1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include <algorithm>
17 #include <atomic>
18 #include <set>
19 #include <unordered_map>
20 #include <vector>
21
22 #include "tensorflow/core/common_runtime/constant_folding.h"
23
24 #include "tensorflow/core/common_runtime/device_factory.h"
25 #include "tensorflow/core/common_runtime/executor.h"
26 #include "tensorflow/core/common_runtime/function.h"
27 #include "tensorflow/core/common_runtime/graph_runner.h"
28 #include "tensorflow/core/common_runtime/memory_types.h"
29 #include "tensorflow/core/common_runtime/rendezvous_mgr.h"
30 #include "tensorflow/core/framework/log_memory.h"
31 #include "tensorflow/core/framework/op_kernel.h"
32 #include "tensorflow/core/framework/types.h"
33 #include "tensorflow/core/graph/algorithm.h"
34 #include "tensorflow/core/graph/node_builder.h"
35 #include "tensorflow/core/graph/subgraph.h"
36 #include "tensorflow/core/lib/core/threadpool.h"
37 #include "tensorflow/core/lib/gtl/cleanup.h"
38 #include "tensorflow/core/lib/gtl/flatset.h"
39 #include "tensorflow/core/lib/strings/strcat.h"
40 #include "tensorflow/core/public/session_options.h"
41
42 namespace tensorflow {
43
44 namespace {
45
46 // Test to see if the Op is one that turns into a constant when its
47 // inputs' shapes are known.
IsShapeOp(const Node * n)48 bool IsShapeOp(const Node* n) {
49 const auto& ts = n->type_string();
50 return ts == "Shape" || ts == "ShapeN" || ts == "Rank" || ts == "Size";
51 }
52
53 // Reads the partially-known shape of each of n's inputs from shape_map, and
54 // stores it to input_shapes. Returns false if any input does not have a shape
55 // in shape_map.
ReadPartialShapesFromShapeMap(const Node * n,const std::unordered_map<string,std::vector<PartialTensorShape>> * shape_map,std::vector<PartialTensorShape> * input_shapes)56 bool ReadPartialShapesFromShapeMap(
57 const Node* n,
58 const std::unordered_map<string, std::vector<PartialTensorShape>>*
59 shape_map,
60 std::vector<PartialTensorShape>* input_shapes) {
61 CHECK(shape_map != nullptr);
62 for (const Edge* in : n->in_edges()) {
63 // Don't need to check if incoming control edges have known shapes.
64 if (in->IsControlEdge()) continue;
65 const auto known_shape_iter = shape_map->find(in->src()->name());
66 if (known_shape_iter == shape_map->end()) {
67 // One of n's inputs doesn't have known shapes, so don't replace n.
68 return false;
69 }
70 const auto& known_shape = known_shape_iter->second;
71 CHECK_GT(known_shape.size(), in->src_output()) << known_shape_iter->first;
72 input_shapes->push_back(known_shape[in->src_output()]);
73 }
74 return true;
75 }
76
77 // If all of n's inputs have fully-defined shapes, inserts those shapes as a
78 // vector of Tensors in the shape_replacement_map.
MaybeReplaceShapeOrShapeNOp(const Node * n,const std::vector<PartialTensorShape> & input_shapes,std::unordered_map<const Node *,std::vector<Tensor>> * shape_replacement_map)79 bool MaybeReplaceShapeOrShapeNOp(
80 const Node* n, const std::vector<PartialTensorShape>& input_shapes,
81 std::unordered_map<const Node*, std::vector<Tensor>>*
82 shape_replacement_map) {
83 std::vector<Tensor> defined_shape;
84 for (const auto& shape : input_shapes) {
85 if (!shape.IsFullyDefined()) {
86 return false;
87 }
88 const int rank = shape.dims();
89 DataType op_type = n->output_type(0);
90 Tensor t(op_type, TensorShape({rank}));
91 if (op_type == DT_INT64) {
92 auto vec = t.vec<int64>();
93 for (int i = 0; i < rank; ++i) {
94 vec(i) = shape.dim_size(i);
95 }
96 } else {
97 CHECK(op_type == DT_INT32);
98 auto vec = t.vec<int32>();
99 for (int i = 0; i < rank; ++i) {
100 if (shape.dim_size(i) > INT_MAX) {
101 VLOG(1) << "Node " << n->name() << " has input shape dimension " << i
102 << " of " << shape.dim_size(i) << " but type INT32 "
103 << " so not replacing as constant: this will trigger a "
104 "runtime error later.";
105 return false;
106 }
107 vec(i) = static_cast<int32>(shape.dim_size(i));
108 }
109 }
110 defined_shape.push_back(t);
111 }
112 // All the inputs had known shapes so we can replace the node by constants
113 // later in the rewrite.
114 shape_replacement_map->insert({n, defined_shape});
115 return true;
116 }
117
118 // If n's input has defined rank, inserts that rank as a Tensor in the
119 // shape_replacement_map.
MaybeReplaceRankOp(const Node * n,const std::vector<PartialTensorShape> & input_shapes,std::unordered_map<const Node *,std::vector<Tensor>> * shape_replacement_map)120 bool MaybeReplaceRankOp(const Node* n,
121 const std::vector<PartialTensorShape>& input_shapes,
122 std::unordered_map<const Node*, std::vector<Tensor>>*
123 shape_replacement_map) {
124 CHECK_EQ(input_shapes.size(), 1);
125 if (input_shapes[0].unknown_rank()) {
126 return false;
127 }
128 Tensor t(DT_INT32, TensorShape({}));
129 t.scalar<int32>()() = input_shapes[0].dims();
130 shape_replacement_map->insert({n, {t}});
131 return true;
132 }
133
134 // If n's input has defined size, inserts that size as a Tensor in the
135 // shape_replacement_map.
MaybeReplaceSizeOp(const Node * n,const std::vector<PartialTensorShape> & input_shapes,std::unordered_map<const Node *,std::vector<Tensor>> * shape_replacement_map)136 bool MaybeReplaceSizeOp(const Node* n,
137 const std::vector<PartialTensorShape>& input_shapes,
138 std::unordered_map<const Node*, std::vector<Tensor>>*
139 shape_replacement_map) {
140 CHECK_EQ(input_shapes.size(), 1);
141 if (!input_shapes[0].IsFullyDefined()) {
142 return false;
143 }
144 DataType op_type = n->output_type(0);
145 Tensor t(op_type, TensorShape({}));
146 int64 size = input_shapes[0].num_elements();
147 if (op_type == DT_INT64) {
148 t.scalar<int64>()() = size;
149 } else {
150 CHECK(op_type == DT_INT32);
151 if (size > INT_MAX) {
152 VLOG(1) << "Node " << n->name() << " has input shape size " << size
153 << " but type INT32 "
154 << " so not replacing as constant: this will trigger a runtime "
155 "error later.";
156 return false;
157 }
158 t.scalar<int32>()() = static_cast<int32>(size);
159 }
160 shape_replacement_map->insert({n, {t}});
161 return true;
162 }
163
164 // If n is a shape Op (Shape, ShapeN, Rank, or Size) and its inputs have their
165 // shapes specified in shape_map, then adds to shape_replacement_map a mapping
166 // from n to a vector of Tensors, where Tensor k is the (statically known) value
167 // on n's kth output edge. shape_replacement_map has an entry for n iff
168 // MaybeReplaceShapeOp returns true, so it's valid to use
169 // shape_replacement_map->count(n) as a test to see if n is a shape op that can
170 // be replaced.
MaybeReplaceShapeOp(const Node * n,const std::unordered_map<string,std::vector<PartialTensorShape>> * shape_map,std::unordered_map<const Node *,std::vector<Tensor>> * shape_replacement_map)171 bool MaybeReplaceShapeOp(
172 const Node* n,
173 const std::unordered_map<string, std::vector<PartialTensorShape>>*
174 shape_map,
175 std::unordered_map<const Node*, std::vector<Tensor>>*
176 shape_replacement_map) {
177 if (shape_map == nullptr || !IsShapeOp(n)) {
178 return false;
179 }
180 // input_shapes will contain the shapes of each of n's inputs.
181 std::vector<PartialTensorShape> input_shapes;
182 if (!ReadPartialShapesFromShapeMap(n, shape_map, &input_shapes)) {
183 return false;
184 }
185 const auto& ts = n->type_string();
186 if (ts == "Shape" || ts == "ShapeN") {
187 if (!MaybeReplaceShapeOrShapeNOp(n, input_shapes, shape_replacement_map)) {
188 return false;
189 }
190 } else if (ts == "Rank") {
191 if (!MaybeReplaceRankOp(n, input_shapes, shape_replacement_map)) {
192 return false;
193 }
194 } else {
195 CHECK_EQ(ts, "Size");
196 if (!MaybeReplaceSizeOp(n, input_shapes, shape_replacement_map)) {
197 return false;
198 }
199 }
200 return true;
201 }
202
203 // Returns true if n can be evaluated as constant. shape_map maps from
204 // nodes to the partially-known shapes of their outputs. consider if
205 // non-null returns a bool indicating whether a given (non-Const,
206 // non-Shape) node is eligible to be
207 // constant-propagated. shape_replacement_map is filled in with a
208 // vector of constant output tensors for constant-foldable shape nodes
209 // (Shape, ShapeN, Size, or Rank).
IsConstantFoldable(const Node * n,const std::unordered_map<string,std::vector<PartialTensorShape>> * shape_map,const std::function<bool (const Node *)> & consider,std::unordered_map<const Node *,std::vector<Tensor>> * shape_replacement_map)210 bool IsConstantFoldable(
211 const Node* n,
212 const std::unordered_map<string, std::vector<PartialTensorShape>>*
213 shape_map,
214 const std::function<bool(const Node*)>& consider,
215 std::unordered_map<const Node*, std::vector<Tensor>>*
216 shape_replacement_map) {
217 if (n->IsConstant()) {
218 return true;
219 }
220 if (MaybeReplaceShapeOp(n, shape_map, shape_replacement_map)) {
221 return true;
222 }
223 if (n->op_def().is_stateful()) {
224 return false;
225 }
226 if (consider && !consider(n)) {
227 return false;
228 }
229 if (n->IsControlFlow() || n->IsSend() || n->IsRecv()) {
230 return false;
231 }
232 // TODO(yuanbyu): For now disable these session handle operations.
233 if (n->IsGetSessionHandle() || n->IsGetSessionTensor() ||
234 n->IsDeleteSessionTensor()) {
235 return false;
236 }
237 if (n->IsSource()) {
238 return false;
239 }
240 if (n->IsSink()) {
241 return false;
242 }
243 // Since constant-folding runs on the CPU, do not attempt to constant-fold
244 // operators that have no CPU kernel. Also implies that we will not
245 // constant-fold functions.
246 // TODO(phawkins): allow constant-folding for functions; functions may
247 // be arbitrarily expensive to execute.
248 if (!FindKernelDef(DeviceType(DEVICE_CPU), n->def(), /*def=*/nullptr,
249 /*kernel_class_name=*/nullptr)
250 .ok()) {
251 return false;
252 }
253
254 return true;
255 }
256
257 // If n is eligible for constant-folding, adds it to nodes, and places its
258 // control dependencies and those transitively of its constant-foldable inputs
259 // into constant_control_deps. If n is a constant-foldable shape node (Shape,
260 // ShapeN, Rank, or Size), also puts its outputs into shape_replacement_map.
ConsiderConstantFoldableNode(Node * n,const ConstantFoldingOptions & opts,std::vector<Node * > * nodes,std::unordered_map<const Node *,gtl::FlatSet<Node * >> * constant_control_deps,std::unordered_map<const Node *,std::vector<Tensor>> * shape_replacement_map,bool * internal_node_inserted)261 void ConsiderConstantFoldableNode(
262 Node* n, const ConstantFoldingOptions& opts, std::vector<Node*>* nodes,
263 std::unordered_map<const Node*, gtl::FlatSet<Node*>>* constant_control_deps,
264 std::unordered_map<const Node*, std::vector<Tensor>>* shape_replacement_map,
265 bool* internal_node_inserted) {
266 if (IsConstantFoldable(n, opts.shape_map, opts.consider,
267 shape_replacement_map)) {
268 // A node is constant provided all of its non-control incoming Tensors come
269 // from constant nodes, or it's a shape Op with statically known inputs in
270 // which case it is placed in shape_replacement_map.
271 //
272 // We allow control dependencies from non-constant nodes to constant nodes,
273 // but to preserve the graph structure we must transfer the control
274 // dependency onto any constant replacement.
275 bool all_parents_constant = true;
276 for (const Edge* in : n->in_edges()) {
277 // Allows non-constant -> constant control edges.
278 if (!in->IsControlEdge() &&
279 constant_control_deps->count(in->src()) == 0) {
280 all_parents_constant = false;
281 break;
282 }
283 }
284 if (all_parents_constant || shape_replacement_map->count(n) != 0) {
285 gtl::FlatSet<Node*>& control_deps = (*constant_control_deps)[n];
286 for (const Edge* e : n->in_edges()) {
287 if (constant_control_deps->count(e->src()) == 0) {
288 // This branch is taken if the incoming edge is a control dependency,
289 // in which case we want to add it to the dependencies being
290 // accumulated for this node, or the incoming edge is not
291 // constant. The latter may happen when n is a shape node and the
292 // source has known shape. In that case add a control dependency from
293 // the source node, since there was previously a data dependency and
294 // we want to preserve sequencing constraints.
295 if (!e->src()->IsSource()) {
296 control_deps.insert(e->src());
297 }
298 } else {
299 // If the parent has been accumulating control dependencies, add all
300 // of its transitive control deps.
301 const gtl::FlatSet<Node*>& parent_deps =
302 (*constant_control_deps)[e->src()];
303 control_deps.insert(parent_deps.begin(), parent_deps.end());
304 }
305 }
306 nodes->push_back(n);
307 if (!n->IsConstant()) {
308 *internal_node_inserted = true;
309 }
310 }
311 }
312 }
313
314 // Returns the constant foldable nodes in `nodes` in topological order.
315 // Populates `constant_control_deps` with the non-constant control dependencies
316 // of each constant node.
FindConstantFoldableNodes(const Graph * graph,const ConstantFoldingOptions & opts,std::vector<Node * > * nodes,std::unordered_map<const Node *,gtl::FlatSet<Node * >> * constant_control_deps,std::unordered_map<const Node *,std::vector<Tensor>> * shape_replacement_map)317 void FindConstantFoldableNodes(
318 const Graph* graph, const ConstantFoldingOptions& opts,
319 std::vector<Node*>* nodes,
320 std::unordered_map<const Node*, gtl::FlatSet<Node*>>* constant_control_deps,
321 std::unordered_map<const Node*, std::vector<Tensor>>*
322 shape_replacement_map) {
323 bool internal_node_inserted = false;
324 // Walk the nodes in data flow order.
325 ReverseDFS(*graph, nullptr,
326 [nodes, constant_control_deps, shape_replacement_map,
327 &internal_node_inserted, &opts](Node* n) {
328 ConsiderConstantFoldableNode(
329 n, opts, nodes, constant_control_deps, shape_replacement_map,
330 &internal_node_inserted);
331 },
332 NodeComparatorName());
333 // If we have inserted just leaf level nodes, then there is nothing to fold.
334 if (!internal_node_inserted) {
335 nodes->clear();
336 constant_control_deps->clear();
337 }
338 }
339
340 typedef std::pair<Node*, int> NodeAndOutput;
341
UniqueConstantId()342 int64 UniqueConstantId() {
343 static std::atomic_int_fast64_t unique_constant_id;
344 return unique_constant_id.fetch_add(1);
345 }
346
347 // Adds n to constant_graph which is being built up for subsequent evaluation of
348 // constant propagation. node_map is the mapping of nodes in the original graph
349 // to nodes in the constant graph. The value of an entry in node_map is a vector
350 // of nodes because a ShapeN node in the original graph is replaced by a vector
351 // of Constant nodes in the constant graph.
AddNodeToConstantGraph(Node * n,std::unordered_map<Node *,std::vector<Node * >> * node_map,Graph * constant_graph)352 void AddNodeToConstantGraph(
353 Node* n, std::unordered_map<Node*, std::vector<Node*>>* node_map,
354 Graph* constant_graph) {
355 std::vector<Node*>& added = (*node_map)[n];
356 added.push_back(constant_graph->CopyNode(n));
357 for (const Edge* in_edge : n->in_edges()) {
358 // Don't copy control edges to the constant graph.
359 if (!in_edge->IsControlEdge()) {
360 Node* in = in_edge->src();
361 auto it = node_map->find(in);
362 CHECK(it != node_map->end())
363 << n->DebugString() << " <-" << in->DebugString();
364 if (it->second.size() == 1) {
365 constant_graph->AddEdge(it->second[0], in_edge->src_output(), added[0],
366 in_edge->dst_input());
367 } else {
368 // The original source node had multiple outputs and was replaced by a
369 // vector of constants, so the edge comes from the 0th output of the kth
370 // added constant, rather than the kth output of the added node as in
371 // the standard case above.
372 constant_graph->AddEdge(it->second[in_edge->src_output()], 0, added[0],
373 in_edge->dst_input());
374 }
375 }
376 }
377 }
378
379 // Replaces constant-foldable shape node n by a vector of constants in
380 // constant_graph, which is being built up for subsequent evaluation of constant
381 // propagation. node_map is the mapping of nodes in the original graph to nodes
382 // in the constant graph. The value of an entry in node_map is a vector of nodes
383 // because a ShapeN node in the original graph is replaced by a vector of
384 // Constant nodes in the constant graph.
AddShapeNodeToConstantGraph(Node * n,const std::unordered_map<const Node *,std::vector<Tensor>> & shape_replacement_map,std::unordered_map<Node *,std::vector<Node * >> * node_map,const ConstantFoldNameGenerator & generate_new_name,Graph * constant_graph)385 void AddShapeNodeToConstantGraph(
386 Node* n,
387 const std::unordered_map<const Node*, std::vector<Tensor>>&
388 shape_replacement_map,
389 std::unordered_map<Node*, std::vector<Node*>>* node_map,
390 const ConstantFoldNameGenerator& generate_new_name, Graph* constant_graph) {
391 std::vector<Node*>& added = (*node_map)[n];
392 const string& node_name = n->name();
393 for (const Tensor& t : shape_replacement_map.at(n)) {
394 auto builder =
395 NodeDefBuilder(generate_new_name(constant_graph, node_name), "Const")
396 .Attr("dtype", t.dtype())
397 .Attr("value", t);
398 NodeDef def;
399 CHECK(builder.Finalize(&def).ok());
400 Node* constant_node;
401 CHECK(NodeBuilder(builder).Finalize(constant_graph, &constant_node).ok());
402 added.push_back(constant_node);
403 }
404 // Don't copy incoming edges to shape nodes that are being replaced.
405 }
406
407 // Given the constant foldable nodes in 'nodes', returns a new graph 'g'. 'g'
408 // will contain copies of the nodes in 'nodes'. In addition, if there is an edge
409 // going from a node 'n' in 'nodes' to another node in 'orig_graph' but not in
410 // 'nodes', then 'tensors_to_fetch' will contain the mapping from the
411 // corresponding copy of 'n' and the edge number in 'g' to 'n'.
GetConstantGraph(const Graph * orig_graph,const std::vector<Node * > & nodes,const std::unordered_map<const Node *,std::vector<Tensor>> & shape_replacement_map,std::map<NodeAndOutput,Node * > * tensors_to_fetch,const ConstantFoldNameGenerator & generate_new_name)412 Graph* GetConstantGraph(
413 const Graph* orig_graph, const std::vector<Node*>& nodes,
414 const std::unordered_map<const Node*, std::vector<Tensor>>&
415 shape_replacement_map,
416 std::map<NodeAndOutput, Node*>* tensors_to_fetch,
417 const ConstantFoldNameGenerator& generate_new_name) {
418 Graph* constant_graph = new Graph(orig_graph->op_registry());
419 std::unordered_map<Node*, std::vector<Node*>> node_map;
420 node_map[orig_graph->source_node()] = {constant_graph->source_node()};
421 node_map[orig_graph->sink_node()] = {constant_graph->sink_node()};
422 for (Node* n : nodes) {
423 if (shape_replacement_map.count(n) == 0) {
424 AddNodeToConstantGraph(n, &node_map, constant_graph);
425 } else {
426 AddShapeNodeToConstantGraph(n, shape_replacement_map, &node_map,
427 generate_new_name, constant_graph);
428 }
429 }
430
431 for (auto const& added_nodes : node_map) {
432 for (const Edge* out_edge : added_nodes.first->out_edges()) {
433 if (node_map.count(out_edge->dst()) == 0) {
434 if (out_edge->IsControlEdge()) continue;
435 if (added_nodes.second.size() == 1) {
436 tensors_to_fetch->insert(
437 {{added_nodes.second[0], out_edge->src_output()},
438 added_nodes.first});
439 } else {
440 // The node had multiple outputs and was replaced by a
441 // vector of constants, so the NodeAndOutput is the 0th
442 // output of the kth added constant, rather than the kth
443 // output of the added node as in the standard case above.
444 tensors_to_fetch->insert(
445 {{added_nodes.second[out_edge->src_output()], 0},
446 added_nodes.first});
447 }
448 }
449 }
450 }
451
452 return constant_graph;
453 }
454
455 // Replaces the identified Tensor in 'graph' by a 'Const' node with
456 // the value supplied in 'constant'. 'partition_device', if non-null
457 // is the device where the graph executes. Returns true if the
458 // replacement was successful, false otherwise.
459 // 'control_deps' is the set of nodes that should be control predecessors of the
460 // new constant node.
ReplaceTensorWithConstant(Graph * graph,Device * partition_device,NodeAndOutput tensor,const Tensor & constant,const gtl::FlatSet<Node * > & control_deps,int64 max_constant_size_in_bytes,const ConstantFoldNameGenerator & generate_new_name)461 bool ReplaceTensorWithConstant(
462 Graph* graph, Device* partition_device, NodeAndOutput tensor,
463 const Tensor& constant, const gtl::FlatSet<Node*>& control_deps,
464 int64 max_constant_size_in_bytes,
465 const ConstantFoldNameGenerator& generate_new_name) {
466 // Be conservative when replacing a tensor with a constant, when not
467 // running on CPU.
468 // 1) If the destination tensor is not an int32 tensor, and has HOST_MEMORY
469 // constraint, do not replace it.
470 // 2) If the destination tensor is an int32 tensor, but has DEVICE_MEMORY
471 // constraint, do not replace it.
472 // 3) If the constant op created does not have a kernel implementation
473 // for the device, do not use it.
474 // 4) If the size of the constant in bytes is too large (>
475 // max_constant_in_bytes), do not replace it. This prevents the size of the
476 // Graph from growing too large.
477 // TODO(keveman): Consider adding a new constant op that has a kernel
478 // implementation for all types, but with HostMemory constraint on it's
479 // output.
480 // 5) Do not replace another constant.
481 if (tensor.first->IsConstant()) {
482 return false;
483 }
484 DeviceType device_type = partition_device
485 ? DeviceType{partition_device->device_type()}
486 : DEVICE_CPU;
487 if (partition_device && device_type != DEVICE_CPU) {
488 MemoryType memory_type;
489 if (!MemoryTypeForOutput(device_type, graph, tensor.first, tensor.second,
490 &memory_type)
491 .ok()) {
492 return false;
493 }
494 bool is_int32 = tensor.first->output_type(tensor.second) == DT_INT32;
495 if ((memory_type == HOST_MEMORY && !is_int32) ||
496 (memory_type == DEVICE_MEMORY && is_int32)) {
497 return false;
498 }
499 }
500 if (constant.TotalBytes() > max_constant_size_in_bytes) {
501 return false;
502 }
503
504 Node* n = tensor.first;
505 std::vector<const Edge*> edges_to_remove;
506 for (const Edge* out_edge : n->out_edges()) {
507 if (out_edge->src_output() == tensor.second) {
508 edges_to_remove.push_back(out_edge);
509 }
510 }
511 const string& node_name = n->name();
512 Node* constant_node;
513 auto builder = NodeDefBuilder(generate_new_name(graph, node_name), "Const")
514 .Attr("dtype", constant.dtype())
515 .Attr("value", constant);
516 if (partition_device) {
517 builder.Device(partition_device->name());
518 }
519 NodeDef def;
520 if (!builder.Finalize(&def).ok()) {
521 return false;
522 }
523 const KernelDef* kdef;
524 if (!FindKernelDef(device_type, def, &kdef, nullptr).ok()) {
525 return false;
526 }
527
528 VLOG(1) << "Replacing " << tensor.first->name() << " :: " << tensor.second
529 << " with a constant";
530
531 if (!NodeBuilder(builder).Finalize(graph, &constant_node).ok()) {
532 return false;
533 }
534 for (auto edge : edges_to_remove) {
535 graph->AddEdge(constant_node, 0, edge->dst(), edge->dst_input());
536 graph->RemoveEdge(edge);
537 }
538 if (control_deps.empty()) {
539 graph->AddControlEdge(graph->source_node(), constant_node);
540 } else {
541 for (Node* node : control_deps) {
542 graph->AddControlEdge(node, constant_node);
543 }
544 }
545 if (partition_device) {
546 constant_node->set_assigned_device_name(partition_device->name());
547 }
548 return true;
549 }
550
551 } // namespace
552
ConstantFold(const ConstantFoldingOptions & opts,FunctionLibraryRuntime * function_library,Env * env,Device * partition_device,Graph * graph,bool * was_mutated)553 Status ConstantFold(const ConstantFoldingOptions& opts,
554 FunctionLibraryRuntime* function_library, Env* env,
555 Device* partition_device, Graph* graph, bool* was_mutated) {
556 DumpGraph("Before", graph);
557 ConstantFoldNameGenerator generate_new_name = opts.generate_new_name;
558 if (generate_new_name == nullptr) {
559 generate_new_name = [](Graph* graph, string old_name) {
560 return strings::StrCat(graph->NewName(old_name), "__cf__",
561 UniqueConstantId());
562 };
563 }
564
565 std::vector<Node*> constant_foldable_nodes;
566 std::unordered_map<const Node*, gtl::FlatSet<Node*>> constant_control_deps;
567 std::unordered_map<const Node*, std::vector<Tensor>> shape_replacement_map;
568 FindConstantFoldableNodes(graph, opts, &constant_foldable_nodes,
569 &constant_control_deps, &shape_replacement_map);
570 if (constant_foldable_nodes.empty()) {
571 VLOG(1) << "No constant foldable nodes found";
572 *was_mutated = false;
573 // This is not an error, so return the status as OK.
574 return Status::OK();
575 }
576
577 std::map<NodeAndOutput, Node*> tensors_to_fetch;
578 std::unique_ptr<Graph> constant_graph(
579 GetConstantGraph(graph, constant_foldable_nodes, shape_replacement_map,
580 &tensors_to_fetch, generate_new_name));
581 DumpGraph("Constant graph", constant_graph.get());
582
583 if (tensors_to_fetch.empty()) {
584 VLOG(1) << "No constant nodes found that feed into the original graph.";
585 *was_mutated = false;
586 // This is not an error, so return the status as OK.
587 return Status::OK();
588 }
589 VLOG(1) << "Constant foldable " << constant_graph->num_node_ids() << " : "
590 << graph->num_node_ids();
591
592 std::vector<string> tensors_to_fetch_names;
593 std::vector<NodeAndOutput> tensors_to_replace;
594 // Sorting the nodes based on the name gives us a stable ordering between runs
595 // for the same graph.
596 std::vector<std::pair<NodeAndOutput, Node*>> tensors_to_fetch_sorted(
597 tensors_to_fetch.begin(), tensors_to_fetch.end());
598 std::sort(tensors_to_fetch_sorted.begin(), tensors_to_fetch_sorted.end(),
599 [](const std::pair<NodeAndOutput, Node*>& n1,
600 const std::pair<NodeAndOutput, Node*>& n2) {
601 return n1.first.first->name() < n2.first.first->name();
602 });
603 for (auto n : tensors_to_fetch_sorted) {
604 tensors_to_fetch_names.push_back(
605 strings::StrCat(n.first.first->name(), ":", n.first.second));
606 tensors_to_replace.push_back({n.second, n.first.second});
607 }
608
609 auto graph_runner = std::unique_ptr<GraphRunner>(new GraphRunner(env));
610 // Evaluate the constant foldable nodes.
611 std::vector<Tensor> outputs;
612 auto delete_tensors = gtl::MakeCleanup([&graph_runner, &outputs] {
613 // Output tensors need to be cleared before the GraphRunner is deleted.
614 outputs.clear();
615 graph_runner.reset(nullptr);
616 });
617
618 Status s =
619 graph_runner->Run(constant_graph.get(), function_library, {} /* inputs*/,
620 tensors_to_fetch_names, &outputs);
621 if (!s.ok()) {
622 VLOG(1) << "Could not fetch constants: " << s;
623 *was_mutated = false;
624 return s;
625 }
626
627 // Fetch the constant tensors and replace the corresponding tensors in the
628 // original graph with those constants.
629 int32 num_nodes_replaced = 0;
630 for (size_t c = 0; c < outputs.size(); ++c) {
631 const gtl::FlatSet<Node*>& control_deps =
632 constant_control_deps[tensors_to_replace[c].first];
633 if (ReplaceTensorWithConstant(
634 graph, partition_device, tensors_to_replace[c], outputs[c],
635 control_deps, opts.max_constant_size_in_bytes, generate_new_name)) {
636 ++num_nodes_replaced;
637 }
638 }
639
640 DumpGraph("After", graph);
641
642 *was_mutated = (num_nodes_replaced > 0);
643 return Status::OK();
644 }
645
646 } // namespace tensorflow
647