1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/core/common_runtime/shape_refiner.h"
16 
17 #include <deque>
18 #include <memory>
19 #include <unordered_set>
20 #include <vector>
21 
22 #include "tensorflow/core/common_runtime/eval_const_tensor.h"
23 #include "tensorflow/core/common_runtime/function_utils.h"
24 #include "tensorflow/core/common_runtime/graph_constructor.h"
25 #include "tensorflow/core/framework/bounds_check.h"
26 #include "tensorflow/core/framework/common_shape_fns.h"
27 #include "tensorflow/core/framework/node_def.pb.h"
28 #include "tensorflow/core/framework/shape_inference.h"
29 #include "tensorflow/core/framework/tensor.h"
30 #include "tensorflow/core/framework/tensor.pb.h"
31 #include "tensorflow/core/framework/versions.pb.h"
32 #include "tensorflow/core/graph/algorithm.h"
33 #include "tensorflow/core/lib/core/errors.h"
34 
35 namespace tensorflow {
36 
37 using shape_inference::DimensionHandle;
38 using shape_inference::InferenceContext;
39 using shape_inference::ShapeAndType;
40 using shape_inference::ShapeHandle;
41 
ShapeRefiner(int graph_def_version,const OpRegistryInterface * ops)42 ShapeRefiner::ShapeRefiner(int graph_def_version,
43                            const OpRegistryInterface* ops)
44     : graph_def_version_(graph_def_version),
45       ops_registry_(ops),
46       graph_runner_(Env::Default()) {}
47 
ShapeRefiner(const VersionDef & versions,const OpRegistryInterface * ops)48 ShapeRefiner::ShapeRefiner(const VersionDef& versions,
49                            const OpRegistryInterface* ops)
50     : ShapeRefiner(versions.producer(), ops) {}
51 
~ShapeRefiner()52 ShapeRefiner::~ShapeRefiner() {
53   // The lifetime of the tensors are bound to the GraphRunner, so the tensors
54   // should be deleted before it.
55   const_tensor_map_.clear();
56 }
57 
58 namespace {
59 
60 constexpr char kArgOp[] = "_Arg";
61 constexpr char kRetvalOp[] = "_Retval";
62 
63 }  // namespace
64 
65 // Runs shape inference for the given node using the given ShapeRefiner.
66 // The node must be a sub-node of a function node and the outer_context is
67 // the inference context of that function node in the outer graph.
InferShapesForFunctionSubNode(const Node * node,InferenceContext * outer_context)68 Status ShapeRefiner::InferShapesForFunctionSubNode(
69     const Node* node, InferenceContext* outer_context) {
70   TF_RETURN_IF_ERROR(AddNodeInternal(node, outer_context));
71   InferenceContext* node_context = CHECK_NOTNULL(GetContext(node));
72 
73   if (StringPiece(node->type_string()) == kArgOp) {
74     // Handle special node: function input.
75     // Shapes for these nodes are provided in the outer inference
76     // context.
77 
78     int index;
79     TF_RETURN_IF_ERROR(GetNodeAttr(AttrSlice(node->def()), "index", &index));
80 
81     if (index < 0 || outer_context->num_inputs() <= index) {
82       return errors::Internal(
83           "Function instantiation included invalid input index: ", index,
84           " not in [0, ", outer_context->num_inputs(), ").");
85     }
86 
87     // TODO(b/134547156): TEMPORARY WORKAROUND. If input shape handle is not set
88     // in outer context, set _Arg node output shape to unknown.
89     if (outer_context->input(index).SameHandle(ShapeHandle())) {
90       VLOG(1) << "Function instantiation has undefined input shape at "
91               << "index: " << index << " in the outer inference context.";
92       node_context->set_output(0, node_context->UnknownShape());
93     } else {
94       node_context->set_output(0, outer_context->input(index));
95     }
96 
97     auto* resource = outer_context->input_handle_shapes_and_types(index);
98     if (resource) {
99       node_context->set_output_handle_shapes_and_types(0, *resource);
100     }
101   } else if (StringPiece(node->type_string()) == kRetvalOp) {
102     // Handle special node: function output.
103     // Shapes inferred for these nodes go into the outer inference
104     // context.
105 
106     int index;
107     TF_RETURN_IF_ERROR(GetNodeAttr(AttrSlice(node->def()), "index", &index));
108 
109     if (index < 0 || outer_context->num_outputs() <= index) {
110       return errors::Internal(
111           "Function instantiation included invalid output index: ", index,
112           " not in [0, ", outer_context->num_outputs(), ").");
113     }
114 
115     // outer_context outlives node_context, therefore we need to create
116     // a new shape handle owned by outer_context instead.
117     ShapeHandle handle;
118     TensorShapeProto proto;
119     node_context->ShapeHandleToProto(node_context->input(0), &proto);
120     TF_RETURN_IF_ERROR(outer_context->MakeShapeFromShapeProto(proto, &handle));
121     outer_context->set_output(index, handle);
122 
123     auto* resource = node_context->input_handle_shapes_and_types(0);
124     if (resource) {
125       outer_context->set_output_handle_shapes_and_types(index, *resource);
126     }
127   }
128 
129   return Status::OK();
130 }
131 
132 // TODO(cwhipkey): When an inference context inside function has
133 // requested_input_tensor(i) or requested_input_tensor_as_partial_shape(i)
134 // set when input(i) is an _Arg op, then this request should propagate to
135 // context, and vice versa.
136 //
137 // NOTE: Recursive user-defined functions are not supported.
138 // Maybe we won't support recursive functions at all in TF, because of
139 // other maintainability issues.
InferShapesForFunction(const FunctionDef * function_def,AttrSlice attributes,ExtendedInferenceContext * outer_context)140 Status ShapeRefiner::InferShapesForFunction(
141     const FunctionDef* function_def, AttrSlice attributes,
142     ExtendedInferenceContext* outer_context) {
143   const Graph* graph;
144   auto it = functions_.find(function_def);
145   if (it != functions_.end()) {
146     graph = it->second.get();
147   } else {
148     InstantiationResult result;
149     TF_RETURN_IF_ERROR(InstantiateFunction(
150         *function_def, attributes,
151         [this](const string& op, const OpDef** sig) {
152           return this->function_library_->LookUpOpDef(op, sig);
153         },
154         &result));
155 
156     Graph* new_graph = new Graph(function_library_);
157     GraphConstructorOptions options;
158     options.allow_internal_ops = true;
159     TF_RETURN_IF_ERROR(
160         ConvertNodeDefsToGraph(options, result.nodes, new_graph));
161     functions_[function_def].reset(new_graph);
162     graph = new_graph;
163   }
164 
165   std::unordered_set<const Node*> function_nodes;
166   Status inference_status = Status::OK();
167   {
168     auto node_shape_inference_lambda = [this, &outer_context, &function_nodes,
169                                         &inference_status](const Node* node) {
170       if (!inference_status.ok()) return;
171       inference_status =
172           InferShapesForFunctionSubNode(node, outer_context->get_context());
173       function_nodes.insert(node);
174     };
175 
176     // Calls inference lambda for each node after visiting all predecessors.
177     // Ensures that we are adding nodes to ShapeRefiner in the topological
178     // order.
179     ReverseDFS(*graph, {}, node_shape_inference_lambda);
180   }
181 
182   // Delete the contexts created for the functions nodes to save memory.
183   for (const Node* node : function_nodes) {
184     node_to_context_.erase(node);
185   }
186 
187   return inference_status;
188 }
189 
AddNode(const Node * node)190 Status ShapeRefiner::AddNode(const Node* node) {
191   return AddNodeInternal(node, /*outer_context=*/nullptr);
192 }
193 
AddNodeInternal(const Node * node,shape_inference::InferenceContext * outer_context)194 Status ShapeRefiner::AddNodeInternal(
195     const Node* node, shape_inference::InferenceContext* outer_context) {
196   // Create the inference context for this node with the existing input shapes.
197   std::unique_ptr<InferenceContext> ic(new InferenceContext(
198       graph_def_version_, node->def(), node->op_def(),
199       std::vector<ShapeHandle>(node->num_inputs()), {}, {}, {}));
200   TF_RETURN_IF_ERROR(ic->construction_status());
201 
202   // For each 'input' of this node, fetch the corresponding shape
203   // from 'input's InferenceContext, and store into this node's
204   // InferenceContext.
205   for (const Edge* e : node->in_edges()) {
206     if (e->IsControlEdge()) continue;
207 
208     if (e->dst_input() < 0) {
209       return tensorflow::errors::Internal(
210           "Index ", e->dst_input(), " is negative but not a control edge.");
211     }
212 
213     const Node* input = e->src();
214     auto it = node_to_context_.find(input);
215     if (it == node_to_context_.end()) {
216       // v1 control flow adds loops to the graph; we have to break them
217       // somewhere, so we'll ignore this input and leave its shape undefined.
218       ic->SetInput(e->dst_input(), ic->UnknownShape());
219       continue;
220     }
221 
222     InferenceContext* input_ic = it->second->get_context();
223     ic->SetInput(e->dst_input(), input_ic->output(e->src_output()));
224 
225     const auto* in_v =
226         input_ic->output_handle_shapes_and_types(e->src_output());
227     if (in_v != nullptr) {
228       DataType input_type = e->src()->output_type(e->src_output());
229       DCHECK(input_type == DT_RESOURCE || input_type == DT_VARIANT);
230       ic->set_input_handle_shapes_and_types(e->dst_input(),
231                                             std::vector<ShapeAndType>(*in_v));
232     }
233   }
234 
235   // Get the shape function for this node
236   const OpRegistrationData* op_reg_data;
237   TF_RETURN_IF_ERROR(ops_registry_->LookUp(node->type_string(), &op_reg_data));
238   if (op_reg_data->shape_inference_fn == nullptr &&
239       require_shape_inference_fns_) {
240     return errors::InvalidArgument(
241         "No shape inference function exists for op '", node->type_string(),
242         "', did you forget to define it?");
243   }
244 
245   std::unique_ptr<ExtendedInferenceContext> ec(
246       new ExtendedInferenceContext(std::move(ic), node));
247 
248   // Run the shape inference function, and return if there was an error.
249   TF_RETURN_IF_ERROR(RunShapeFn(node, op_reg_data, ec.get(), outer_context));
250 
251   // Store the resulting context object in the map.
252   node_to_context_[node].swap(ec);
253 
254   return Status::OK();
255 }
256 
SetShape(const Node * node,int output_port,ShapeHandle shape)257 Status ShapeRefiner::SetShape(const Node* node, int output_port,
258                               ShapeHandle shape) {
259   auto c = GetContext(node);
260   if (c == nullptr) {
261     return errors::Internal("Could not find context for ", node->name());
262   }
263 
264   if (output_port < 0 || output_port >= node->num_outputs()) {
265     return errors::InvalidArgument(
266         "output_port '", output_port, "' is out of range, ", "node '",
267         node->name(), "' has ", node->num_outputs(), " outputs");
268   }
269   // Note: it's possible, if the node's been updated, that the shape inference
270   // context doesn't have the right number of outputs.
271   if (node->num_outputs() > c->num_outputs()) {
272     TF_RETURN_IF_ERROR(c->ExpandOutputs(node->num_outputs()));
273   }
274 
275   // Check compatibility, and merge the shapes.
276   ShapeHandle existing_shape = c->output(output_port);
277   TF_RETURN_IF_ERROR(c->Merge(existing_shape, shape, &shape));
278   c->set_output(output_port, shape);
279 
280   // TODO(vrv): Do we need to propagate the new shape through all
281   // consumers that change their outputs?  At the moment, python
282   // does not do this, but this seems like a nice feature.
283 
284   // TODO(vrv): We might need to keep track of the fact that the
285   // existing shape is invalidated, in case we need to propagate
286   // this information to remote workers.
287   return Status::OK();
288 }
289 
UpdateNode(const Node * node,bool relax,bool * refined)290 Status ShapeRefiner::UpdateNode(const Node* node, bool relax, bool* refined) {
291   auto it = node_to_context_.find(node);
292   if (it == node_to_context_.end()) {
293     *refined = true;
294     return AddNode(node);
295   }
296   ExtendedInferenceContext* node_ext_context = it->second.get();
297   InferenceContext* node_context = node_ext_context->get_context();
298 
299   // Give up if the context wasn't successfully built by the AddNode() method.
300   TF_RETURN_IF_ERROR(node_context->construction_status());
301 
302   // Check if the shapes of the nodes in the fan-in of this node have changed,
303   // and if they have update the node input shapes.
304   for (const Edge* e : node->in_edges()) {
305     if (e->IsControlEdge()) continue;
306 
307     int dst_input = e->dst_input();
308     int src_output = e->src_output();
309 
310     Node* input = e->src();
311     auto iter = node_to_context_.find(input);
312     if (iter == node_to_context_.end()) {
313       return errors::FailedPrecondition(
314           "Input ", dst_input, " ('", input->name(), "') for '", node->name(),
315           "' was not previously added to ShapeRefiner.");
316     }
317 
318     InferenceContext* c = iter->second->get_context();
319     DCHECK_GE(dst_input, 0);
320     ShapeHandle existing_input = node_context->input(dst_input);
321     if (!relax) {
322       if (node_context->MergeInput(dst_input, c->output(src_output))) {
323         if (!SameDefinedShape(node_context, node_context->input(dst_input),
324                               existing_input)) {
325           *refined = true;
326         }
327       }
328     } else {
329       if (node_context->RelaxInput(dst_input, c->output(src_output))) {
330         if (!SameDefinedShape(node_context, node_context->input(dst_input),
331                               existing_input)) {
332           *refined = true;
333         }
334       }
335     }
336     if (node_context->requested_input_tensor_as_partial_shape(dst_input)) {
337       // The input value may have changed. Since we have no way to know if
338       // that's indeed the case, err on the safe side.
339       *refined = true;
340     }
341 
342     // Also propagate handle shape and dtype of edges which are carrying
343     // resource handles.
344     if (e->src()->output_type(src_output) == DT_RESOURCE) {
345       auto* outputs = c->output_handle_shapes_and_types(src_output);
346       if (!outputs) continue;
347 
348       if (!relax &&
349           node_context->MergeInputHandleShapesAndTypes(dst_input, *outputs)) {
350         *refined = true;
351       } else if (relax) {
352         std::vector<ShapeAndType> existing_inputs;
353         const std::vector<ShapeAndType>* inputs =
354             node_context->input_handle_shapes_and_types(dst_input);
355         if (inputs) {
356           existing_inputs = *inputs;
357         }
358         if (node_context->RelaxInputHandleShapesAndMergeTypes(dst_input,
359                                                               *outputs)) {
360           if (IsUpdatedShapesOrTypes(
361                   node_context, existing_inputs,
362                   *node_context->input_handle_shapes_and_types(dst_input))) {
363             *refined = true;
364           }
365         }
366       }
367     }
368   }
369 
370   if (!*refined) {
371     // No input shape has changed, we're done
372     return Status::OK();
373   }
374 
375   // Get and run the shape function for this node to update the shapes of the
376   // outputs.
377   const OpRegistrationData* op_reg_data;
378   TF_RETURN_IF_ERROR(ops_registry_->LookUp(node->type_string(), &op_reg_data));
379   if (op_reg_data->shape_inference_fn == nullptr &&
380       require_shape_inference_fns_) {
381     return errors::InvalidArgument(
382         "No shape inference function exists for op '", node->type_string(),
383         "', did you forget to define it?");
384   }
385 
386   if (!op_reg_data->shape_inference_fn) {
387     // There is nothing more we can infer
388     return Status::OK();
389   }
390 
391   return RunShapeFn(node, op_reg_data, node_ext_context);
392 }
393 
EvaluateConstantTensorForEdge(const Node * node,int dst_idx,bool * evaluated,Tensor * result,InferenceContext * outer_context)394 Status ShapeRefiner::EvaluateConstantTensorForEdge(
395     const Node* node, int dst_idx, bool* evaluated, Tensor* result,
396     InferenceContext* outer_context) {
397   *evaluated = false;
398   const Edge* input_edge;
399   TF_RETURN_IF_ERROR(node->input_edge(dst_idx, &input_edge));
400   OutputTensor tensor(input_edge->src(), input_edge->src_output());
401   return EvaluateConstantTensor(
402       tensor, *this, *ops_registry_, graph_def_version_, evaluated, result,
403       &graph_runner_, &const_tensor_map_, kMaxTensorSize,
404       disable_constant_propagation_, outer_context);
405 }
406 
EvaluateConstantIntScalarEdge(const Node * node,int dst_idx,bool * evaluated,int64 * result,shape_inference::InferenceContext * outer_context)407 Status ShapeRefiner::EvaluateConstantIntScalarEdge(
408     const Node* node, int dst_idx, bool* evaluated, int64* result,
409     shape_inference::InferenceContext* outer_context) {
410   Tensor scalar;
411   TF_RETURN_IF_ERROR(EvaluateConstantTensorForEdge(node, dst_idx, evaluated,
412                                                    &scalar, outer_context));
413   if (*evaluated) {
414     if (scalar.NumElements() != 1) {
415       return errors::InvalidArgument(
416           "EvaluateConstantIntScalarEdge called on non-scalar edge: ",
417           scalar.NumElements());
418     }
419     if (scalar.dtype() == DT_INT32) {
420       *result = scalar.scalar<int32>()();
421     } else {
422       if (scalar.dtype() != DT_INT64) {
423         return errors::InvalidArgument(
424             "EvaluateConstantIntScalarEdge called on non-integer edge: ",
425             scalar.dtype());
426       }
427       *result = scalar.scalar<int64>()();
428     }
429   }
430   return Status::OK();
431 }
432 
ConstantPartialShape(InferenceContext * target_context,const Node * node,int dst_idx,ShapeHandle * result,shape_inference::InferenceContext * outer_context)433 Status ShapeRefiner::ConstantPartialShape(
434     InferenceContext* target_context, const Node* node, int dst_idx,
435     ShapeHandle* result, shape_inference::InferenceContext* outer_context) {
436   const Edge* input_edge;
437   TF_RETURN_IF_ERROR(node->input_edge(dst_idx, &input_edge));
438 
439   InferenceContext* src_context = GetContext(input_edge->src());
440   if (src_context == nullptr) return errors::Internal("Missing src context");
441   ShapeHandle src_shape = src_context->output(input_edge->src_output());
442 
443   if (src_context->Value(src_context->Rank(src_shape)) == 0) {
444     Tensor t;
445     bool evaluated = false;
446     TF_RETURN_IF_ERROR(EvaluateConstantTensorForEdge(node, dst_idx, &evaluated,
447                                                      &t, outer_context));
448     if (!evaluated) {
449       return errors::InvalidArgument(
450           "Received a shape scalar with unknown static value.  A static value "
451           "of '-1' is required to represent an unknown shape.");
452     }
453     if (t.dims() == 0) {
454       if (t.dtype() == DT_INT32 && t.scalar<int32>()() == -1) {
455         *result = target_context->UnknownShape();
456         return Status::OK();
457       } else if (t.dtype() == DT_INT64 && t.scalar<int64>()() == -1) {
458         *result = target_context->UnknownShape();
459         return Status::OK();
460       }
461     }
462     return errors::InvalidArgument(
463         "Received an invalid shape scalar with a static value that is not "
464         "'-1': ",
465         t.DebugString());
466   }
467 
468   TF_RETURN_IF_ERROR(src_context->WithRank(src_shape, 1, &src_shape));
469 
470   const string& src_op = input_edge->src()->type_string();
471   if (src_context->Value(src_context->Dim(src_shape, 0)) == 0) {
472     // Source tensor is a vector of length 0, so the shape it
473     // represents is as scalar.
474     *result = target_context->Scalar();
475   } else if (src_op == "Cast") {
476     // First try to evaluate the current tensor, as it might be a valid cast of
477     // a float.
478     Tensor t;
479     bool evaluated = false;
480     if (EvaluateConstantTensorForEdge(node, dst_idx, &evaluated, &t,
481                                       outer_context)
482             .ok()) {
483       if (evaluated &&
484           target_context->MakeShapeFromTensor(&t, src_shape, result).ok()) {
485         return Status::OK();
486       }
487     }
488 
489     // Then try to infer partial shape from the input to the cast tensor.
490     ShapeHandle pre_cast_shape;
491     if (!ConstantPartialShape(target_context, input_edge->src(), 0,
492                               &pre_cast_shape, outer_context)
493              .ok()) {
494       TF_RETURN_IF_ERROR(
495           target_context->MakeShapeFromTensor(nullptr, src_shape, result));
496     }
497     if (!target_context->RankKnown(pre_cast_shape)) {
498       // Failed to evaluate. Treat the output as completely unknown.
499       *result = target_context->UnknownShape();
500       return Status::OK();
501     }
502     auto* dest_type = input_edge->src()->attrs().Find("DstT");
503     if (dest_type == nullptr || dest_type->value_case() != AttrValue::kType ||
504         (dest_type->type() != DT_INT32 && dest_type->type() != DT_INT64)) {
505       // Casting to a weird type. Do not attempt to infer across it.
506       *result = target_context->MakeShape(std::vector<DimensionHandle>(
507           target_context->Rank(pre_cast_shape), target_context->UnknownDim()));
508       return Status::OK();
509     }
510     *result = pre_cast_shape;
511   } else if (src_op == "Shape") {
512     *result = src_context->input(0);
513   } else if (src_op == "ShapeN") {
514     *result = src_context->input(input_edge->src_output());
515   } else if (src_op == "Pack") {
516     std::vector<DimensionHandle> dims;
517     // Pack is concatenating its input scalars to form the shape tensor vector.
518     for (int i = 0; i < src_context->num_inputs(); ++i) {
519       int64 size;
520       bool evaluated;
521       TF_RETURN_IF_ERROR(EvaluateConstantIntScalarEdge(
522           input_edge->src(), i, &evaluated, &size, outer_context));
523       if (evaluated) {
524         dims.push_back(size < 0 ? target_context->UnknownDim()
525                                 : target_context->MakeDim(size));
526       } else {
527         dims.push_back(target_context->UnknownDim());
528       }
529     }
530     *result = target_context->MakeShape(dims);
531   } else if (src_op == "Concat" || src_op == "ConcatV2") {
532     *result = target_context->Scalar();
533     // For Concat, input 0 is concat dim; for V2 it is the last input.
534     const int concat_dim =
535         src_op == "Concat" ? 0 : src_context->num_inputs() - 1;
536     // Concat is concatenating its input shape vectors.
537     for (int i = 0; i < src_context->num_inputs(); ++i) {
538       // Concat dim is ignored (and will always be a scalar).
539       if (i == concat_dim) continue;
540       ShapeHandle sub_result;
541       TF_RETURN_IF_ERROR(ConstantPartialShape(target_context, input_edge->src(),
542                                               i, &sub_result, outer_context));
543       if (!target_context->RankKnown(sub_result)) {
544         // Failed to evaluate. Treat the output as completely unknown.
545         // TODO(cwhipkey): we could rely on all inputs being the same rank, so
546         // figure that rank out and append the right number of unknown dims.
547         *result = target_context->UnknownShape();
548         return Status::OK();
549       }
550       TF_RETURN_IF_ERROR(
551           target_context->Concatenate(*result, sub_result, result));
552     }
553   } else if (src_op == "StridedSlice") {
554     TF_RETURN_IF_ERROR(PartialStridedSliceShape(input_edge->src(), src_context,
555                                                 result, outer_context));
556   } else if (src_op == "VariableShape") {
557     auto* handle_data = src_context->input_handle_shapes_and_types(0);
558     if (handle_data != nullptr && !handle_data->empty()) {
559       *result = handle_data->at(0).shape;
560     } else {
561       *result = target_context->UnknownShape();
562     }
563   } else {
564     Tensor t;
565     bool evaluated = false;
566     TF_RETURN_IF_ERROR(EvaluateConstantTensorForEdge(node, dst_idx, &evaluated,
567                                                      &t, outer_context));
568     TF_RETURN_IF_ERROR(target_context->MakeShapeFromTensor(
569         evaluated ? &t : nullptr, src_shape, result));
570   }
571   return Status::OK();
572 }
573 
PartialStridedSliceShape(Node * slice_node,InferenceContext * ctx,ShapeHandle * result,shape_inference::InferenceContext * outer_context)574 Status ShapeRefiner::PartialStridedSliceShape(
575     Node* slice_node, InferenceContext* ctx, ShapeHandle* result,
576     shape_inference::InferenceContext* outer_context) {
577   // Only attempt to evaluate if begin/end/strides all are scalars.
578   for (int i = 1; i <= 3; ++i) {
579     ShapeHandle input_shape = ctx->input(i);
580     if (ctx->Value(ctx->Dim(input_shape, 0)) != 1) {
581       *result = ctx->UnknownShape();
582       return Status::OK();
583     }
584   }
585 
586   int begin_mask, end_mask, ellipsis_mask, new_axis_mask, shrink_axis_mask;
587   TF_RETURN_IF_ERROR(
588       GetNodeAttr(slice_node->attrs(), "begin_mask", &begin_mask));
589   TF_RETURN_IF_ERROR(GetNodeAttr(slice_node->attrs(), "end_mask", &end_mask));
590   TF_RETURN_IF_ERROR(
591       GetNodeAttr(slice_node->attrs(), "ellipsis_mask", &ellipsis_mask));
592   TF_RETURN_IF_ERROR(
593       GetNodeAttr(slice_node->attrs(), "new_axis_mask", &new_axis_mask));
594   TF_RETURN_IF_ERROR(
595       GetNodeAttr(slice_node->attrs(), "shrink_axis_mask", &shrink_axis_mask));
596 
597   // Only attempt to evaluate if there are no special masks set (note that we
598   // can handle begin/end_mask == 1).
599   if (!(begin_mask == 0 || begin_mask == 1) ||
600       !(end_mask == 0 || end_mask == 1) || ellipsis_mask != 0 ||
601       new_axis_mask != 0 || shrink_axis_mask != 0) {
602     *result = ctx->UnknownShape();
603     return Status::OK();
604   }
605 
606   bool evaluated;
607   int64 begin;
608   if (begin_mask == 1) {
609     begin = 0;
610   } else {
611     TF_RETURN_IF_ERROR(EvaluateConstantIntScalarEdge(slice_node, 1, &evaluated,
612                                                      &begin, outer_context));
613     if (!evaluated) {
614       *result = ctx->UnknownShape();
615       return Status::OK();
616     }
617   }
618 
619   int64 end;
620   if (end_mask == 1) {
621     end = std::numeric_limits<int64>::max();
622   } else {
623     TF_RETURN_IF_ERROR(EvaluateConstantIntScalarEdge(slice_node, 2, &evaluated,
624                                                      &end, outer_context));
625     if (!evaluated) {
626       *result = ctx->UnknownShape();
627       return Status::OK();
628     }
629   }
630 
631   int64 stride;
632   TF_RETURN_IF_ERROR(EvaluateConstantIntScalarEdge(slice_node, 3, &evaluated,
633                                                    &stride, outer_context));
634   if (!evaluated) {
635     *result = ctx->UnknownShape();
636     return Status::OK();
637   }
638 
639   // Apply stride to input interpreted as a partial shape.
640   ShapeHandle input;
641   TF_RETURN_IF_ERROR(
642       ConstantPartialShape(ctx, slice_node, 0, &input, outer_context));
643   TF_RETURN_IF_ERROR(ctx->Subshape(input, begin, end, stride, result));
644   return Status::OK();
645 }
646 
RunShapeFn(const Node * node,const OpRegistrationData * op_reg_data,ExtendedInferenceContext * ec,InferenceContext * outer_context)647 Status ShapeRefiner::RunShapeFn(const Node* node,
648                                 const OpRegistrationData* op_reg_data,
649                                 ExtendedInferenceContext* ec,
650                                 InferenceContext* outer_context) {
651   // This will be filled in with real data in a second pass.
652   std::vector<const Tensor*> input_tensors(node->num_inputs(), nullptr);
653   std::vector<Tensor> real_tensors(node->num_inputs());
654   std::vector<bool> attempted_materialization(node->num_inputs());
655   std::vector<bool> attempted_tensor_as_shape_conversion(node->num_inputs());
656   std::vector<ShapeHandle> input_tensors_as_shapes;
657 
658   auto* c = ec->get_context();
659 
660   c->set_input_tensors(input_tensors);
661   c->set_input_tensors_as_shapes(input_tensors_as_shapes);
662 
663   // Run the shape inference function, and return if there was an error.
664   // Capture as lambda, because we might need to re-run inference later on.
665   auto run_inference_lambda = [&]() {
666     if (function_library_ && IsFunctionCall(*function_library_, *node)) {
667       bool disable_shape_inference;
668       if (!GetNodeAttr(AttrSlice(node->def()), "_disable_call_shape_inference",
669                        &disable_shape_inference)
670                .ok() ||
671           !disable_shape_inference) {
672         // Special inference logic for user-defined functions.
673         NameAttrList function;
674         TF_RETURN_IF_ERROR(
675             NameAndAttrsFromFunctionCall(node->def(), &function));
676         const FunctionDef* function_def =
677             function_library_->Find(function.name());
678         if (function_def != nullptr) {
679           // The constant Tensor map we have for the outside context is not
680           // valid inside the function. We need to push a new clean map while
681           // performing inference on the function body.
682           auto const_tensor_map_copy = const_tensor_map_;
683           const_tensor_map_.clear();
684           Status function_inference_status = InferShapesForFunction(
685               function_def, AttrSlice(&function.attr()), ec);
686           const_tensor_map_ = const_tensor_map_copy;
687           return function_inference_status;
688         }
689       }
690     }
691 
692     if (op_reg_data->shape_inference_fn) {
693       TF_RETURN_IF_ERROR(c->Run(op_reg_data->shape_inference_fn));
694     } else {
695       TF_RETURN_IF_ERROR(c->Run(shape_inference::UnknownShape));
696     }
697     return Status::OK();
698   };
699   TF_RETURN_IF_ERROR(run_inference_lambda());
700 
701   // We must run the shape function repeatedly, in case users write
702   // shape functions where they only conditionally call input_tensor()
703   // based on the values of another input tensor.
704   bool rerun_shape_fn;
705   do {
706     // If the result of running shape inference would have benefitted
707     // from knowing the values of input tensors, try to materialize
708     // the results of those tensors, and then run the shape inference
709     // function again using those known tensors.
710     rerun_shape_fn = false;
711 
712     // NOTE: It is possible to batch the extraction and
713     // materialization of inputs, instead of materializing one input
714     // at a time like we do below.  If input-at-a-time computation
715     // becomes a bottleneck, we could separate ExtractConstantSubgraph
716     // into two functions: one that returns true if an input is
717     // derivable from constants, and another function that extracts
718     // the subgraph for multiple target nodes and executes the whole
719     // subgraph once.
720 
721     for (int i = 0; i < c->num_inputs(); ++i) {
722       if (!c->requested_input_tensor(i)) {
723         continue;
724       }
725       // Check if we have not already filled in the requested input,
726       // and if not, try to materialize the tensors.
727       if (!attempted_materialization[i]) {
728         attempted_materialization[i] = true;
729 
730         Tensor result;
731         bool evaluated = false;
732         TF_RETURN_IF_ERROR(EvaluateConstantTensorForEdge(
733             node, i, &evaluated, &result, outer_context));
734         if (evaluated) {
735           real_tensors[i] = result;
736           input_tensors[i] = &real_tensors[i];
737           // We have more concrete information about a shape,
738           // so re-run shape inference.
739           rerun_shape_fn = true;
740         }
741       }
742       if (c->requested_input_tensor_as_partial_shape(i) &&
743           !attempted_tensor_as_shape_conversion[i]) {
744         attempted_tensor_as_shape_conversion[i] = true;
745         if (i >= input_tensors_as_shapes.size()) {
746           input_tensors_as_shapes.resize(i + 1);
747         }
748         ShapeHandle s;
749         TF_RETURN_IF_ERROR(ConstantPartialShape(c, node, i, &s, outer_context));
750         input_tensors_as_shapes[i] = s;
751         rerun_shape_fn = true;
752       }
753     }
754 
755     if (rerun_shape_fn) {
756       // We have more information about the shapes on this pass,
757       // so re-run shape inference.
758       c->set_input_tensors(input_tensors);
759       c->set_input_tensors_as_shapes(input_tensors_as_shapes);
760       TF_RETURN_IF_ERROR(run_inference_lambda());
761     }
762   } while (rerun_shape_fn);
763 
764   return Status::OK();
765 }
766 
SameDefinedShape(InferenceContext * c,ShapeHandle s0,ShapeHandle s1)767 bool ShapeRefiner::SameDefinedShape(InferenceContext* c, ShapeHandle s0,
768                                     ShapeHandle s1) {
769   if (s0.SameHandle(s1)) {
770     return true;
771   }
772   if (c->Rank(s0) != c->Rank(s1)) {
773     return false;
774   }
775   if (!c->RankKnown(s0) && !c->RankKnown(s1)) {
776     return false;
777   }
778   for (int i = 0; i < c->Rank(s0); ++i) {
779     if (!c->Dim(s0, i).SameHandle(c->Dim(s1, i))) {
780       int64 val0 = c->Value(c->Dim(s0, i));
781       int64 val1 = c->Value(c->Dim(s1, i));
782       if (val0 < 0 || val1 < 0 || val0 != val1) {
783         return false;
784       }
785     }
786   }
787 
788   return true;
789 }
790 
IsUpdatedShapesOrTypes(InferenceContext * c,const std::vector<ShapeAndType> & existing,const std::vector<ShapeAndType> & updated)791 bool ShapeRefiner::IsUpdatedShapesOrTypes(
792     InferenceContext* c, const std::vector<ShapeAndType>& existing,
793     const std::vector<ShapeAndType>& updated) {
794   if (existing.size() != updated.size()) {
795     return true;
796   }
797   for (int i = 0; i < existing.size(); i++) {
798     if (!SameDefinedShape(c, existing[i].shape, updated[i].shape) ||
799         existing[i].dtype != updated[i].dtype) {
800       return true;
801     }
802   }
803   return false;
804 }
805 
806 }  // namespace tensorflow
807