1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/core/common_runtime/shape_refiner.h"
16 
17 #include <deque>
18 #include <memory>
19 #include <unordered_set>
20 #include <vector>
21 
22 #include "tensorflow/core/framework/common_shape_fns.h"
23 #include "tensorflow/core/framework/node_def.pb.h"
24 #include "tensorflow/core/framework/tensor.h"
25 #include "tensorflow/core/framework/tensor.pb.h"
26 #include "tensorflow/core/framework/versions.pb.h"
27 #include "tensorflow/core/graph/algorithm.h"
28 #include "tensorflow/core/graph/graph_constructor.h"
29 #include "tensorflow/core/kernels/bounds_check.h"
30 #include "tensorflow/core/lib/core/errors.h"
31 #include "tensorflow/core/lib/gtl/stl_util.h"
32 #include "tensorflow/core/public/session.h"
33 
34 namespace tensorflow {
35 
36 using shape_inference::DimensionHandle;
37 using shape_inference::InferenceContext;
38 using shape_inference::ShapeAndType;
39 using shape_inference::ShapeHandle;
40 
ShapeRefiner(int graph_def_version,const OpRegistryInterface * ops)41 ShapeRefiner::ShapeRefiner(int graph_def_version,
42                            const OpRegistryInterface* ops)
43     : graph_def_version_(graph_def_version),
44       ops_registry_(ops),
45       graph_runner_(Env::Default()) {}
46 
ShapeRefiner(const VersionDef & versions,const OpRegistryInterface * ops)47 ShapeRefiner::ShapeRefiner(const VersionDef& versions,
48                            const OpRegistryInterface* ops)
49     : ShapeRefiner(versions.producer(), ops) {}
50 
~ShapeRefiner()51 ShapeRefiner::~ShapeRefiner() {
52   // The lifetime of the tensors are bound to the GraphRunner, so the tensors
53   // should be deleted before it.
54   const_tensor_map_.clear();
55 }
56 
57 namespace {
58 
59 constexpr char kArgOp[] = "_Arg";
60 constexpr char kRetvalOp[] = "_Retval";
61 
62 // Runs shape inference for the given node using the given ShapeRefiner.
63 // The node must be a sub-node of a function node and the outer_context is
64 // the inference context of that function node in the outer graph.
InferShapesForFunctionSubNode(const Node * node,ShapeRefiner * refiner,InferenceContext * outer_context)65 Status InferShapesForFunctionSubNode(const Node* node, ShapeRefiner* refiner,
66                                      InferenceContext* outer_context) {
67   TF_RETURN_IF_ERROR(refiner->AddNode(node));
68   InferenceContext* node_context = CHECK_NOTNULL(refiner->GetContext(node));
69 
70   if (StringPiece(node->type_string()) == kArgOp) {
71     // Handle special node: function input.
72     // Shapes for these nodes are provided in the outer inference
73     // context.
74 
75     int index;
76     TF_RETURN_IF_ERROR(GetNodeAttr(AttrSlice(node->def()), "index", &index));
77 
78     if (index < 0 || outer_context->num_inputs() <= index) {
79       return errors::Internal(
80           "Function instantiation included invalid input index: ", index,
81           " not in [0, ", outer_context->num_inputs(), ").");
82     }
83 
84     node_context->set_output(0, outer_context->input(index));
85 
86     auto* resource = outer_context->input_handle_shapes_and_types(index);
87     if (resource) {
88       node_context->set_output_handle_shapes_and_types(0, *resource);
89     }
90   } else if (StringPiece(node->type_string()) == kRetvalOp) {
91     // Handle special node: function output.
92     // Shapes inferred for these nodes go into the outer inference
93     // context.
94 
95     int index;
96     TF_RETURN_IF_ERROR(GetNodeAttr(AttrSlice(node->def()), "index", &index));
97 
98     if (index < 0 || outer_context->num_outputs() <= index) {
99       return errors::Internal(
100           "Function instantiation included invalid output index: ", index,
101           " not in [0, ", outer_context->num_outputs(), ").");
102     }
103 
104     // outer_context outlives node_context, therefore we need to create
105     // a new shape handle owned by outer_context instead.
106     ShapeHandle handle;
107     TensorShapeProto proto;
108     node_context->ShapeHandleToProto(node_context->input(0), &proto);
109     TF_RETURN_IF_ERROR(outer_context->MakeShapeFromShapeProto(proto, &handle));
110     outer_context->set_output(index, handle);
111 
112     auto* resource = node_context->input_handle_shapes_and_types(0);
113     if (resource) {
114       outer_context->set_output_handle_shapes_and_types(index, *resource);
115     }
116   }
117 
118   return Status::OK();
119 }
120 
121 }  // namespace
122 
123 // TODO(cwhipkey): When an inference context inside function has
124 // requested_input_tensor(i) or requested_input_tensor_as_partial_shape(i)
125 // set when input(i) is an _Arg op, then this request should propagate to
126 // context, and vice versa.
127 //
128 // NOTE: Recursive user-defined functions are not supported.
129 // Maybe we won't support recursive functions at all in TF, because of
130 // other maintainability issues.
InferShapesForFunction(const tensorflow::FunctionDef * function_def,bool keep_nested_shapes,ExtendedInferenceContext * outer_context)131 Status ShapeRefiner::InferShapesForFunction(
132     const tensorflow::FunctionDef* function_def, bool keep_nested_shapes,
133     ExtendedInferenceContext* outer_context) {
134   const Graph* graph;
135   auto it = functions_.find(function_def);
136   if (it != functions_.end()) {
137     graph = it->second.get();
138   } else {
139     InstantiationResult result;
140     TF_RETURN_IF_ERROR(InstantiateFunction(
141         *function_def, outer_context->get_context()->attrs(),
142         [this](const string& op, const OpDef** sig) {
143           return this->function_library_->LookUpOpDef(op, sig);
144         },
145         &result));
146 
147     Graph* new_graph = new Graph(function_library_);
148     GraphConstructorOptions options;
149     options.allow_internal_ops = true;
150     TF_RETURN_IF_ERROR(
151         ConvertNodeDefsToGraph(options, result.nodes, new_graph));
152     functions_[function_def].reset(new_graph);
153     graph = new_graph;
154   }
155 
156   std::unordered_set<const Node*> function_nodes;
157   Status inference_status = Status::OK();
158   {
159     auto node_shape_inference_lambda = [this, &outer_context, &function_nodes,
160                                         &inference_status](const Node* node) {
161       if (!inference_status.ok()) return;
162       inference_status = InferShapesForFunctionSubNode(
163           node, this, outer_context->get_context());
164       function_nodes.insert(node);
165     };
166 
167     // Calls inference lambda for each node after visiting all predecessors.
168     // Ensures that we are adding nodes to ShapeRefiner in the topological
169     // order.
170     ReverseDFS(*graph, {}, node_shape_inference_lambda);
171   }
172 
173   if (keep_nested_shapes && inference_status.ok()) {
174     // Fill the nested inferences map.
175     //
176     // The materialized function graph has extra nodes for arguments and
177     // return values, which are not explicitly listed in the FunctionDef,
178     // we filter out these special nodes here to not expose the implementation
179     // details and keep only inferences for the nodes listed in the FunctionDef.
180     std::unordered_map<string, const NodeDef*> user_defined_nodes;
181     for (const auto& node_def : function_def->node_def()) {
182       user_defined_nodes[node_def.name()] = &node_def;
183     }
184 
185     std::unordered_map<string, std::unique_ptr<ExtendedInferenceContext>>
186         nested_inferences;
187     for (const Node* node : function_nodes) {
188       const string& node_name = node->name();
189       if (user_defined_nodes.find(node_name) != user_defined_nodes.end()) {
190         nested_inferences[node_name] = std::move(node_to_context_[node]);
191         node_to_context_.erase(node);
192         // By default InferenceContext refers to a NodeDef from Graph.
193         // Change it to the publicly accessible NodeDef of the function
194         // definition.
195         nested_inferences[node_name]->get_context()->node_def_ =
196             user_defined_nodes[node_name];
197       }
198     }
199     outer_context->set_nested_inferences(std::move(nested_inferences));
200   } else {
201     // Delete the contexts created for the functions nodes to save memory.
202     for (const Node* node : function_nodes) {
203       node_to_context_.erase(node);
204     }
205   }
206 
207   return inference_status;
208 }
209 
AddNode(const Node * node)210 Status ShapeRefiner::AddNode(const Node* node) {
211   // For each 'input' of this node, fetch the corresponding shape
212   // from 'input's InferenceContext, and store into a vector
213   // indexed by 'node's input.
214   std::vector<Node*> input_nodes(node->num_inputs());
215   std::vector<ShapeHandle> input_shapes(node->num_inputs());
216   std::vector<std::unique_ptr<std::vector<ShapeAndType>>>
217       input_handle_shapes_and_types(node->num_inputs());
218   for (const Edge* e : node->in_edges()) {
219     if (e->IsControlEdge()) continue;
220 
221     Node* input = e->src();
222     auto it = node_to_context_.find(input);
223     if (it == node_to_context_.end()) {
224       return errors::FailedPrecondition(
225           "Input ", e->dst_input(), " ('", input->name(), "') for '",
226           node->name(), "' was not previously added to ShapeRefiner.");
227     }
228 
229     InferenceContext* c = it->second->get_context();
230     DCHECK_GE(e->dst_input(), 0);
231     input_nodes[e->dst_input()] = input;
232     input_shapes[e->dst_input()] = c->output(e->src_output());
233 
234     // Only propagate handle data of edges which are carrying resource handles.
235     if (e->src()->output_type(e->src_output()) == DT_RESOURCE) {
236       const auto* in_v = c->output_handle_shapes_and_types(e->src_output());
237       if (in_v != nullptr) {
238         input_handle_shapes_and_types[e->dst_input()].reset(
239             new std::vector<ShapeAndType>(*in_v));
240       }
241     }
242   }
243 
244   // Get the shape function for this node
245   const OpRegistrationData* op_reg_data;
246   TF_RETURN_IF_ERROR(ops_registry_->LookUp(node->type_string(), &op_reg_data));
247   if (op_reg_data->shape_inference_fn == nullptr &&
248       require_shape_inference_fns_) {
249     return errors::InvalidArgument(
250         "No shape inference function exists for op '", node->type_string(),
251         "', did you forget to define it?");
252   }
253 
254   // This needs to be filled in with real data in a second pass.
255   std::vector<const Tensor*> input_tensors(node->num_inputs(), nullptr);
256   std::vector<ShapeHandle> input_tensors_as_shapes;
257 
258   // Create the inference context for this node with the existing input shapes.
259   std::unique_ptr<InferenceContext> c(
260       new InferenceContext(graph_def_version_, &node->def(), node->op_def(),
261                            input_shapes, input_tensors, input_tensors_as_shapes,
262                            std::move(input_handle_shapes_and_types)));
263   if (!c->construction_status().ok()) {
264     return c->construction_status();
265   }
266 
267   std::unique_ptr<ExtendedInferenceContext> ec(
268       new ExtendedInferenceContext(std::move(c), node));
269 
270   // Run the shape inference function, and return if there was an error.
271   TF_RETURN_IF_ERROR(RunShapeFn(node, op_reg_data, ec.get()));
272 
273   // Store the resulting context object in the map.
274   node_to_context_[node].swap(ec);
275 
276   return Status::OK();
277 }
278 
SetShape(const Node * node,int output_port,ShapeHandle shape)279 Status ShapeRefiner::SetShape(const Node* node, int output_port,
280                               ShapeHandle shape) {
281   auto c = GetContext(node);
282   if (c == nullptr) {
283     return errors::Internal("Could not find context for ", node->name());
284   }
285 
286   if (output_port < 0 || output_port >= node->num_outputs()) {
287     return errors::InvalidArgument(
288         "output_port '", output_port, "' is out of range, ", "node '",
289         node->name(), "' has ", node->num_outputs(), " outputs");
290   }
291 
292   // Check compatibility, and merge the shapes.
293   ShapeHandle existing_shape = c->output(output_port);
294   TF_RETURN_IF_ERROR(c->Merge(existing_shape, shape, &shape));
295   c->set_output(output_port, shape);
296 
297   // TODO(vrv): Do we need to propagate the new shape through all
298   // consumers that change their outputs?  At the moment, python
299   // does not do this, but this seems like a nice feature.
300 
301   // TODO(vrv): We might need to keep track of the fact that the
302   // existing shape is invalidated, in case we need to propagate
303   // this information to remote workers.
304   return Status::OK();
305 }
306 
UpdateNode(const Node * node,bool relax,bool * refined)307 Status ShapeRefiner::UpdateNode(const Node* node, bool relax, bool* refined) {
308   auto it = node_to_context_.find(node);
309   if (it == node_to_context_.end()) {
310     *refined = true;
311     return AddNode(node);
312   }
313   ExtendedInferenceContext* node_ext_context = it->second.get();
314   InferenceContext* node_context = node_ext_context->get_context();
315 
316   // Give up if the context wasn't successfully built by the AddNode() method.
317   TF_RETURN_IF_ERROR(node_context->construction_status());
318 
319   // Check if the shapes of the nodes in the fan-in of this node have changed,
320   // and if they have update the node input shapes.
321   for (const Edge* e : node->in_edges()) {
322     if (e->IsControlEdge()) continue;
323 
324     int dst_input = e->dst_input();
325     int src_output = e->src_output();
326 
327     Node* input = e->src();
328     auto iter = node_to_context_.find(input);
329     if (iter == node_to_context_.end()) {
330       return errors::FailedPrecondition(
331           "Input ", dst_input, " ('", input->name(), "') for '", node->name(),
332           "' was not previously added to ShapeRefiner.");
333     }
334 
335     InferenceContext* c = iter->second->get_context();
336     DCHECK_GE(dst_input, 0);
337     ShapeHandle existing_input = node_context->input(dst_input);
338     if (!relax) {
339       if (node_context->MergeInput(dst_input, c->output(src_output))) {
340         if (!SameDefinedShape(node_context, node_context->input(dst_input),
341                               existing_input)) {
342           *refined = true;
343         }
344       }
345     } else {
346       if (node_context->RelaxInput(dst_input, c->output(src_output))) {
347         if (!SameDefinedShape(node_context, node_context->input(dst_input),
348                               existing_input)) {
349           *refined = true;
350         }
351       }
352     }
353 
354     // Also propagate handle shape and dtype of edges which are carrying
355     // resource handles.
356     if (e->src()->output_type(src_output) == DT_RESOURCE) {
357       auto* outputs = c->output_handle_shapes_and_types(src_output);
358       if (!outputs) continue;
359 
360       if (!relax &&
361           node_context->MergeInputHandleShapesAndTypes(dst_input, *outputs)) {
362         *refined = true;
363       } else if (relax) {
364         std::vector<ShapeAndType> existing_inputs;
365         const std::vector<ShapeAndType>* inputs =
366             node_context->input_handle_shapes_and_types(dst_input);
367         if (inputs) {
368           existing_inputs = *inputs;
369         }
370         if (node_context->RelaxInputHandleShapesAndMergeTypes(dst_input,
371                                                               *outputs)) {
372           if (IsUpdatedShapesOrTypes(
373                   node_context, existing_inputs,
374                   *node_context->input_handle_shapes_and_types(dst_input))) {
375             *refined = true;
376           }
377         }
378       }
379     }
380   }
381 
382   if (!*refined) {
383     // No input shape has changed, we're done
384     return Status::OK();
385   }
386 
387   // Get and run the shape function for this node to update the shapes of the
388   // outputs.
389   const OpRegistrationData* op_reg_data;
390   TF_RETURN_IF_ERROR(ops_registry_->LookUp(node->type_string(), &op_reg_data));
391   if (op_reg_data->shape_inference_fn == nullptr &&
392       require_shape_inference_fns_) {
393     return errors::InvalidArgument(
394         "No shape inference function exists for op '", node->type_string(),
395         "', did you forget to define it?");
396   }
397 
398   if (!op_reg_data->shape_inference_fn) {
399     // There is nothing more we can infer
400     return Status::OK();
401   }
402 
403   return RunShapeFn(node, op_reg_data, node_ext_context);
404 }
405 
EvaluateConstantTensorForEdge(const Node * node,int dst_idx,bool * evaluated,Tensor * result)406 Status ShapeRefiner::EvaluateConstantTensorForEdge(const Node* node,
407                                                    int dst_idx, bool* evaluated,
408                                                    Tensor* result) {
409   *evaluated = false;
410 
411   const Edge* input_edge;
412   TF_RETURN_IF_ERROR(node->input_edge(dst_idx, &input_edge));
413 
414   // Simple case: the source node is a constant
415   const Node* src = input_edge->src();
416   if (src->IsConstant()) {
417     if (result->FromProto(src->def().attr().at("value").tensor())) {
418       *evaluated = true;
419       return Status::OK();
420     }
421   }
422 
423   if (disable_constant_propagation_) {
424     return Status::OK();
425   }
426 
427   bool is_constant_graph = false;
428   Graph subgraph(ops_registry_);
429   auto versions = subgraph.versions();
430   versions.set_producer(graph_def_version_);
431   subgraph.set_versions(versions);
432 
433   // We identify the possibly constant subgraph to evaluate by
434   // recursively iterating backwards through the inputs to 'node'
435   // until we either 1) find an already existing input to our subgraph
436   // (filled in `const_inputs`), 2) Discover our graph is not constant,
437   // or 3) Hit a root node.
438   std::vector<std::pair<string, Tensor>> const_inputs;
439   TF_RETURN_IF_ERROR(ExtractConstantSubgraph(
440       input_edge->src(), &subgraph, &is_constant_graph, &const_inputs));
441   if (!is_constant_graph) {
442     return Status::OK();
443   }
444   const string output_tensor_name =
445       strings::StrCat(input_edge->src()->name(), ":", input_edge->src_output());
446   std::vector<Tensor> outputs;
447 
448   // NOTE; we should pass in a function library runtime if we want
449   // to support constant-expression evaluation on functions.
450   Status s = graph_runner_.Run(&subgraph, nullptr /* function_library */,
451                                const_inputs, {output_tensor_name}, &outputs);
452 
453   // If all kernels in the constant graph are not registered
454   // in the process, GraphRunner::Run may fail, in which case
455   // we cannot propagate constants, so this is best-effort.
456   if (s.ok()) {
457     *result = outputs[0];
458     *evaluated = true;
459 
460     // We memoize (small) constants evaluated so far, so
461     // ExtractConstantSubgraph can avoid extracting the full
462     // subgraph.  As we build up large graphs, this avoids
463     // repeated computation of the early parts of a constant
464     // graph.
465     if (outputs[0].TotalBytes() <= kMaxTensorSize) {
466       const_tensor_map_[output_tensor_name] = outputs[0];
467     }
468   }
469   return Status::OK();
470 }
471 
TryToInferTensorOutputFromInputShapes(const Edge * edge,Tensor * output,bool * success)472 Status ShapeRefiner::TryToInferTensorOutputFromInputShapes(const Edge* edge,
473                                                            Tensor* output,
474                                                            bool* success) {
475   *success = false;
476   const Node* node = edge->src();
477   auto it = node_to_context_.find(node);
478   if (it == node_to_context_.end()) {
479     return errors::FailedPrecondition("Node does not have context.");
480   }
481   InferenceContext* c = it->second->get_context();
482 
483   if (node->type_string() == "Shape") {
484     // If input shapes to the shape op are fully defined,
485     // we can infer the shape op's output tensor.
486     bool fully_defined_inputs = c->FullyDefined(c->input(0));
487     if (fully_defined_inputs) {
488       int input_rank = c->Rank(c->input(0));
489       Tensor t(node->output_type(0), TensorShape({input_rank}));
490       if (node->output_type(0) == DT_INT32) {
491         auto flat = t.flat<int>();
492         for (int i = 0; i < input_rank; i++) {
493           int64 dimension = c->Value(c->Dim(c->input(0), i));
494           if (!FastBoundsCheck(dimension, std::numeric_limits<int32>::max())) {
495             return errors::FailedPrecondition(
496                 "Shape has output type int32, but dimension exceeds maximum "
497                 "int32 value");
498           }
499           flat(i) = static_cast<int32>(dimension);
500         }
501       } else if (node->output_type(0) == DT_INT64) {
502         auto flat = t.flat<int64>();
503         for (int i = 0; i < input_rank; i++) {
504           flat(i) = c->Value(c->Dim(c->input(0), i));
505         }
506       } else {
507         return errors::FailedPrecondition(
508             "Shape has output type that is not int32 or int64");
509       }
510       *output = t;
511       *success = true;
512     }
513   } else if (node->type_string() == "Rank") {
514     bool rank_known = c->RankKnown(c->input(0));
515     if (rank_known) {
516       int32 input_rank = c->Rank(c->input(0));
517       Tensor t(node->output_type(0), TensorShape({}));
518       t.flat<int32>()(0) = input_rank;
519       *output = t;
520       *success = true;
521     }
522   } else if (node->type_string() == "Size") {
523     bool fully_defined_inputs = c->FullyDefined(c->input(0));
524     if (fully_defined_inputs) {
525       int32 rank = c->Rank(c->input(0));
526       Tensor t(node->output_type(0), TensorShape({}));
527       int64 size = 1;
528       for (int i = 0; i < rank; i++) {
529         size *= c->Value(c->Dim(c->input(0), i));
530       }
531       if (node->output_type(0) == DT_INT32) {
532         if (!FastBoundsCheck(size, std::numeric_limits<int32>::max())) {
533           return errors::FailedPrecondition(
534               "Size has output type int32, but size exceeds maximum int32 "
535               "value");
536         }
537         t.flat<int32>()(0) = static_cast<int32>(size);
538       } else if (node->output_type(0) == DT_INT64) {
539         t.flat<int64>()(0) = size;
540       } else {
541         return errors::FailedPrecondition(
542             "Size has output type that is not int32 or int64");
543       }
544       *output = t;
545       *success = true;
546     }
547   }
548   return Status::OK();
549 }
550 
ExtractConstantSubgraph(Node * target_node,Graph * out_graph,bool * is_constant_graph,std::vector<std::pair<string,Tensor>> * const_inputs)551 Status ShapeRefiner::ExtractConstantSubgraph(
552     Node* target_node, Graph* out_graph, bool* is_constant_graph,
553     std::vector<std::pair<string, Tensor>>* const_inputs) {
554   *is_constant_graph = false;
555   std::unordered_set<string> const_inputs_added;
556 
557   if (target_node->op_def().is_stateful()) {
558     return Status::OK();
559   }
560 
561   if (target_node->type_string() == "PlaceholderWithDefault") {
562     return Status::OK();
563   }
564 
565   // TODO(skyewm): more of the filtering applied in input nodes below should be
566   // applied to target_node here
567 
568   struct NodeAndRecursed {
569     Node* new_node = nullptr;
570     bool recursed = false;
571   };
572 
573   std::map<Node*, NodeAndRecursed> old_to_new_and_recursed;
574   Node* target_node_copy = out_graph->CopyNode(target_node);
575   old_to_new_and_recursed[target_node].new_node = target_node_copy;
576   old_to_new_and_recursed[target_node].recursed = true;
577 
578   // Add the target node's inputs to seed the recursion.
579   std::deque<const Edge*> edges_to_visit;
580   for (const Edge* e : target_node->in_edges()) {
581     // TODO(vrv): What do we do about control edges?  Based on our
582     // definition of a constant graph, we should be free to ignore
583     // control edges since the order in which a constant graph is
584     // executed should be the same regardless of when nodes run: we
585     // should only need to recurse down data edges.
586     if (e->IsControlEdge()) continue;
587     edges_to_visit.push_back(e);
588   }
589 
590   *is_constant_graph = true;
591 
592   // Iterate over the set of edges to visit (backwards).
593   while (!edges_to_visit.empty()) {
594     const Edge* current_edge = edges_to_visit.front();
595     edges_to_visit.pop_front();
596     Node* current_node = current_edge->src();
597 
598     // If the node is stateful, assume the graph is not constant.
599     if (current_node->op_def().is_stateful()) {
600       *is_constant_graph = false;
601       return Status::OK();
602     }
603 
604     // During construction or import from GraphConstructor, back edges may not
605     // be filled in.  Don't constant fold through merges at all for now.
606     if (IsMerge(current_node)) {
607       *is_constant_graph = false;
608       return Status::OK();
609     }
610 
611     // Don't constant fold enter/exit currently either, as it's easy to end
612     // up with a partial frame.
613     if (IsEnter(current_node) || IsExit(current_node)) {
614       *is_constant_graph = false;
615       return Status::OK();
616     }
617 
618     // Placeholders should never be constant folded because their outputs are
619     // fed by the user. Note that "Placeholder" nodes have no inputs so are
620     // handled below.
621     if (current_node->type_string() == "PlaceholderWithDefault") {
622       *is_constant_graph = false;
623       return Status::OK();
624     }
625 
626     // If there is nothing more to recurse down, see if
627     // the generator node is a constant.
628     if (current_node->num_inputs() == 0) {
629       if (!current_node->IsConstant()) {
630         // Generator node is not a constant, so subgraph is not
631         // constant.
632         *is_constant_graph = false;
633         return Status::OK();
634       }
635     }
636 
637     // Either the node is a constant, or the node is a potential
638     // intermediate node on the path from a constant.
639     //
640     // Add a copy of its node and a new edge to the new subgraph.
641 
642     // Get or create the version of 'current_node' in the new graph.
643     Node* current_node_copy;
644     // This gets or creates the NodeAndRecursed entry for current_node.
645     NodeAndRecursed* node_and_recursed = &old_to_new_and_recursed[current_node];
646     if (node_and_recursed->new_node == nullptr) {
647       // First time processing this node.
648       current_node_copy = out_graph->CopyNode(current_node);
649       // Track the mapping from the original node to the new one.
650       node_and_recursed->new_node = current_node_copy;
651     } else {
652       current_node_copy = node_and_recursed->new_node;
653     }
654 
655     // Add the edge to the destination node.
656     {
657       auto it = old_to_new_and_recursed.find(current_edge->dst());
658       if (it == old_to_new_and_recursed.end()) {
659         return errors::Internal(
660             "Could not find mapping from old to new copy of destination node: ",
661             current_edge->dst()->name());
662       }
663       Node* dst_copy = it->second.new_node;
664 
665       out_graph->AddEdge(current_node_copy, current_edge->src_output(),
666                          dst_copy, current_edge->dst_input());
667     }
668 
669     const string& output_tensor_name =
670         strings::StrCat(current_node->name(), ":", current_edge->src_output());
671 
672     // Some tensor values can be inferred. For example, a shape op
673     // with input shapes fully defined can have its output tensor inferred.
674     Tensor tensor_inferred;
675     bool successfully_inferred_tensor = false;
676     TF_RETURN_IF_ERROR(TryToInferTensorOutputFromInputShapes(
677         current_edge, &tensor_inferred, &successfully_inferred_tensor));
678     if (successfully_inferred_tensor) {
679       const_inputs->emplace_back(output_tensor_name, tensor_inferred);
680       const_inputs_added.insert(output_tensor_name);
681       continue;
682     }
683 
684     // If we have a copy of the input tensor materialized already,
685     // then add to the list of inputs to feed and do not recurse further.
686     auto it = const_tensor_map_.find(output_tensor_name);
687     if (it != const_tensor_map_.end() &&
688         const_inputs_added.count(output_tensor_name) == 0) {
689       const_inputs->emplace_back(output_tensor_name, it->second);
690       const_inputs_added.insert(output_tensor_name);
691       continue;
692     }
693 
694     // If this node's inputs have not been processed already, do so now.
695     if (!node_and_recursed->recursed) {
696       node_and_recursed->recursed = true;
697       for (const Edge* e : current_node->in_edges()) {
698         if (e->IsControlEdge()) continue;
699         edges_to_visit.push_back(e);
700       }
701     }
702   }
703 
704   return Status::OK();
705 }
706 
ConstantPartialShape(InferenceContext * target_context,const Node * node,int dst_idx,ShapeHandle * result)707 Status ShapeRefiner::ConstantPartialShape(InferenceContext* target_context,
708                                           const Node* node, int dst_idx,
709                                           ShapeHandle* result) {
710   const Edge* input_edge;
711   TF_RETURN_IF_ERROR(node->input_edge(dst_idx, &input_edge));
712 
713   InferenceContext* src_context = GetContext(input_edge->src());
714   if (src_context == nullptr) return errors::Internal("Missing src context");
715   ShapeHandle src_shape = src_context->output(input_edge->src_output());
716   TF_RETURN_IF_ERROR(src_context->WithRank(src_shape, 1, &src_shape));
717 
718   const string& src_op = input_edge->src()->type_string();
719   if (src_context->Value(src_context->Dim(src_shape, 0)) == 0) {
720     // Source tensor is a vector of length 0, so the shape it
721     // represents is as scalar.
722     *result = target_context->Scalar();
723   } else if (src_op == "Shape") {
724     *result = src_context->input(0);
725   } else if (src_op == "ShapeN") {
726     *result = src_context->input(input_edge->src_output());
727   } else if (src_op == "Pack") {
728     std::vector<DimensionHandle> dims;
729     // Pack is concatenating its input scalars to form the shape tensor vector.
730     for (int i = 0; i < src_context->num_inputs(); ++i) {
731       Tensor scalar;
732       bool evaluated = false;
733       TF_RETURN_IF_ERROR(EvaluateConstantTensorForEdge(input_edge->src(), i,
734                                                        &evaluated, &scalar));
735       if (evaluated) {
736         int64 size;
737         if (scalar.dtype() == DT_INT32) {
738           size = scalar.scalar<int32>()();
739         } else if (scalar.dtype() == DT_INT64) {
740           size = scalar.scalar<int64>()();
741         } else {
742           return errors::InvalidArgument("Pack input must be int32 or int64");
743         }
744         dims.push_back(size < 0 ? target_context->UnknownDim()
745                                 : target_context->MakeDim(size));
746       } else {
747         dims.push_back(target_context->UnknownDim());
748       }
749     }
750     *result = target_context->MakeShape(dims);
751   } else if (src_op == "Concat" || src_op == "ConcatV2") {
752     *result = target_context->Scalar();
753     // For Concat, input 0 is concat dim; for V2 it is the last input.
754     const int concat_dim =
755         src_op == "Concat" ? 0 : src_context->num_inputs() - 1;
756     // Concat is concatenating its input shape vectors.
757     for (int i = 0; i < src_context->num_inputs(); ++i) {
758       // Concat dim is ignored (and will always be a scalar).
759       if (i == concat_dim) continue;
760       ShapeHandle sub_result;
761       TF_RETURN_IF_ERROR(ConstantPartialShape(target_context, input_edge->src(),
762                                               i, &sub_result));
763       if (!target_context->RankKnown(sub_result)) {
764         // Failed to evaluate. Treat the output as completely unknown.
765         // TODO(cwhipkey): we could rely on all inputs being the same rank, so
766         // figure that rank out and append the right number of unknown dims.
767         *result = target_context->UnknownShape();
768         return Status::OK();
769       }
770       TF_RETURN_IF_ERROR(
771           target_context->Concatenate(*result, sub_result, result));
772     }
773   } else {
774     Tensor t;
775     bool evaluated = false;
776     TF_RETURN_IF_ERROR(
777         EvaluateConstantTensorForEdge(node, dst_idx, &evaluated, &t));
778     TF_RETURN_IF_ERROR(target_context->MakeShapeFromTensor(
779         evaluated ? &t : nullptr, src_shape, result));
780   }
781   return Status::OK();
782 }
783 
RunShapeFn(const Node * node,const OpRegistrationData * op_reg_data,ExtendedInferenceContext * ec)784 Status ShapeRefiner::RunShapeFn(const Node* node,
785                                 const OpRegistrationData* op_reg_data,
786                                 ExtendedInferenceContext* ec) {
787   // This will be filled in with real data in a second pass.
788   std::vector<const Tensor*> input_tensors(node->num_inputs(), nullptr);
789   std::vector<Tensor> real_tensors(node->num_inputs());
790   std::vector<bool> attempted_materialization(node->num_inputs());
791   std::vector<bool> attempted_tensor_as_shape_conversion(node->num_inputs());
792   std::vector<ShapeHandle> input_tensors_as_shapes;
793 
794   auto* c = ec->get_context();
795 
796   c->set_input_tensors(input_tensors);
797   c->set_input_tensors_as_shapes(input_tensors_as_shapes);
798 
799   // Run the shape inference function, and return if there was an error.
800   // Capture as lambda, because we might need to re-run inference later on.
801   auto run_inference_lambda = [&]() {
802     if (function_library_ && op_reg_data->is_function_op) {
803       // Special inference logic for user-defined functions.
804 
805       auto* func_def = function_library_->Find(op_reg_data->op_def.name());
806       if (func_def) {
807         return InferShapesForFunction(func_def, keep_nested_shape_inferences_,
808                                       ec);
809       }
810     }
811 
812     if (op_reg_data->shape_inference_fn) {
813       TF_RETURN_IF_ERROR(c->Run(op_reg_data->shape_inference_fn));
814     } else {
815       TF_RETURN_IF_ERROR(c->Run(shape_inference::UnknownShape));
816     }
817     return Status::OK();
818   };
819   TF_RETURN_IF_ERROR(run_inference_lambda());
820 
821   // We must run the shape function repeatedly, in case users write
822   // shape functions where they only conditionally call input_tensor()
823   // based on the values of another input tensor.
824   bool rerun_shape_fn;
825   do {
826     // If the result of running shape inference would have benefitted
827     // from knowing the values of input tensors, try to materialize
828     // the results of those tensors, and then run the shape inference
829     // function again using those known tensors.
830     rerun_shape_fn = false;
831 
832     // NOTE: It is possible to batch the extraction and
833     // materialization of inputs, instead of materializing one input
834     // at a time like we do below.  If input-at-a-time computation
835     // becomes a bottleneck, we could separate ExtractConstantSubgraph
836     // into two functions: one that returns true if an input is
837     // derivable from constants, and another function that extracts
838     // the subgraph for multiple target nodes and executes the whole
839     // subgraph once.
840 
841     for (int i = 0; i < c->num_inputs(); ++i) {
842       if (!c->requested_input_tensor(i)) {
843         continue;
844       }
845       // Check if we have not already filled in the requested input,
846       // and if not, try to materialize the tensors.
847       if (!attempted_materialization[i]) {
848         attempted_materialization[i] = true;
849 
850         Tensor result;
851         bool evaluated = false;
852         TF_RETURN_IF_ERROR(
853             EvaluateConstantTensorForEdge(node, i, &evaluated, &result));
854         if (evaluated) {
855           real_tensors[i] = result;
856           input_tensors[i] = &real_tensors[i];
857           // We have more concrete information about a shape,
858           // so re-run shape inference.
859           rerun_shape_fn = true;
860         }
861       }
862       if (c->requested_input_tensor_as_partial_shape(i) &&
863           !attempted_tensor_as_shape_conversion[i]) {
864         attempted_tensor_as_shape_conversion[i] = true;
865         if (i >= input_tensors_as_shapes.size()) {
866           input_tensors_as_shapes.resize(i + 1);
867         }
868         ShapeHandle s;
869         TF_RETURN_IF_ERROR(ConstantPartialShape(c, node, i, &s));
870         input_tensors_as_shapes[i] = s;
871         rerun_shape_fn = true;
872       }
873     }
874 
875     if (rerun_shape_fn) {
876       // We have more information about the shapes on this pass,
877       // so re-run shape inference.
878       c->set_input_tensors(input_tensors);
879       c->set_input_tensors_as_shapes(input_tensors_as_shapes);
880       TF_RETURN_IF_ERROR(run_inference_lambda());
881     }
882   } while (rerun_shape_fn);
883 
884   return Status::OK();
885 }
886 
SameDefinedShape(InferenceContext * c,ShapeHandle s0,ShapeHandle s1)887 bool ShapeRefiner::SameDefinedShape(InferenceContext* c, ShapeHandle s0,
888                                     ShapeHandle s1) {
889   if (s0.SameHandle(s1)) {
890     return true;
891   }
892   if (c->Rank(s0) != c->Rank(s1)) {
893     return false;
894   }
895   if (!c->RankKnown(s0) && !c->RankKnown(s1)) {
896     return false;
897   }
898   for (int i = 0; i < c->Rank(s0); ++i) {
899     if (!c->Dim(s0, i).SameHandle(c->Dim(s1, i))) {
900       int64 val0 = c->Value(c->Dim(s0, i));
901       int64 val1 = c->Value(c->Dim(s1, i));
902       if (val0 < 0 || val1 < 0 || val0 != val1) {
903         return false;
904       }
905     }
906   }
907 
908   return true;
909 }
910 
IsUpdatedShapesOrTypes(InferenceContext * c,const std::vector<ShapeAndType> & existing,const std::vector<ShapeAndType> & updated)911 bool ShapeRefiner::IsUpdatedShapesOrTypes(
912     InferenceContext* c, const std::vector<ShapeAndType>& existing,
913     const std::vector<ShapeAndType>& updated) {
914   if (existing.size() != updated.size()) {
915     return true;
916   }
917   for (int i = 0; i < existing.size(); i++) {
918     if (!SameDefinedShape(c, existing[i].shape, updated[i].shape) ||
919         existing[i].dtype != updated[i].dtype) {
920       return true;
921     }
922   }
923   return false;
924 }
925 
926 }  // namespace tensorflow
927