1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/core/common_runtime/shape_refiner.h"
16 
17 #include <deque>
18 #include <memory>
19 #include <unordered_set>
20 #include <vector>
21 
22 #include "tensorflow/core/common_runtime/eval_const_tensor.h"
23 #include "tensorflow/core/framework/bounds_check.h"
24 #include "tensorflow/core/framework/common_shape_fns.h"
25 #include "tensorflow/core/framework/node_def.pb.h"
26 #include "tensorflow/core/framework/tensor.h"
27 #include "tensorflow/core/framework/tensor.pb.h"
28 #include "tensorflow/core/framework/versions.pb.h"
29 #include "tensorflow/core/graph/algorithm.h"
30 #include "tensorflow/core/graph/graph_constructor.h"
31 #include "tensorflow/core/lib/core/errors.h"
32 #include "tensorflow/core/lib/gtl/stl_util.h"
33 #include "tensorflow/core/public/session.h"
34 
35 namespace tensorflow {
36 
37 using shape_inference::DimensionHandle;
38 using shape_inference::InferenceContext;
39 using shape_inference::ShapeAndType;
40 using shape_inference::ShapeHandle;
41 
ShapeRefiner(int graph_def_version,const OpRegistryInterface * ops)42 ShapeRefiner::ShapeRefiner(int graph_def_version,
43                            const OpRegistryInterface* ops)
44     : graph_def_version_(graph_def_version),
45       ops_registry_(ops),
46       graph_runner_(Env::Default()) {}
47 
ShapeRefiner(const VersionDef & versions,const OpRegistryInterface * ops)48 ShapeRefiner::ShapeRefiner(const VersionDef& versions,
49                            const OpRegistryInterface* ops)
50     : ShapeRefiner(versions.producer(), ops) {}
51 
~ShapeRefiner()52 ShapeRefiner::~ShapeRefiner() {
53   // The lifetime of the tensors are bound to the GraphRunner, so the tensors
54   // should be deleted before it.
55   const_tensor_map_.clear();
56 }
57 
58 namespace {
59 
60 constexpr char kArgOp[] = "_Arg";
61 constexpr char kRetvalOp[] = "_Retval";
62 
63 // Runs shape inference for the given node using the given ShapeRefiner.
64 // The node must be a sub-node of a function node and the outer_context is
65 // the inference context of that function node in the outer graph.
InferShapesForFunctionSubNode(const Node * node,ShapeRefiner * refiner,InferenceContext * outer_context)66 Status InferShapesForFunctionSubNode(const Node* node, ShapeRefiner* refiner,
67                                      InferenceContext* outer_context) {
68   TF_RETURN_IF_ERROR(refiner->AddNode(node));
69   InferenceContext* node_context = CHECK_NOTNULL(refiner->GetContext(node));
70 
71   if (StringPiece(node->type_string()) == kArgOp) {
72     // Handle special node: function input.
73     // Shapes for these nodes are provided in the outer inference
74     // context.
75 
76     int index;
77     TF_RETURN_IF_ERROR(GetNodeAttr(AttrSlice(node->def()), "index", &index));
78 
79     if (index < 0 || outer_context->num_inputs() <= index) {
80       return errors::Internal(
81           "Function instantiation included invalid input index: ", index,
82           " not in [0, ", outer_context->num_inputs(), ").");
83     }
84 
85     node_context->set_output(0, outer_context->input(index));
86 
87     auto* resource = outer_context->input_handle_shapes_and_types(index);
88     if (resource) {
89       node_context->set_output_handle_shapes_and_types(0, *resource);
90     }
91   } else if (StringPiece(node->type_string()) == kRetvalOp) {
92     // Handle special node: function output.
93     // Shapes inferred for these nodes go into the outer inference
94     // context.
95 
96     int index;
97     TF_RETURN_IF_ERROR(GetNodeAttr(AttrSlice(node->def()), "index", &index));
98 
99     if (index < 0 || outer_context->num_outputs() <= index) {
100       return errors::Internal(
101           "Function instantiation included invalid output index: ", index,
102           " not in [0, ", outer_context->num_outputs(), ").");
103     }
104 
105     // outer_context outlives node_context, therefore we need to create
106     // a new shape handle owned by outer_context instead.
107     ShapeHandle handle;
108     TensorShapeProto proto;
109     node_context->ShapeHandleToProto(node_context->input(0), &proto);
110     TF_RETURN_IF_ERROR(outer_context->MakeShapeFromShapeProto(proto, &handle));
111     outer_context->set_output(index, handle);
112 
113     auto* resource = node_context->input_handle_shapes_and_types(0);
114     if (resource) {
115       outer_context->set_output_handle_shapes_and_types(index, *resource);
116     }
117   }
118 
119   return Status::OK();
120 }
121 
122 }  // namespace
123 
124 // TODO(cwhipkey): When an inference context inside function has
125 // requested_input_tensor(i) or requested_input_tensor_as_partial_shape(i)
126 // set when input(i) is an _Arg op, then this request should propagate to
127 // context, and vice versa.
128 //
129 // NOTE: Recursive user-defined functions are not supported.
130 // Maybe we won't support recursive functions at all in TF, because of
131 // other maintainability issues.
InferShapesForFunction(const tensorflow::FunctionDef * function_def,bool keep_nested_shapes,ExtendedInferenceContext * outer_context)132 Status ShapeRefiner::InferShapesForFunction(
133     const tensorflow::FunctionDef* function_def, bool keep_nested_shapes,
134     ExtendedInferenceContext* outer_context) {
135   const Graph* graph;
136   auto it = functions_.find(function_def);
137   if (it != functions_.end()) {
138     graph = it->second.get();
139   } else {
140     InstantiationResult result;
141     TF_RETURN_IF_ERROR(InstantiateFunction(
142         *function_def, outer_context->get_context()->attrs(),
143         [this](const string& op, const OpDef** sig) {
144           return this->function_library_->LookUpOpDef(op, sig);
145         },
146         &result));
147 
148     Graph* new_graph = new Graph(function_library_);
149     GraphConstructorOptions options;
150     options.allow_internal_ops = true;
151     TF_RETURN_IF_ERROR(
152         ConvertNodeDefsToGraph(options, result.nodes, new_graph));
153     functions_[function_def].reset(new_graph);
154     graph = new_graph;
155   }
156 
157   std::unordered_set<const Node*> function_nodes;
158   Status inference_status = Status::OK();
159   {
160     auto node_shape_inference_lambda = [this, &outer_context, &function_nodes,
161                                         &inference_status](const Node* node) {
162       if (!inference_status.ok()) return;
163       inference_status = InferShapesForFunctionSubNode(
164           node, this, outer_context->get_context());
165       function_nodes.insert(node);
166     };
167 
168     // Calls inference lambda for each node after visiting all predecessors.
169     // Ensures that we are adding nodes to ShapeRefiner in the topological
170     // order.
171     ReverseDFS(*graph, {}, node_shape_inference_lambda);
172   }
173 
174   if (keep_nested_shapes && inference_status.ok()) {
175     // Fill the nested inferences map.
176     //
177     // The materialized function graph has extra nodes for arguments and
178     // return values, which are not explicitly listed in the FunctionDef,
179     // we filter out these special nodes here to not expose the implementation
180     // details and keep only inferences for the nodes listed in the FunctionDef.
181     std::unordered_map<string, const NodeDef*> user_defined_nodes;
182     for (const auto& node_def : function_def->node_def()) {
183       user_defined_nodes[node_def.name()] = &node_def;
184     }
185 
186     std::unordered_map<string, std::unique_ptr<ExtendedInferenceContext>>
187         nested_inferences;
188     for (const Node* node : function_nodes) {
189       const string& node_name = node->name();
190       if (user_defined_nodes.find(node_name) != user_defined_nodes.end()) {
191         nested_inferences[node_name] = std::move(node_to_context_[node]);
192         node_to_context_.erase(node);
193         // By default InferenceContext refers to a NodeDef from Graph.
194         // Change it to the publicly accessible NodeDef of the function
195         // definition.
196         nested_inferences[node_name]->get_context()->node_def_ =
197             user_defined_nodes[node_name];
198       }
199     }
200     outer_context->set_nested_inferences(std::move(nested_inferences));
201   } else {
202     // Delete the contexts created for the functions nodes to save memory.
203     for (const Node* node : function_nodes) {
204       node_to_context_.erase(node);
205     }
206   }
207 
208   return inference_status;
209 }
210 
AddNode(const Node * node)211 Status ShapeRefiner::AddNode(const Node* node) {
212   // For each 'input' of this node, fetch the corresponding shape
213   // from 'input's InferenceContext, and store into a vector
214   // indexed by 'node's input.
215   std::vector<const Node*> input_nodes(node->num_inputs());
216   std::vector<ShapeHandle> input_shapes(node->num_inputs());
217   std::vector<std::unique_ptr<std::vector<ShapeAndType>>>
218       input_handle_shapes_and_types(node->num_inputs());
219   for (const Edge* e : node->in_edges()) {
220     if (e->IsControlEdge()) continue;
221 
222     const Node* input = e->src();
223     auto it = node_to_context_.find(input);
224     if (it == node_to_context_.end()) {
225       return errors::FailedPrecondition(
226           "Input ", e->dst_input(), " ('", input->name(), "') for '",
227           node->name(), "' was not previously added to ShapeRefiner.");
228     }
229 
230     InferenceContext* c = it->second->get_context();
231     DCHECK_GE(e->dst_input(), 0);
232     input_nodes[e->dst_input()] = input;
233     input_shapes[e->dst_input()] = c->output(e->src_output());
234 
235     const auto* in_v = c->output_handle_shapes_and_types(e->src_output());
236     if (in_v != nullptr) {
237       DataType input_type = e->src()->output_type(e->src_output());
238       DCHECK(input_type == DT_RESOURCE || input_type == DT_VARIANT);
239       input_handle_shapes_and_types[e->dst_input()].reset(
240           new std::vector<ShapeAndType>(*in_v));
241     }
242   }
243 
244   // Get the shape function for this node
245   const OpRegistrationData* op_reg_data;
246   TF_RETURN_IF_ERROR(ops_registry_->LookUp(node->type_string(), &op_reg_data));
247   if (op_reg_data->shape_inference_fn == nullptr &&
248       require_shape_inference_fns_) {
249     return errors::InvalidArgument(
250         "No shape inference function exists for op '", node->type_string(),
251         "', did you forget to define it?");
252   }
253 
254   // This needs to be filled in with real data in a second pass.
255   std::vector<const Tensor*> input_tensors(node->num_inputs(), nullptr);
256   std::vector<ShapeHandle> input_tensors_as_shapes;
257 
258   // Create the inference context for this node with the existing input shapes.
259   std::unique_ptr<InferenceContext> c(
260       new InferenceContext(graph_def_version_, &node->def(), node->op_def(),
261                            input_shapes, input_tensors, input_tensors_as_shapes,
262                            std::move(input_handle_shapes_and_types)));
263   if (!c->construction_status().ok()) {
264     return c->construction_status();
265   }
266 
267   std::unique_ptr<ExtendedInferenceContext> ec(
268       new ExtendedInferenceContext(std::move(c), node));
269 
270   // Run the shape inference function, and return if there was an error.
271   TF_RETURN_IF_ERROR(RunShapeFn(node, op_reg_data, ec.get()));
272 
273   // Store the resulting context object in the map.
274   node_to_context_[node].swap(ec);
275 
276   return Status::OK();
277 }
278 
SetShape(const Node * node,int output_port,ShapeHandle shape)279 Status ShapeRefiner::SetShape(const Node* node, int output_port,
280                               ShapeHandle shape) {
281   auto c = GetContext(node);
282   if (c == nullptr) {
283     return errors::Internal("Could not find context for ", node->name());
284   }
285 
286   if (output_port < 0 || output_port >= node->num_outputs()) {
287     return errors::InvalidArgument(
288         "output_port '", output_port, "' is out of range, ", "node '",
289         node->name(), "' has ", node->num_outputs(), " outputs");
290   }
291   // Note: it's possible, if the node's been updated, that the shape inference
292   // context doesn't have the right number of outputs.
293   if (node->num_outputs() > c->num_outputs()) {
294     TF_RETURN_IF_ERROR(c->ExpandOutputs(node->num_outputs()));
295   }
296 
297   // Check compatibility, and merge the shapes.
298   ShapeHandle existing_shape = c->output(output_port);
299   TF_RETURN_IF_ERROR(c->Merge(existing_shape, shape, &shape));
300   c->set_output(output_port, shape);
301 
302   // TODO(vrv): Do we need to propagate the new shape through all
303   // consumers that change their outputs?  At the moment, python
304   // does not do this, but this seems like a nice feature.
305 
306   // TODO(vrv): We might need to keep track of the fact that the
307   // existing shape is invalidated, in case we need to propagate
308   // this information to remote workers.
309   return Status::OK();
310 }
311 
UpdateNode(const Node * node,bool relax,bool * refined)312 Status ShapeRefiner::UpdateNode(const Node* node, bool relax, bool* refined) {
313   auto it = node_to_context_.find(node);
314   if (it == node_to_context_.end()) {
315     *refined = true;
316     return AddNode(node);
317   }
318   ExtendedInferenceContext* node_ext_context = it->second.get();
319   InferenceContext* node_context = node_ext_context->get_context();
320 
321   // Give up if the context wasn't successfully built by the AddNode() method.
322   TF_RETURN_IF_ERROR(node_context->construction_status());
323 
324   // Check if the shapes of the nodes in the fan-in of this node have changed,
325   // and if they have update the node input shapes.
326   for (const Edge* e : node->in_edges()) {
327     if (e->IsControlEdge()) continue;
328 
329     int dst_input = e->dst_input();
330     int src_output = e->src_output();
331 
332     Node* input = e->src();
333     auto iter = node_to_context_.find(input);
334     if (iter == node_to_context_.end()) {
335       return errors::FailedPrecondition(
336           "Input ", dst_input, " ('", input->name(), "') for '", node->name(),
337           "' was not previously added to ShapeRefiner.");
338     }
339 
340     InferenceContext* c = iter->second->get_context();
341     DCHECK_GE(dst_input, 0);
342     ShapeHandle existing_input = node_context->input(dst_input);
343     if (!relax) {
344       if (node_context->MergeInput(dst_input, c->output(src_output))) {
345         if (!SameDefinedShape(node_context, node_context->input(dst_input),
346                               existing_input)) {
347           *refined = true;
348         }
349       }
350     } else {
351       if (node_context->RelaxInput(dst_input, c->output(src_output))) {
352         if (!SameDefinedShape(node_context, node_context->input(dst_input),
353                               existing_input)) {
354           *refined = true;
355         }
356       }
357     }
358     if (node_context->requested_input_tensor_as_partial_shape(dst_input)) {
359       // The input value may have changed. Since we have no way to know if
360       // that's indeed the case, err on the safe side.
361       *refined = true;
362     }
363 
364     // Also propagate handle shape and dtype of edges which are carrying
365     // resource handles.
366     if (e->src()->output_type(src_output) == DT_RESOURCE) {
367       auto* outputs = c->output_handle_shapes_and_types(src_output);
368       if (!outputs) continue;
369 
370       if (!relax &&
371           node_context->MergeInputHandleShapesAndTypes(dst_input, *outputs)) {
372         *refined = true;
373       } else if (relax) {
374         std::vector<ShapeAndType> existing_inputs;
375         const std::vector<ShapeAndType>* inputs =
376             node_context->input_handle_shapes_and_types(dst_input);
377         if (inputs) {
378           existing_inputs = *inputs;
379         }
380         if (node_context->RelaxInputHandleShapesAndMergeTypes(dst_input,
381                                                               *outputs)) {
382           if (IsUpdatedShapesOrTypes(
383                   node_context, existing_inputs,
384                   *node_context->input_handle_shapes_and_types(dst_input))) {
385             *refined = true;
386           }
387         }
388       }
389     }
390   }
391 
392   if (!*refined) {
393     // No input shape has changed, we're done
394     return Status::OK();
395   }
396 
397   // Get and run the shape function for this node to update the shapes of the
398   // outputs.
399   const OpRegistrationData* op_reg_data;
400   TF_RETURN_IF_ERROR(ops_registry_->LookUp(node->type_string(), &op_reg_data));
401   if (op_reg_data->shape_inference_fn == nullptr &&
402       require_shape_inference_fns_) {
403     return errors::InvalidArgument(
404         "No shape inference function exists for op '", node->type_string(),
405         "', did you forget to define it?");
406   }
407 
408   if (!op_reg_data->shape_inference_fn) {
409     // There is nothing more we can infer
410     return Status::OK();
411   }
412 
413   return RunShapeFn(node, op_reg_data, node_ext_context);
414 }
415 
EvaluateConstantTensorForEdge(const Node * node,int dst_idx,bool * evaluated,Tensor * result)416 Status ShapeRefiner::EvaluateConstantTensorForEdge(const Node* node,
417                                                    int dst_idx, bool* evaluated,
418                                                    Tensor* result) {
419   *evaluated = false;
420   const Edge* input_edge;
421   TF_RETURN_IF_ERROR(node->input_edge(dst_idx, &input_edge));
422   OutputTensor tensor(input_edge->src(), input_edge->src_output());
423   return EvaluateConstantTensor(tensor, *this, *ops_registry_,
424                                 graph_def_version_, evaluated, result,
425                                 &graph_runner_, &const_tensor_map_,
426                                 kMaxTensorSize, disable_constant_propagation_);
427 }
428 
EvaluateConstantIntScalarEdge(const Node * node,int dst_idx,bool * evaluated,int64 * result)429 Status ShapeRefiner::EvaluateConstantIntScalarEdge(const Node* node,
430                                                    int dst_idx, bool* evaluated,
431                                                    int64* result) {
432   Tensor scalar;
433   TF_RETURN_IF_ERROR(
434       EvaluateConstantTensorForEdge(node, dst_idx, evaluated, &scalar));
435   if (*evaluated) {
436     DCHECK_EQ(scalar.NumElements(), 1)
437         << "EvaluateConstantIntScalarEdge called on non-scalar edge: "
438         << scalar.NumElements();
439     if (scalar.dtype() == DT_INT32) {
440       *result = scalar.scalar<int32>()();
441     } else {
442       DCHECK_EQ(scalar.dtype(), DT_INT64)
443           << "EvaluateConstantIntScalarEdge called on non-integer edge: "
444           << scalar.dtype();
445       *result = scalar.scalar<int64>()();
446     }
447   }
448   return Status::OK();
449 }
450 
ConstantPartialShape(InferenceContext * target_context,const Node * node,int dst_idx,ShapeHandle * result)451 Status ShapeRefiner::ConstantPartialShape(InferenceContext* target_context,
452                                           const Node* node, int dst_idx,
453                                           ShapeHandle* result) {
454   const Edge* input_edge;
455   TF_RETURN_IF_ERROR(node->input_edge(dst_idx, &input_edge));
456 
457   InferenceContext* src_context = GetContext(input_edge->src());
458   if (src_context == nullptr) return errors::Internal("Missing src context");
459   ShapeHandle src_shape = src_context->output(input_edge->src_output());
460 
461   if (src_context->Value(src_context->Rank(src_shape)) == 0) {
462     Tensor t;
463     bool evaluated = false;
464     TF_RETURN_IF_ERROR(
465         EvaluateConstantTensorForEdge(node, dst_idx, &evaluated, &t));
466     if (!evaluated) {
467       return errors::InvalidArgument(
468           "Received a shape scalar with unknown static value.  A static value "
469           "of '-1' is required to represent an unknown shape.");
470     }
471     if (t.dims() == 0) {
472       if (t.dtype() == DT_INT32 && t.scalar<int32>()() == -1) {
473         *result = target_context->UnknownShape();
474         return Status::OK();
475       } else if (t.dtype() == DT_INT64 && t.scalar<int64>()() == -1) {
476         *result = target_context->UnknownShape();
477         return Status::OK();
478       }
479     }
480     return errors::InvalidArgument(
481         "Received an invalid shape scalar with a static value that is not "
482         "'-1': ",
483         t.DebugString());
484   }
485 
486   TF_RETURN_IF_ERROR(src_context->WithRank(src_shape, 1, &src_shape));
487 
488   const string& src_op = input_edge->src()->type_string();
489   if (src_context->Value(src_context->Dim(src_shape, 0)) == 0) {
490     // Source tensor is a vector of length 0, so the shape it
491     // represents is as scalar.
492     *result = target_context->Scalar();
493   } else if (src_op == "Shape") {
494     *result = src_context->input(0);
495   } else if (src_op == "ShapeN") {
496     *result = src_context->input(input_edge->src_output());
497   } else if (src_op == "Pack") {
498     std::vector<DimensionHandle> dims;
499     // Pack is concatenating its input scalars to form the shape tensor vector.
500     for (int i = 0; i < src_context->num_inputs(); ++i) {
501       int64 size;
502       bool evaluated;
503       TF_RETURN_IF_ERROR(EvaluateConstantIntScalarEdge(input_edge->src(), i,
504                                                        &evaluated, &size));
505       if (evaluated) {
506         dims.push_back(size < 0 ? target_context->UnknownDim()
507                                 : target_context->MakeDim(size));
508       } else {
509         dims.push_back(target_context->UnknownDim());
510       }
511     }
512     *result = target_context->MakeShape(dims);
513   } else if (src_op == "Concat" || src_op == "ConcatV2") {
514     *result = target_context->Scalar();
515     // For Concat, input 0 is concat dim; for V2 it is the last input.
516     const int concat_dim =
517         src_op == "Concat" ? 0 : src_context->num_inputs() - 1;
518     // Concat is concatenating its input shape vectors.
519     for (int i = 0; i < src_context->num_inputs(); ++i) {
520       // Concat dim is ignored (and will always be a scalar).
521       if (i == concat_dim) continue;
522       ShapeHandle sub_result;
523       TF_RETURN_IF_ERROR(ConstantPartialShape(target_context, input_edge->src(),
524                                               i, &sub_result));
525       if (!target_context->RankKnown(sub_result)) {
526         // Failed to evaluate. Treat the output as completely unknown.
527         // TODO(cwhipkey): we could rely on all inputs being the same rank, so
528         // figure that rank out and append the right number of unknown dims.
529         *result = target_context->UnknownShape();
530         return Status::OK();
531       }
532       TF_RETURN_IF_ERROR(
533           target_context->Concatenate(*result, sub_result, result));
534     }
535   } else if (src_op == "StridedSlice") {
536     TF_RETURN_IF_ERROR(
537         PartialStridedSliceShape(input_edge->src(), src_context, result));
538   } else {
539     Tensor t;
540     bool evaluated = false;
541     TF_RETURN_IF_ERROR(
542         EvaluateConstantTensorForEdge(node, dst_idx, &evaluated, &t));
543     TF_RETURN_IF_ERROR(target_context->MakeShapeFromTensor(
544         evaluated ? &t : nullptr, src_shape, result));
545   }
546   return Status::OK();
547 }
548 
PartialStridedSliceShape(Node * slice_node,InferenceContext * ctx,ShapeHandle * result)549 Status ShapeRefiner::PartialStridedSliceShape(Node* slice_node,
550                                               InferenceContext* ctx,
551                                               ShapeHandle* result) {
552   // Only attempt to evaluate if begin/end/strides all are scalars.
553   for (int i = 1; i <= 3; ++i) {
554     ShapeHandle input_shape = ctx->input(i);
555     if (ctx->Value(ctx->Dim(input_shape, 0)) != 1) {
556       *result = ctx->UnknownShape();
557       return Status::OK();
558     }
559   }
560 
561   int begin_mask, end_mask, ellipsis_mask, new_axis_mask, shrink_axis_mask;
562   TF_RETURN_IF_ERROR(
563       GetNodeAttr(slice_node->attrs(), "begin_mask", &begin_mask));
564   TF_RETURN_IF_ERROR(GetNodeAttr(slice_node->attrs(), "end_mask", &end_mask));
565   TF_RETURN_IF_ERROR(
566       GetNodeAttr(slice_node->attrs(), "ellipsis_mask", &ellipsis_mask));
567   TF_RETURN_IF_ERROR(
568       GetNodeAttr(slice_node->attrs(), "new_axis_mask", &new_axis_mask));
569   TF_RETURN_IF_ERROR(
570       GetNodeAttr(slice_node->attrs(), "shrink_axis_mask", &shrink_axis_mask));
571 
572   // Only attempt to evaluate if there are no special masks set (note that we
573   // can handle begin/end_mask == 1).
574   if (!(begin_mask == 0 || begin_mask == 1) ||
575       !(end_mask == 0 || end_mask == 1) || ellipsis_mask != 0 ||
576       new_axis_mask != 0 || shrink_axis_mask != 0) {
577     *result = ctx->UnknownShape();
578     return Status::OK();
579   }
580 
581   bool evaluated;
582   int64 begin;
583   if (begin_mask == 1) {
584     begin = 0;
585   } else {
586     TF_RETURN_IF_ERROR(
587         EvaluateConstantIntScalarEdge(slice_node, 1, &evaluated, &begin));
588     if (!evaluated) {
589       *result = ctx->UnknownShape();
590       return Status::OK();
591     }
592   }
593 
594   int64 end;
595   if (end_mask == 1) {
596     end = std::numeric_limits<int64>::max();
597   } else {
598     TF_RETURN_IF_ERROR(
599         EvaluateConstantIntScalarEdge(slice_node, 2, &evaluated, &end));
600     if (!evaluated) {
601       *result = ctx->UnknownShape();
602       return Status::OK();
603     }
604   }
605 
606   int64 stride;
607   TF_RETURN_IF_ERROR(
608       EvaluateConstantIntScalarEdge(slice_node, 3, &evaluated, &stride));
609   if (!evaluated) {
610     *result = ctx->UnknownShape();
611     return Status::OK();
612   }
613 
614   // Apply stride to input interpreted as a partial shape.
615   ShapeHandle input;
616   TF_RETURN_IF_ERROR(ConstantPartialShape(ctx, slice_node, 0, &input));
617   TF_RETURN_IF_ERROR(ctx->Subshape(input, begin, end, stride, result));
618   return Status::OK();
619 }
620 
RunShapeFn(const Node * node,const OpRegistrationData * op_reg_data,ExtendedInferenceContext * ec)621 Status ShapeRefiner::RunShapeFn(const Node* node,
622                                 const OpRegistrationData* op_reg_data,
623                                 ExtendedInferenceContext* ec) {
624   // This will be filled in with real data in a second pass.
625   std::vector<const Tensor*> input_tensors(node->num_inputs(), nullptr);
626   std::vector<Tensor> real_tensors(node->num_inputs());
627   std::vector<bool> attempted_materialization(node->num_inputs());
628   std::vector<bool> attempted_tensor_as_shape_conversion(node->num_inputs());
629   std::vector<ShapeHandle> input_tensors_as_shapes;
630 
631   auto* c = ec->get_context();
632 
633   c->set_input_tensors(input_tensors);
634   c->set_input_tensors_as_shapes(input_tensors_as_shapes);
635 
636   // Run the shape inference function, and return if there was an error.
637   // Capture as lambda, because we might need to re-run inference later on.
638   auto run_inference_lambda = [&]() {
639     if (function_library_ && op_reg_data->is_function_op) {
640       // Special inference logic for user-defined functions.
641 
642       auto* func_def = function_library_->Find(op_reg_data->op_def.name());
643       if (func_def) {
644         return InferShapesForFunction(func_def, keep_nested_shape_inferences_,
645                                       ec);
646       }
647     }
648 
649     if (op_reg_data->shape_inference_fn) {
650       TF_RETURN_IF_ERROR(c->Run(op_reg_data->shape_inference_fn));
651     } else {
652       TF_RETURN_IF_ERROR(c->Run(shape_inference::UnknownShape));
653     }
654     return Status::OK();
655   };
656   TF_RETURN_IF_ERROR(run_inference_lambda());
657 
658   // We must run the shape function repeatedly, in case users write
659   // shape functions where they only conditionally call input_tensor()
660   // based on the values of another input tensor.
661   bool rerun_shape_fn;
662   do {
663     // If the result of running shape inference would have benefitted
664     // from knowing the values of input tensors, try to materialize
665     // the results of those tensors, and then run the shape inference
666     // function again using those known tensors.
667     rerun_shape_fn = false;
668 
669     // NOTE: It is possible to batch the extraction and
670     // materialization of inputs, instead of materializing one input
671     // at a time like we do below.  If input-at-a-time computation
672     // becomes a bottleneck, we could separate ExtractConstantSubgraph
673     // into two functions: one that returns true if an input is
674     // derivable from constants, and another function that extracts
675     // the subgraph for multiple target nodes and executes the whole
676     // subgraph once.
677 
678     for (int i = 0; i < c->num_inputs(); ++i) {
679       if (!c->requested_input_tensor(i)) {
680         continue;
681       }
682       // Check if we have not already filled in the requested input,
683       // and if not, try to materialize the tensors.
684       if (!attempted_materialization[i]) {
685         attempted_materialization[i] = true;
686 
687         Tensor result;
688         bool evaluated = false;
689         TF_RETURN_IF_ERROR(
690             EvaluateConstantTensorForEdge(node, i, &evaluated, &result));
691         if (evaluated) {
692           real_tensors[i] = result;
693           input_tensors[i] = &real_tensors[i];
694           // We have more concrete information about a shape,
695           // so re-run shape inference.
696           rerun_shape_fn = true;
697         }
698       }
699       if (c->requested_input_tensor_as_partial_shape(i) &&
700           !attempted_tensor_as_shape_conversion[i]) {
701         attempted_tensor_as_shape_conversion[i] = true;
702         if (i >= input_tensors_as_shapes.size()) {
703           input_tensors_as_shapes.resize(i + 1);
704         }
705         ShapeHandle s;
706         TF_RETURN_IF_ERROR(ConstantPartialShape(c, node, i, &s));
707         input_tensors_as_shapes[i] = s;
708         rerun_shape_fn = true;
709       }
710     }
711 
712     if (rerun_shape_fn) {
713       // We have more information about the shapes on this pass,
714       // so re-run shape inference.
715       c->set_input_tensors(input_tensors);
716       c->set_input_tensors_as_shapes(input_tensors_as_shapes);
717       TF_RETURN_IF_ERROR(run_inference_lambda());
718     }
719   } while (rerun_shape_fn);
720 
721   return Status::OK();
722 }
723 
SameDefinedShape(InferenceContext * c,ShapeHandle s0,ShapeHandle s1)724 bool ShapeRefiner::SameDefinedShape(InferenceContext* c, ShapeHandle s0,
725                                     ShapeHandle s1) {
726   if (s0.SameHandle(s1)) {
727     return true;
728   }
729   if (c->Rank(s0) != c->Rank(s1)) {
730     return false;
731   }
732   if (!c->RankKnown(s0) && !c->RankKnown(s1)) {
733     return false;
734   }
735   for (int i = 0; i < c->Rank(s0); ++i) {
736     if (!c->Dim(s0, i).SameHandle(c->Dim(s1, i))) {
737       int64 val0 = c->Value(c->Dim(s0, i));
738       int64 val1 = c->Value(c->Dim(s1, i));
739       if (val0 < 0 || val1 < 0 || val0 != val1) {
740         return false;
741       }
742     }
743   }
744 
745   return true;
746 }
747 
IsUpdatedShapesOrTypes(InferenceContext * c,const std::vector<ShapeAndType> & existing,const std::vector<ShapeAndType> & updated)748 bool ShapeRefiner::IsUpdatedShapesOrTypes(
749     InferenceContext* c, const std::vector<ShapeAndType>& existing,
750     const std::vector<ShapeAndType>& updated) {
751   if (existing.size() != updated.size()) {
752     return true;
753   }
754   for (int i = 0; i < existing.size(); i++) {
755     if (!SameDefinedShape(c, existing[i].shape, updated[i].shape) ||
756         existing[i].dtype != updated[i].dtype) {
757       return true;
758     }
759   }
760   return false;
761 }
762 
763 }  // namespace tensorflow
764