1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/core/common_runtime/shape_refiner.h"
16
17 #include <deque>
18 #include <memory>
19 #include <unordered_set>
20 #include <vector>
21
22 #include "tensorflow/core/framework/common_shape_fns.h"
23 #include "tensorflow/core/framework/node_def.pb.h"
24 #include "tensorflow/core/framework/tensor.h"
25 #include "tensorflow/core/framework/tensor.pb.h"
26 #include "tensorflow/core/framework/versions.pb.h"
27 #include "tensorflow/core/graph/algorithm.h"
28 #include "tensorflow/core/graph/graph_constructor.h"
29 #include "tensorflow/core/kernels/bounds_check.h"
30 #include "tensorflow/core/lib/core/errors.h"
31 #include "tensorflow/core/lib/gtl/stl_util.h"
32 #include "tensorflow/core/public/session.h"
33
34 namespace tensorflow {
35
36 using shape_inference::DimensionHandle;
37 using shape_inference::InferenceContext;
38 using shape_inference::ShapeAndType;
39 using shape_inference::ShapeHandle;
40
ShapeRefiner(int graph_def_version,const OpRegistryInterface * ops)41 ShapeRefiner::ShapeRefiner(int graph_def_version,
42 const OpRegistryInterface* ops)
43 : graph_def_version_(graph_def_version),
44 ops_registry_(ops),
45 graph_runner_(Env::Default()) {}
46
ShapeRefiner(const VersionDef & versions,const OpRegistryInterface * ops)47 ShapeRefiner::ShapeRefiner(const VersionDef& versions,
48 const OpRegistryInterface* ops)
49 : ShapeRefiner(versions.producer(), ops) {}
50
~ShapeRefiner()51 ShapeRefiner::~ShapeRefiner() {
52 // The lifetime of the tensors are bound to the GraphRunner, so the tensors
53 // should be deleted before it.
54 const_tensor_map_.clear();
55 }
56
57 namespace {
58
59 constexpr char kArgOp[] = "_Arg";
60 constexpr char kRetvalOp[] = "_Retval";
61
62 // Runs shape inference for the given node using the given ShapeRefiner.
63 // The node must be a sub-node of a function node and the outer_context is
64 // the inference context of that function node in the outer graph.
InferShapesForFunctionSubNode(const Node * node,ShapeRefiner * refiner,InferenceContext * outer_context)65 Status InferShapesForFunctionSubNode(const Node* node, ShapeRefiner* refiner,
66 InferenceContext* outer_context) {
67 TF_RETURN_IF_ERROR(refiner->AddNode(node));
68 InferenceContext* node_context = CHECK_NOTNULL(refiner->GetContext(node));
69
70 if (StringPiece(node->type_string()) == kArgOp) {
71 // Handle special node: function input.
72 // Shapes for these nodes are provided in the outer inference
73 // context.
74
75 int index;
76 TF_RETURN_IF_ERROR(GetNodeAttr(AttrSlice(node->def()), "index", &index));
77
78 if (index < 0 || outer_context->num_inputs() <= index) {
79 return errors::Internal(
80 "Function instantiation included invalid input index: ", index,
81 " not in [0, ", outer_context->num_inputs(), ").");
82 }
83
84 node_context->set_output(0, outer_context->input(index));
85
86 auto* resource = outer_context->input_handle_shapes_and_types(index);
87 if (resource) {
88 node_context->set_output_handle_shapes_and_types(0, *resource);
89 }
90 } else if (StringPiece(node->type_string()) == kRetvalOp) {
91 // Handle special node: function output.
92 // Shapes inferred for these nodes go into the outer inference
93 // context.
94
95 int index;
96 TF_RETURN_IF_ERROR(GetNodeAttr(AttrSlice(node->def()), "index", &index));
97
98 if (index < 0 || outer_context->num_outputs() <= index) {
99 return errors::Internal(
100 "Function instantiation included invalid output index: ", index,
101 " not in [0, ", outer_context->num_outputs(), ").");
102 }
103
104 // outer_context outlives node_context, therefore we need to create
105 // a new shape handle owned by outer_context instead.
106 ShapeHandle handle;
107 TensorShapeProto proto;
108 node_context->ShapeHandleToProto(node_context->input(0), &proto);
109 TF_RETURN_IF_ERROR(outer_context->MakeShapeFromShapeProto(proto, &handle));
110 outer_context->set_output(index, handle);
111
112 auto* resource = node_context->input_handle_shapes_and_types(0);
113 if (resource) {
114 outer_context->set_output_handle_shapes_and_types(index, *resource);
115 }
116 }
117
118 return Status::OK();
119 }
120
121 } // namespace
122
123 // TODO(cwhipkey): When an inference context inside function has
124 // requested_input_tensor(i) or requested_input_tensor_as_partial_shape(i)
125 // set when input(i) is an _Arg op, then this request should propagate to
126 // context, and vice versa.
127 //
128 // NOTE: Recursive user-defined functions are not supported.
129 // Maybe we won't support recursive functions at all in TF, because of
130 // other maintainability issues.
InferShapesForFunction(const tensorflow::FunctionDef * function_def,bool keep_nested_shapes,ExtendedInferenceContext * outer_context)131 Status ShapeRefiner::InferShapesForFunction(
132 const tensorflow::FunctionDef* function_def, bool keep_nested_shapes,
133 ExtendedInferenceContext* outer_context) {
134 const Graph* graph;
135 auto it = functions_.find(function_def);
136 if (it != functions_.end()) {
137 graph = it->second.get();
138 } else {
139 InstantiationResult result;
140 TF_RETURN_IF_ERROR(InstantiateFunction(
141 *function_def, outer_context->get_context()->attrs(),
142 [this](const string& op, const OpDef** sig) {
143 return this->function_library_->LookUpOpDef(op, sig);
144 },
145 &result));
146
147 Graph* new_graph = new Graph(function_library_);
148 GraphConstructorOptions options;
149 options.allow_internal_ops = true;
150 TF_RETURN_IF_ERROR(
151 ConvertNodeDefsToGraph(options, result.nodes, new_graph));
152 functions_[function_def].reset(new_graph);
153 graph = new_graph;
154 }
155
156 std::unordered_set<const Node*> function_nodes;
157 Status inference_status = Status::OK();
158 {
159 auto node_shape_inference_lambda = [this, &outer_context, &function_nodes,
160 &inference_status](const Node* node) {
161 if (!inference_status.ok()) return;
162 inference_status = InferShapesForFunctionSubNode(
163 node, this, outer_context->get_context());
164 function_nodes.insert(node);
165 };
166
167 // Calls inference lambda for each node after visiting all predecessors.
168 // Ensures that we are adding nodes to ShapeRefiner in the topological
169 // order.
170 ReverseDFS(*graph, {}, node_shape_inference_lambda);
171 }
172
173 if (keep_nested_shapes && inference_status.ok()) {
174 // Fill the nested inferences map.
175 //
176 // The materialized function graph has extra nodes for arguments and
177 // return values, which are not explicitly listed in the FunctionDef,
178 // we filter out these special nodes here to not expose the implementation
179 // details and keep only inferences for the nodes listed in the FunctionDef.
180 std::unordered_map<string, const NodeDef*> user_defined_nodes;
181 for (const auto& node_def : function_def->node_def()) {
182 user_defined_nodes[node_def.name()] = &node_def;
183 }
184
185 std::unordered_map<string, std::unique_ptr<ExtendedInferenceContext>>
186 nested_inferences;
187 for (const Node* node : function_nodes) {
188 const string& node_name = node->name();
189 if (user_defined_nodes.find(node_name) != user_defined_nodes.end()) {
190 nested_inferences[node_name] = std::move(node_to_context_[node]);
191 node_to_context_.erase(node);
192 // By default InferenceContext refers to a NodeDef from Graph.
193 // Change it to the publicly accessible NodeDef of the function
194 // definition.
195 nested_inferences[node_name]->get_context()->node_def_ =
196 user_defined_nodes[node_name];
197 }
198 }
199 outer_context->set_nested_inferences(std::move(nested_inferences));
200 } else {
201 // Delete the contexts created for the functions nodes to save memory.
202 for (const Node* node : function_nodes) {
203 node_to_context_.erase(node);
204 }
205 }
206
207 return inference_status;
208 }
209
AddNode(const Node * node)210 Status ShapeRefiner::AddNode(const Node* node) {
211 // For each 'input' of this node, fetch the corresponding shape
212 // from 'input's InferenceContext, and store into a vector
213 // indexed by 'node's input.
214 std::vector<Node*> input_nodes(node->num_inputs());
215 std::vector<ShapeHandle> input_shapes(node->num_inputs());
216 std::vector<std::unique_ptr<std::vector<ShapeAndType>>>
217 input_handle_shapes_and_types(node->num_inputs());
218 for (const Edge* e : node->in_edges()) {
219 if (e->IsControlEdge()) continue;
220
221 Node* input = e->src();
222 auto it = node_to_context_.find(input);
223 if (it == node_to_context_.end()) {
224 return errors::FailedPrecondition(
225 "Input ", e->dst_input(), " ('", input->name(), "') for '",
226 node->name(), "' was not previously added to ShapeRefiner.");
227 }
228
229 InferenceContext* c = it->second->get_context();
230 DCHECK_GE(e->dst_input(), 0);
231 input_nodes[e->dst_input()] = input;
232 input_shapes[e->dst_input()] = c->output(e->src_output());
233
234 // Only propagate handle data of edges which are carrying resource handles.
235 if (e->src()->output_type(e->src_output()) == DT_RESOURCE) {
236 const auto* in_v = c->output_handle_shapes_and_types(e->src_output());
237 if (in_v != nullptr) {
238 input_handle_shapes_and_types[e->dst_input()].reset(
239 new std::vector<ShapeAndType>(*in_v));
240 }
241 }
242 }
243
244 // Get the shape function for this node
245 const OpRegistrationData* op_reg_data;
246 TF_RETURN_IF_ERROR(ops_registry_->LookUp(node->type_string(), &op_reg_data));
247 if (op_reg_data->shape_inference_fn == nullptr &&
248 require_shape_inference_fns_) {
249 return errors::InvalidArgument(
250 "No shape inference function exists for op '", node->type_string(),
251 "', did you forget to define it?");
252 }
253
254 // This needs to be filled in with real data in a second pass.
255 std::vector<const Tensor*> input_tensors(node->num_inputs(), nullptr);
256 std::vector<ShapeHandle> input_tensors_as_shapes;
257
258 // Create the inference context for this node with the existing input shapes.
259 std::unique_ptr<InferenceContext> c(
260 new InferenceContext(graph_def_version_, &node->def(), node->op_def(),
261 input_shapes, input_tensors, input_tensors_as_shapes,
262 std::move(input_handle_shapes_and_types)));
263 if (!c->construction_status().ok()) {
264 return c->construction_status();
265 }
266
267 std::unique_ptr<ExtendedInferenceContext> ec(
268 new ExtendedInferenceContext(std::move(c), node));
269
270 // Run the shape inference function, and return if there was an error.
271 TF_RETURN_IF_ERROR(RunShapeFn(node, op_reg_data, ec.get()));
272
273 // Store the resulting context object in the map.
274 node_to_context_[node].swap(ec);
275
276 return Status::OK();
277 }
278
SetShape(const Node * node,int output_port,ShapeHandle shape)279 Status ShapeRefiner::SetShape(const Node* node, int output_port,
280 ShapeHandle shape) {
281 auto c = GetContext(node);
282 if (c == nullptr) {
283 return errors::Internal("Could not find context for ", node->name());
284 }
285
286 if (output_port < 0 || output_port >= node->num_outputs()) {
287 return errors::InvalidArgument(
288 "output_port '", output_port, "' is out of range, ", "node '",
289 node->name(), "' has ", node->num_outputs(), " outputs");
290 }
291
292 // Check compatibility, and merge the shapes.
293 ShapeHandle existing_shape = c->output(output_port);
294 TF_RETURN_IF_ERROR(c->Merge(existing_shape, shape, &shape));
295 c->set_output(output_port, shape);
296
297 // TODO(vrv): Do we need to propagate the new shape through all
298 // consumers that change their outputs? At the moment, python
299 // does not do this, but this seems like a nice feature.
300
301 // TODO(vrv): We might need to keep track of the fact that the
302 // existing shape is invalidated, in case we need to propagate
303 // this information to remote workers.
304 return Status::OK();
305 }
306
UpdateNode(const Node * node,bool relax,bool * refined)307 Status ShapeRefiner::UpdateNode(const Node* node, bool relax, bool* refined) {
308 auto it = node_to_context_.find(node);
309 if (it == node_to_context_.end()) {
310 *refined = true;
311 return AddNode(node);
312 }
313 ExtendedInferenceContext* node_ext_context = it->second.get();
314 InferenceContext* node_context = node_ext_context->get_context();
315
316 // Give up if the context wasn't successfully built by the AddNode() method.
317 TF_RETURN_IF_ERROR(node_context->construction_status());
318
319 // Check if the shapes of the nodes in the fan-in of this node have changed,
320 // and if they have update the node input shapes.
321 for (const Edge* e : node->in_edges()) {
322 if (e->IsControlEdge()) continue;
323
324 int dst_input = e->dst_input();
325 int src_output = e->src_output();
326
327 Node* input = e->src();
328 auto iter = node_to_context_.find(input);
329 if (iter == node_to_context_.end()) {
330 return errors::FailedPrecondition(
331 "Input ", dst_input, " ('", input->name(), "') for '", node->name(),
332 "' was not previously added to ShapeRefiner.");
333 }
334
335 InferenceContext* c = iter->second->get_context();
336 DCHECK_GE(dst_input, 0);
337 ShapeHandle existing_input = node_context->input(dst_input);
338 if (!relax) {
339 if (node_context->MergeInput(dst_input, c->output(src_output))) {
340 if (!SameDefinedShape(node_context, node_context->input(dst_input),
341 existing_input)) {
342 *refined = true;
343 }
344 }
345 } else {
346 if (node_context->RelaxInput(dst_input, c->output(src_output))) {
347 if (!SameDefinedShape(node_context, node_context->input(dst_input),
348 existing_input)) {
349 *refined = true;
350 }
351 }
352 }
353
354 // Also propagate handle shape and dtype of edges which are carrying
355 // resource handles.
356 if (e->src()->output_type(src_output) == DT_RESOURCE) {
357 auto* outputs = c->output_handle_shapes_and_types(src_output);
358 if (!outputs) continue;
359
360 if (!relax &&
361 node_context->MergeInputHandleShapesAndTypes(dst_input, *outputs)) {
362 *refined = true;
363 } else if (relax) {
364 std::vector<ShapeAndType> existing_inputs;
365 const std::vector<ShapeAndType>* inputs =
366 node_context->input_handle_shapes_and_types(dst_input);
367 if (inputs) {
368 existing_inputs = *inputs;
369 }
370 if (node_context->RelaxInputHandleShapesAndMergeTypes(dst_input,
371 *outputs)) {
372 if (IsUpdatedShapesOrTypes(
373 node_context, existing_inputs,
374 *node_context->input_handle_shapes_and_types(dst_input))) {
375 *refined = true;
376 }
377 }
378 }
379 }
380 }
381
382 if (!*refined) {
383 // No input shape has changed, we're done
384 return Status::OK();
385 }
386
387 // Get and run the shape function for this node to update the shapes of the
388 // outputs.
389 const OpRegistrationData* op_reg_data;
390 TF_RETURN_IF_ERROR(ops_registry_->LookUp(node->type_string(), &op_reg_data));
391 if (op_reg_data->shape_inference_fn == nullptr &&
392 require_shape_inference_fns_) {
393 return errors::InvalidArgument(
394 "No shape inference function exists for op '", node->type_string(),
395 "', did you forget to define it?");
396 }
397
398 if (!op_reg_data->shape_inference_fn) {
399 // There is nothing more we can infer
400 return Status::OK();
401 }
402
403 return RunShapeFn(node, op_reg_data, node_ext_context);
404 }
405
EvaluateConstantTensorForEdge(const Node * node,int dst_idx,bool * evaluated,Tensor * result)406 Status ShapeRefiner::EvaluateConstantTensorForEdge(const Node* node,
407 int dst_idx, bool* evaluated,
408 Tensor* result) {
409 *evaluated = false;
410
411 const Edge* input_edge;
412 TF_RETURN_IF_ERROR(node->input_edge(dst_idx, &input_edge));
413
414 // Simple case: the source node is a constant
415 const Node* src = input_edge->src();
416 if (src->IsConstant()) {
417 if (result->FromProto(src->def().attr().at("value").tensor())) {
418 *evaluated = true;
419 return Status::OK();
420 }
421 }
422
423 if (disable_constant_propagation_) {
424 return Status::OK();
425 }
426
427 bool is_constant_graph = false;
428 Graph subgraph(ops_registry_);
429 auto versions = subgraph.versions();
430 versions.set_producer(graph_def_version_);
431 subgraph.set_versions(versions);
432
433 // We identify the possibly constant subgraph to evaluate by
434 // recursively iterating backwards through the inputs to 'node'
435 // until we either 1) find an already existing input to our subgraph
436 // (filled in `const_inputs`), 2) Discover our graph is not constant,
437 // or 3) Hit a root node.
438 std::vector<std::pair<string, Tensor>> const_inputs;
439 TF_RETURN_IF_ERROR(ExtractConstantSubgraph(
440 input_edge->src(), &subgraph, &is_constant_graph, &const_inputs));
441 if (!is_constant_graph) {
442 return Status::OK();
443 }
444 const string output_tensor_name =
445 strings::StrCat(input_edge->src()->name(), ":", input_edge->src_output());
446 std::vector<Tensor> outputs;
447
448 // NOTE; we should pass in a function library runtime if we want
449 // to support constant-expression evaluation on functions.
450 Status s = graph_runner_.Run(&subgraph, nullptr /* function_library */,
451 const_inputs, {output_tensor_name}, &outputs);
452
453 // If all kernels in the constant graph are not registered
454 // in the process, GraphRunner::Run may fail, in which case
455 // we cannot propagate constants, so this is best-effort.
456 if (s.ok()) {
457 *result = outputs[0];
458 *evaluated = true;
459
460 // We memoize (small) constants evaluated so far, so
461 // ExtractConstantSubgraph can avoid extracting the full
462 // subgraph. As we build up large graphs, this avoids
463 // repeated computation of the early parts of a constant
464 // graph.
465 if (outputs[0].TotalBytes() <= kMaxTensorSize) {
466 const_tensor_map_[output_tensor_name] = outputs[0];
467 }
468 }
469 return Status::OK();
470 }
471
TryToInferTensorOutputFromInputShapes(const Edge * edge,Tensor * output,bool * success)472 Status ShapeRefiner::TryToInferTensorOutputFromInputShapes(const Edge* edge,
473 Tensor* output,
474 bool* success) {
475 *success = false;
476 const Node* node = edge->src();
477 auto it = node_to_context_.find(node);
478 if (it == node_to_context_.end()) {
479 return errors::FailedPrecondition("Node does not have context.");
480 }
481 InferenceContext* c = it->second->get_context();
482
483 if (node->type_string() == "Shape") {
484 // If input shapes to the shape op are fully defined,
485 // we can infer the shape op's output tensor.
486 bool fully_defined_inputs = c->FullyDefined(c->input(0));
487 if (fully_defined_inputs) {
488 int input_rank = c->Rank(c->input(0));
489 Tensor t(node->output_type(0), TensorShape({input_rank}));
490 if (node->output_type(0) == DT_INT32) {
491 auto flat = t.flat<int>();
492 for (int i = 0; i < input_rank; i++) {
493 int64 dimension = c->Value(c->Dim(c->input(0), i));
494 if (!FastBoundsCheck(dimension, std::numeric_limits<int32>::max())) {
495 return errors::FailedPrecondition(
496 "Shape has output type int32, but dimension exceeds maximum "
497 "int32 value");
498 }
499 flat(i) = static_cast<int32>(dimension);
500 }
501 } else if (node->output_type(0) == DT_INT64) {
502 auto flat = t.flat<int64>();
503 for (int i = 0; i < input_rank; i++) {
504 flat(i) = c->Value(c->Dim(c->input(0), i));
505 }
506 } else {
507 return errors::FailedPrecondition(
508 "Shape has output type that is not int32 or int64");
509 }
510 *output = t;
511 *success = true;
512 }
513 } else if (node->type_string() == "Rank") {
514 bool rank_known = c->RankKnown(c->input(0));
515 if (rank_known) {
516 int32 input_rank = c->Rank(c->input(0));
517 Tensor t(node->output_type(0), TensorShape({}));
518 t.flat<int32>()(0) = input_rank;
519 *output = t;
520 *success = true;
521 }
522 } else if (node->type_string() == "Size") {
523 bool fully_defined_inputs = c->FullyDefined(c->input(0));
524 if (fully_defined_inputs) {
525 int32 rank = c->Rank(c->input(0));
526 Tensor t(node->output_type(0), TensorShape({}));
527 int64 size = 1;
528 for (int i = 0; i < rank; i++) {
529 size *= c->Value(c->Dim(c->input(0), i));
530 }
531 if (node->output_type(0) == DT_INT32) {
532 if (!FastBoundsCheck(size, std::numeric_limits<int32>::max())) {
533 return errors::FailedPrecondition(
534 "Size has output type int32, but size exceeds maximum int32 "
535 "value");
536 }
537 t.flat<int32>()(0) = static_cast<int32>(size);
538 } else if (node->output_type(0) == DT_INT64) {
539 t.flat<int64>()(0) = size;
540 } else {
541 return errors::FailedPrecondition(
542 "Size has output type that is not int32 or int64");
543 }
544 *output = t;
545 *success = true;
546 }
547 }
548 return Status::OK();
549 }
550
ExtractConstantSubgraph(Node * target_node,Graph * out_graph,bool * is_constant_graph,std::vector<std::pair<string,Tensor>> * const_inputs)551 Status ShapeRefiner::ExtractConstantSubgraph(
552 Node* target_node, Graph* out_graph, bool* is_constant_graph,
553 std::vector<std::pair<string, Tensor>>* const_inputs) {
554 *is_constant_graph = false;
555 std::unordered_set<string> const_inputs_added;
556
557 if (target_node->op_def().is_stateful()) {
558 return Status::OK();
559 }
560
561 if (target_node->type_string() == "PlaceholderWithDefault") {
562 return Status::OK();
563 }
564
565 // TODO(skyewm): more of the filtering applied in input nodes below should be
566 // applied to target_node here
567
568 struct NodeAndRecursed {
569 Node* new_node = nullptr;
570 bool recursed = false;
571 };
572
573 std::map<Node*, NodeAndRecursed> old_to_new_and_recursed;
574 Node* target_node_copy = out_graph->CopyNode(target_node);
575 old_to_new_and_recursed[target_node].new_node = target_node_copy;
576 old_to_new_and_recursed[target_node].recursed = true;
577
578 // Add the target node's inputs to seed the recursion.
579 std::deque<const Edge*> edges_to_visit;
580 for (const Edge* e : target_node->in_edges()) {
581 // TODO(vrv): What do we do about control edges? Based on our
582 // definition of a constant graph, we should be free to ignore
583 // control edges since the order in which a constant graph is
584 // executed should be the same regardless of when nodes run: we
585 // should only need to recurse down data edges.
586 if (e->IsControlEdge()) continue;
587 edges_to_visit.push_back(e);
588 }
589
590 *is_constant_graph = true;
591
592 // Iterate over the set of edges to visit (backwards).
593 while (!edges_to_visit.empty()) {
594 const Edge* current_edge = edges_to_visit.front();
595 edges_to_visit.pop_front();
596 Node* current_node = current_edge->src();
597
598 // If the node is stateful, assume the graph is not constant.
599 if (current_node->op_def().is_stateful()) {
600 *is_constant_graph = false;
601 return Status::OK();
602 }
603
604 // During construction or import from GraphConstructor, back edges may not
605 // be filled in. Don't constant fold through merges at all for now.
606 if (IsMerge(current_node)) {
607 *is_constant_graph = false;
608 return Status::OK();
609 }
610
611 // Don't constant fold enter/exit currently either, as it's easy to end
612 // up with a partial frame.
613 if (IsEnter(current_node) || IsExit(current_node)) {
614 *is_constant_graph = false;
615 return Status::OK();
616 }
617
618 // Placeholders should never be constant folded because their outputs are
619 // fed by the user. Note that "Placeholder" nodes have no inputs so are
620 // handled below.
621 if (current_node->type_string() == "PlaceholderWithDefault") {
622 *is_constant_graph = false;
623 return Status::OK();
624 }
625
626 // If there is nothing more to recurse down, see if
627 // the generator node is a constant.
628 if (current_node->num_inputs() == 0) {
629 if (!current_node->IsConstant()) {
630 // Generator node is not a constant, so subgraph is not
631 // constant.
632 *is_constant_graph = false;
633 return Status::OK();
634 }
635 }
636
637 // Either the node is a constant, or the node is a potential
638 // intermediate node on the path from a constant.
639 //
640 // Add a copy of its node and a new edge to the new subgraph.
641
642 // Get or create the version of 'current_node' in the new graph.
643 Node* current_node_copy;
644 // This gets or creates the NodeAndRecursed entry for current_node.
645 NodeAndRecursed* node_and_recursed = &old_to_new_and_recursed[current_node];
646 if (node_and_recursed->new_node == nullptr) {
647 // First time processing this node.
648 current_node_copy = out_graph->CopyNode(current_node);
649 // Track the mapping from the original node to the new one.
650 node_and_recursed->new_node = current_node_copy;
651 } else {
652 current_node_copy = node_and_recursed->new_node;
653 }
654
655 // Add the edge to the destination node.
656 {
657 auto it = old_to_new_and_recursed.find(current_edge->dst());
658 if (it == old_to_new_and_recursed.end()) {
659 return errors::Internal(
660 "Could not find mapping from old to new copy of destination node: ",
661 current_edge->dst()->name());
662 }
663 Node* dst_copy = it->second.new_node;
664
665 out_graph->AddEdge(current_node_copy, current_edge->src_output(),
666 dst_copy, current_edge->dst_input());
667 }
668
669 const string& output_tensor_name =
670 strings::StrCat(current_node->name(), ":", current_edge->src_output());
671
672 // Some tensor values can be inferred. For example, a shape op
673 // with input shapes fully defined can have its output tensor inferred.
674 Tensor tensor_inferred;
675 bool successfully_inferred_tensor = false;
676 TF_RETURN_IF_ERROR(TryToInferTensorOutputFromInputShapes(
677 current_edge, &tensor_inferred, &successfully_inferred_tensor));
678 if (successfully_inferred_tensor) {
679 const_inputs->emplace_back(output_tensor_name, tensor_inferred);
680 const_inputs_added.insert(output_tensor_name);
681 continue;
682 }
683
684 // If we have a copy of the input tensor materialized already,
685 // then add to the list of inputs to feed and do not recurse further.
686 auto it = const_tensor_map_.find(output_tensor_name);
687 if (it != const_tensor_map_.end() &&
688 const_inputs_added.count(output_tensor_name) == 0) {
689 const_inputs->emplace_back(output_tensor_name, it->second);
690 const_inputs_added.insert(output_tensor_name);
691 continue;
692 }
693
694 // If this node's inputs have not been processed already, do so now.
695 if (!node_and_recursed->recursed) {
696 node_and_recursed->recursed = true;
697 for (const Edge* e : current_node->in_edges()) {
698 if (e->IsControlEdge()) continue;
699 edges_to_visit.push_back(e);
700 }
701 }
702 }
703
704 return Status::OK();
705 }
706
ConstantPartialShape(InferenceContext * target_context,const Node * node,int dst_idx,ShapeHandle * result)707 Status ShapeRefiner::ConstantPartialShape(InferenceContext* target_context,
708 const Node* node, int dst_idx,
709 ShapeHandle* result) {
710 const Edge* input_edge;
711 TF_RETURN_IF_ERROR(node->input_edge(dst_idx, &input_edge));
712
713 InferenceContext* src_context = GetContext(input_edge->src());
714 if (src_context == nullptr) return errors::Internal("Missing src context");
715 ShapeHandle src_shape = src_context->output(input_edge->src_output());
716 TF_RETURN_IF_ERROR(src_context->WithRank(src_shape, 1, &src_shape));
717
718 const string& src_op = input_edge->src()->type_string();
719 if (src_context->Value(src_context->Dim(src_shape, 0)) == 0) {
720 // Source tensor is a vector of length 0, so the shape it
721 // represents is as scalar.
722 *result = target_context->Scalar();
723 } else if (src_op == "Shape") {
724 *result = src_context->input(0);
725 } else if (src_op == "ShapeN") {
726 *result = src_context->input(input_edge->src_output());
727 } else if (src_op == "Pack") {
728 std::vector<DimensionHandle> dims;
729 // Pack is concatenating its input scalars to form the shape tensor vector.
730 for (int i = 0; i < src_context->num_inputs(); ++i) {
731 Tensor scalar;
732 bool evaluated = false;
733 TF_RETURN_IF_ERROR(EvaluateConstantTensorForEdge(input_edge->src(), i,
734 &evaluated, &scalar));
735 if (evaluated) {
736 int64 size;
737 if (scalar.dtype() == DT_INT32) {
738 size = scalar.scalar<int32>()();
739 } else if (scalar.dtype() == DT_INT64) {
740 size = scalar.scalar<int64>()();
741 } else {
742 return errors::InvalidArgument("Pack input must be int32 or int64");
743 }
744 dims.push_back(size < 0 ? target_context->UnknownDim()
745 : target_context->MakeDim(size));
746 } else {
747 dims.push_back(target_context->UnknownDim());
748 }
749 }
750 *result = target_context->MakeShape(dims);
751 } else if (src_op == "Concat" || src_op == "ConcatV2") {
752 *result = target_context->Scalar();
753 // For Concat, input 0 is concat dim; for V2 it is the last input.
754 const int concat_dim =
755 src_op == "Concat" ? 0 : src_context->num_inputs() - 1;
756 // Concat is concatenating its input shape vectors.
757 for (int i = 0; i < src_context->num_inputs(); ++i) {
758 // Concat dim is ignored (and will always be a scalar).
759 if (i == concat_dim) continue;
760 ShapeHandle sub_result;
761 TF_RETURN_IF_ERROR(ConstantPartialShape(target_context, input_edge->src(),
762 i, &sub_result));
763 if (!target_context->RankKnown(sub_result)) {
764 // Failed to evaluate. Treat the output as completely unknown.
765 // TODO(cwhipkey): we could rely on all inputs being the same rank, so
766 // figure that rank out and append the right number of unknown dims.
767 *result = target_context->UnknownShape();
768 return Status::OK();
769 }
770 TF_RETURN_IF_ERROR(
771 target_context->Concatenate(*result, sub_result, result));
772 }
773 } else {
774 Tensor t;
775 bool evaluated = false;
776 TF_RETURN_IF_ERROR(
777 EvaluateConstantTensorForEdge(node, dst_idx, &evaluated, &t));
778 TF_RETURN_IF_ERROR(target_context->MakeShapeFromTensor(
779 evaluated ? &t : nullptr, src_shape, result));
780 }
781 return Status::OK();
782 }
783
RunShapeFn(const Node * node,const OpRegistrationData * op_reg_data,ExtendedInferenceContext * ec)784 Status ShapeRefiner::RunShapeFn(const Node* node,
785 const OpRegistrationData* op_reg_data,
786 ExtendedInferenceContext* ec) {
787 // This will be filled in with real data in a second pass.
788 std::vector<const Tensor*> input_tensors(node->num_inputs(), nullptr);
789 std::vector<Tensor> real_tensors(node->num_inputs());
790 std::vector<bool> attempted_materialization(node->num_inputs());
791 std::vector<bool> attempted_tensor_as_shape_conversion(node->num_inputs());
792 std::vector<ShapeHandle> input_tensors_as_shapes;
793
794 auto* c = ec->get_context();
795
796 c->set_input_tensors(input_tensors);
797 c->set_input_tensors_as_shapes(input_tensors_as_shapes);
798
799 // Run the shape inference function, and return if there was an error.
800 // Capture as lambda, because we might need to re-run inference later on.
801 auto run_inference_lambda = [&]() {
802 if (function_library_ && op_reg_data->is_function_op) {
803 // Special inference logic for user-defined functions.
804
805 auto* func_def = function_library_->Find(op_reg_data->op_def.name());
806 if (func_def) {
807 return InferShapesForFunction(func_def, keep_nested_shape_inferences_,
808 ec);
809 }
810 }
811
812 if (op_reg_data->shape_inference_fn) {
813 TF_RETURN_IF_ERROR(c->Run(op_reg_data->shape_inference_fn));
814 } else {
815 TF_RETURN_IF_ERROR(c->Run(shape_inference::UnknownShape));
816 }
817 return Status::OK();
818 };
819 TF_RETURN_IF_ERROR(run_inference_lambda());
820
821 // We must run the shape function repeatedly, in case users write
822 // shape functions where they only conditionally call input_tensor()
823 // based on the values of another input tensor.
824 bool rerun_shape_fn;
825 do {
826 // If the result of running shape inference would have benefitted
827 // from knowing the values of input tensors, try to materialize
828 // the results of those tensors, and then run the shape inference
829 // function again using those known tensors.
830 rerun_shape_fn = false;
831
832 // NOTE: It is possible to batch the extraction and
833 // materialization of inputs, instead of materializing one input
834 // at a time like we do below. If input-at-a-time computation
835 // becomes a bottleneck, we could separate ExtractConstantSubgraph
836 // into two functions: one that returns true if an input is
837 // derivable from constants, and another function that extracts
838 // the subgraph for multiple target nodes and executes the whole
839 // subgraph once.
840
841 for (int i = 0; i < c->num_inputs(); ++i) {
842 if (!c->requested_input_tensor(i)) {
843 continue;
844 }
845 // Check if we have not already filled in the requested input,
846 // and if not, try to materialize the tensors.
847 if (!attempted_materialization[i]) {
848 attempted_materialization[i] = true;
849
850 Tensor result;
851 bool evaluated = false;
852 TF_RETURN_IF_ERROR(
853 EvaluateConstantTensorForEdge(node, i, &evaluated, &result));
854 if (evaluated) {
855 real_tensors[i] = result;
856 input_tensors[i] = &real_tensors[i];
857 // We have more concrete information about a shape,
858 // so re-run shape inference.
859 rerun_shape_fn = true;
860 }
861 }
862 if (c->requested_input_tensor_as_partial_shape(i) &&
863 !attempted_tensor_as_shape_conversion[i]) {
864 attempted_tensor_as_shape_conversion[i] = true;
865 if (i >= input_tensors_as_shapes.size()) {
866 input_tensors_as_shapes.resize(i + 1);
867 }
868 ShapeHandle s;
869 TF_RETURN_IF_ERROR(ConstantPartialShape(c, node, i, &s));
870 input_tensors_as_shapes[i] = s;
871 rerun_shape_fn = true;
872 }
873 }
874
875 if (rerun_shape_fn) {
876 // We have more information about the shapes on this pass,
877 // so re-run shape inference.
878 c->set_input_tensors(input_tensors);
879 c->set_input_tensors_as_shapes(input_tensors_as_shapes);
880 TF_RETURN_IF_ERROR(run_inference_lambda());
881 }
882 } while (rerun_shape_fn);
883
884 return Status::OK();
885 }
886
SameDefinedShape(InferenceContext * c,ShapeHandle s0,ShapeHandle s1)887 bool ShapeRefiner::SameDefinedShape(InferenceContext* c, ShapeHandle s0,
888 ShapeHandle s1) {
889 if (s0.SameHandle(s1)) {
890 return true;
891 }
892 if (c->Rank(s0) != c->Rank(s1)) {
893 return false;
894 }
895 if (!c->RankKnown(s0) && !c->RankKnown(s1)) {
896 return false;
897 }
898 for (int i = 0; i < c->Rank(s0); ++i) {
899 if (!c->Dim(s0, i).SameHandle(c->Dim(s1, i))) {
900 int64 val0 = c->Value(c->Dim(s0, i));
901 int64 val1 = c->Value(c->Dim(s1, i));
902 if (val0 < 0 || val1 < 0 || val0 != val1) {
903 return false;
904 }
905 }
906 }
907
908 return true;
909 }
910
IsUpdatedShapesOrTypes(InferenceContext * c,const std::vector<ShapeAndType> & existing,const std::vector<ShapeAndType> & updated)911 bool ShapeRefiner::IsUpdatedShapesOrTypes(
912 InferenceContext* c, const std::vector<ShapeAndType>& existing,
913 const std::vector<ShapeAndType>& updated) {
914 if (existing.size() != updated.size()) {
915 return true;
916 }
917 for (int i = 0; i < existing.size(); i++) {
918 if (!SameDefinedShape(c, existing[i].shape, updated[i].shape) ||
919 existing[i].dtype != updated[i].dtype) {
920 return true;
921 }
922 }
923 return false;
924 }
925
926 } // namespace tensorflow
927