1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include <algorithm>
17 #include <atomic>
18 #include <set>
19 #include <unordered_map>
20 #include <vector>
21 
22 #include "tensorflow/core/common_runtime/constant_folding.h"
23 
24 #include "tensorflow/core/common_runtime/device_factory.h"
25 #include "tensorflow/core/common_runtime/executor.h"
26 #include "tensorflow/core/common_runtime/function.h"
27 #include "tensorflow/core/common_runtime/graph_runner.h"
28 #include "tensorflow/core/common_runtime/memory_types.h"
29 #include "tensorflow/core/common_runtime/rendezvous_mgr.h"
30 #include "tensorflow/core/framework/log_memory.h"
31 #include "tensorflow/core/framework/op_kernel.h"
32 #include "tensorflow/core/framework/types.h"
33 #include "tensorflow/core/graph/algorithm.h"
34 #include "tensorflow/core/graph/node_builder.h"
35 #include "tensorflow/core/graph/subgraph.h"
36 #include "tensorflow/core/lib/core/threadpool.h"
37 #include "tensorflow/core/lib/gtl/cleanup.h"
38 #include "tensorflow/core/lib/gtl/flatset.h"
39 #include "tensorflow/core/lib/strings/strcat.h"
40 #include "tensorflow/core/public/session_options.h"
41 
42 namespace tensorflow {
43 
44 namespace {
45 
46 // Test to see if the Op is one that turns into a constant when its
47 // inputs' shapes are known.
IsShapeOp(const Node * n)48 bool IsShapeOp(const Node* n) {
49   const auto& ts = n->type_string();
50   return ts == "Shape" || ts == "ShapeN" || ts == "Rank" || ts == "Size";
51 }
52 
53 // Reads the partially-known shape of each of n's inputs from shape_map, and
54 // stores it to input_shapes. Returns false if any input does not have a shape
55 // in shape_map.
ReadPartialShapesFromShapeMap(const Node * n,const std::unordered_map<string,std::vector<PartialTensorShape>> * shape_map,std::vector<PartialTensorShape> * input_shapes)56 bool ReadPartialShapesFromShapeMap(
57     const Node* n,
58     const std::unordered_map<string, std::vector<PartialTensorShape>>*
59         shape_map,
60     std::vector<PartialTensorShape>* input_shapes) {
61   CHECK(shape_map != nullptr);
62   for (const Edge* in : n->in_edges()) {
63     // Don't need to check if incoming control edges have known shapes.
64     if (in->IsControlEdge()) continue;
65     const auto known_shape_iter = shape_map->find(in->src()->name());
66     if (known_shape_iter == shape_map->end()) {
67       // One of n's inputs doesn't have known shapes, so don't replace n.
68       return false;
69     }
70     const auto& known_shape = known_shape_iter->second;
71     CHECK_GT(known_shape.size(), in->src_output()) << known_shape_iter->first;
72     input_shapes->push_back(known_shape[in->src_output()]);
73   }
74   return true;
75 }
76 
77 // If all of n's inputs have fully-defined shapes, inserts those shapes as a
78 // vector of Tensors in the shape_replacement_map.
MaybeReplaceShapeOrShapeNOp(const Node * n,const std::vector<PartialTensorShape> & input_shapes,std::unordered_map<const Node *,std::vector<Tensor>> * shape_replacement_map)79 bool MaybeReplaceShapeOrShapeNOp(
80     const Node* n, const std::vector<PartialTensorShape>& input_shapes,
81     std::unordered_map<const Node*, std::vector<Tensor>>*
82         shape_replacement_map) {
83   std::vector<Tensor> defined_shape;
84   for (const auto& shape : input_shapes) {
85     if (!shape.IsFullyDefined()) {
86       return false;
87     }
88     const int rank = shape.dims();
89     DataType op_type = n->output_type(0);
90     Tensor t(op_type, TensorShape({rank}));
91     if (op_type == DT_INT64) {
92       auto vec = t.vec<int64>();
93       for (int i = 0; i < rank; ++i) {
94         vec(i) = shape.dim_size(i);
95       }
96     } else {
97       CHECK(op_type == DT_INT32);
98       auto vec = t.vec<int32>();
99       for (int i = 0; i < rank; ++i) {
100         if (shape.dim_size(i) > INT_MAX) {
101           VLOG(1) << "Node " << n->name() << " has input shape dimension " << i
102                   << " of " << shape.dim_size(i) << " but type INT32 "
103                   << " so not replacing as constant: this will trigger a "
104                      "runtime error later.";
105           return false;
106         }
107         vec(i) = static_cast<int32>(shape.dim_size(i));
108       }
109     }
110     defined_shape.push_back(t);
111   }
112   // All the inputs had known shapes so we can replace the node by constants
113   // later in the rewrite.
114   shape_replacement_map->insert({n, defined_shape});
115   return true;
116 }
117 
118 // If n's input has defined rank, inserts that rank as a Tensor in the
119 //  shape_replacement_map.
MaybeReplaceRankOp(const Node * n,const std::vector<PartialTensorShape> & input_shapes,std::unordered_map<const Node *,std::vector<Tensor>> * shape_replacement_map)120 bool MaybeReplaceRankOp(const Node* n,
121                         const std::vector<PartialTensorShape>& input_shapes,
122                         std::unordered_map<const Node*, std::vector<Tensor>>*
123                             shape_replacement_map) {
124   CHECK_EQ(input_shapes.size(), 1);
125   if (input_shapes[0].unknown_rank()) {
126     return false;
127   }
128   Tensor t(DT_INT32, TensorShape({}));
129   t.scalar<int32>()() = input_shapes[0].dims();
130   shape_replacement_map->insert({n, {t}});
131   return true;
132 }
133 
134 // If n's input has defined size, inserts that size as a Tensor in the
135 //  shape_replacement_map.
MaybeReplaceSizeOp(const Node * n,const std::vector<PartialTensorShape> & input_shapes,std::unordered_map<const Node *,std::vector<Tensor>> * shape_replacement_map)136 bool MaybeReplaceSizeOp(const Node* n,
137                         const std::vector<PartialTensorShape>& input_shapes,
138                         std::unordered_map<const Node*, std::vector<Tensor>>*
139                             shape_replacement_map) {
140   CHECK_EQ(input_shapes.size(), 1);
141   if (!input_shapes[0].IsFullyDefined()) {
142     return false;
143   }
144   DataType op_type = n->output_type(0);
145   Tensor t(op_type, TensorShape({}));
146   int64 size = input_shapes[0].num_elements();
147   if (op_type == DT_INT64) {
148     t.scalar<int64>()() = size;
149   } else {
150     CHECK(op_type == DT_INT32);
151     if (size > INT_MAX) {
152       VLOG(1) << "Node " << n->name() << " has input shape size " << size
153               << " but type INT32 "
154               << " so not replacing as constant: this will trigger a runtime "
155                  "error later.";
156       return false;
157     }
158     t.scalar<int32>()() = static_cast<int32>(size);
159   }
160   shape_replacement_map->insert({n, {t}});
161   return true;
162 }
163 
164 // If n is a shape Op (Shape, ShapeN, Rank, or Size) and its inputs have their
165 // shapes specified in shape_map, then adds to shape_replacement_map a mapping
166 // from n to a vector of Tensors, where Tensor k is the (statically known) value
167 // on n's kth output edge. shape_replacement_map has an entry for n iff
168 // MaybeReplaceShapeOp returns true, so it's valid to use
169 // shape_replacement_map->count(n) as a test to see if n is a shape op that can
170 // be replaced.
MaybeReplaceShapeOp(const Node * n,const std::unordered_map<string,std::vector<PartialTensorShape>> * shape_map,std::unordered_map<const Node *,std::vector<Tensor>> * shape_replacement_map)171 bool MaybeReplaceShapeOp(
172     const Node* n,
173     const std::unordered_map<string, std::vector<PartialTensorShape>>*
174         shape_map,
175     std::unordered_map<const Node*, std::vector<Tensor>>*
176         shape_replacement_map) {
177   if (shape_map == nullptr || !IsShapeOp(n)) {
178     return false;
179   }
180   // input_shapes will contain the shapes of each of n's inputs.
181   std::vector<PartialTensorShape> input_shapes;
182   if (!ReadPartialShapesFromShapeMap(n, shape_map, &input_shapes)) {
183     return false;
184   }
185   const auto& ts = n->type_string();
186   if (ts == "Shape" || ts == "ShapeN") {
187     if (!MaybeReplaceShapeOrShapeNOp(n, input_shapes, shape_replacement_map)) {
188       return false;
189     }
190   } else if (ts == "Rank") {
191     if (!MaybeReplaceRankOp(n, input_shapes, shape_replacement_map)) {
192       return false;
193     }
194   } else {
195     CHECK_EQ(ts, "Size");
196     if (!MaybeReplaceSizeOp(n, input_shapes, shape_replacement_map)) {
197       return false;
198     }
199   }
200   return true;
201 }
202 
203 // Returns true if n can be evaluated as constant. shape_map maps from
204 // nodes to the partially-known shapes of their outputs. consider if
205 // non-null returns a bool indicating whether a given (non-Const,
206 // non-Shape) node is eligible to be
207 // constant-propagated. shape_replacement_map is filled in with a
208 // vector of constant output tensors for constant-foldable shape nodes
209 // (Shape, ShapeN, Size, or Rank).
IsConstantFoldable(const Node * n,const std::unordered_map<string,std::vector<PartialTensorShape>> * shape_map,const std::function<bool (const Node *)> & consider,std::unordered_map<const Node *,std::vector<Tensor>> * shape_replacement_map)210 bool IsConstantFoldable(
211     const Node* n,
212     const std::unordered_map<string, std::vector<PartialTensorShape>>*
213         shape_map,
214     const std::function<bool(const Node*)>& consider,
215     std::unordered_map<const Node*, std::vector<Tensor>>*
216         shape_replacement_map) {
217   if (n->IsConstant()) {
218     return true;
219   }
220   if (MaybeReplaceShapeOp(n, shape_map, shape_replacement_map)) {
221     return true;
222   }
223   if (n->op_def().is_stateful()) {
224     return false;
225   }
226   if (consider && !consider(n)) {
227     return false;
228   }
229   if (n->IsControlFlow() || n->IsSend() || n->IsRecv()) {
230     return false;
231   }
232   // TODO(yuanbyu): For now disable these session handle operations.
233   if (n->IsGetSessionHandle() || n->IsGetSessionTensor() ||
234       n->IsDeleteSessionTensor()) {
235     return false;
236   }
237   if (n->IsSource()) {
238     return false;
239   }
240   if (n->IsSink()) {
241     return false;
242   }
243   // Since constant-folding runs on the CPU, do not attempt to constant-fold
244   // operators that have no CPU kernel. Also implies that we will not
245   // constant-fold functions.
246   // TODO(phawkins): allow constant-folding for functions; functions may
247   // be arbitrarily expensive to execute.
248   if (!FindKernelDef(DeviceType(DEVICE_CPU), n->def(), /*def=*/nullptr,
249                      /*kernel_class_name=*/nullptr)
250            .ok()) {
251     return false;
252   }
253 
254   return true;
255 }
256 
257 // If n is eligible for constant-folding, adds it to nodes, and places its
258 // control dependencies and those transitively of its constant-foldable inputs
259 // into constant_control_deps. If n is a constant-foldable shape node (Shape,
260 // ShapeN, Rank, or Size), also puts its outputs into shape_replacement_map.
ConsiderConstantFoldableNode(Node * n,const ConstantFoldingOptions & opts,std::vector<Node * > * nodes,std::unordered_map<const Node *,gtl::FlatSet<Node * >> * constant_control_deps,std::unordered_map<const Node *,std::vector<Tensor>> * shape_replacement_map,bool * internal_node_inserted)261 void ConsiderConstantFoldableNode(
262     Node* n, const ConstantFoldingOptions& opts, std::vector<Node*>* nodes,
263     std::unordered_map<const Node*, gtl::FlatSet<Node*>>* constant_control_deps,
264     std::unordered_map<const Node*, std::vector<Tensor>>* shape_replacement_map,
265     bool* internal_node_inserted) {
266   if (IsConstantFoldable(n, opts.shape_map, opts.consider,
267                          shape_replacement_map)) {
268     // A node is constant provided all of its non-control incoming Tensors come
269     // from constant nodes, or it's a shape Op with statically known inputs in
270     // which case it is placed in shape_replacement_map.
271     //
272     // We allow control dependencies from non-constant nodes to constant nodes,
273     // but to preserve the graph structure we must transfer the control
274     // dependency onto any constant replacement.
275     bool all_parents_constant = true;
276     for (const Edge* in : n->in_edges()) {
277       // Allows non-constant -> constant control edges.
278       if (!in->IsControlEdge() &&
279           constant_control_deps->count(in->src()) == 0) {
280         all_parents_constant = false;
281         break;
282       }
283     }
284     if (all_parents_constant || shape_replacement_map->count(n) != 0) {
285       gtl::FlatSet<Node*>& control_deps = (*constant_control_deps)[n];
286       for (const Edge* e : n->in_edges()) {
287         if (constant_control_deps->count(e->src()) == 0) {
288           // This branch is taken if the incoming edge is a control dependency,
289           // in which case we want to add it to the dependencies being
290           // accumulated for this node, or the incoming edge is not
291           // constant. The latter may happen when n is a shape node and the
292           // source has known shape. In that case add a control dependency from
293           // the source node, since there was previously a data dependency and
294           // we want to preserve sequencing constraints.
295           if (!e->src()->IsSource()) {
296             control_deps.insert(e->src());
297           }
298         } else {
299           // If the parent has been accumulating control dependencies, add all
300           // of its transitive control deps.
301           const gtl::FlatSet<Node*>& parent_deps =
302               (*constant_control_deps)[e->src()];
303           control_deps.insert(parent_deps.begin(), parent_deps.end());
304         }
305       }
306       nodes->push_back(n);
307       if (!n->IsConstant()) {
308         *internal_node_inserted = true;
309       }
310     }
311   }
312 }
313 
314 // Returns the constant foldable nodes in `nodes` in topological order.
315 // Populates `constant_control_deps` with the non-constant control dependencies
316 // of each constant node.
FindConstantFoldableNodes(const Graph * graph,const ConstantFoldingOptions & opts,std::vector<Node * > * nodes,std::unordered_map<const Node *,gtl::FlatSet<Node * >> * constant_control_deps,std::unordered_map<const Node *,std::vector<Tensor>> * shape_replacement_map)317 void FindConstantFoldableNodes(
318     const Graph* graph, const ConstantFoldingOptions& opts,
319     std::vector<Node*>* nodes,
320     std::unordered_map<const Node*, gtl::FlatSet<Node*>>* constant_control_deps,
321     std::unordered_map<const Node*, std::vector<Tensor>>*
322         shape_replacement_map) {
323   bool internal_node_inserted = false;
324   // Walk the nodes in data flow order.
325   ReverseDFS(*graph, nullptr,
326              [nodes, constant_control_deps, shape_replacement_map,
327               &internal_node_inserted, &opts](Node* n) {
328                ConsiderConstantFoldableNode(
329                    n, opts, nodes, constant_control_deps, shape_replacement_map,
330                    &internal_node_inserted);
331              },
332              NodeComparatorName());
333   // If we have inserted just leaf level nodes, then there is nothing to fold.
334   if (!internal_node_inserted) {
335     nodes->clear();
336     constant_control_deps->clear();
337   }
338 }
339 
340 typedef std::pair<Node*, int> NodeAndOutput;
341 
UniqueConstantId()342 int64 UniqueConstantId() {
343   static std::atomic_int_fast64_t unique_constant_id;
344   return unique_constant_id.fetch_add(1);
345 }
346 
347 // Adds n to constant_graph which is being built up for subsequent evaluation of
348 // constant propagation. node_map is the mapping of nodes in the original graph
349 // to nodes in the constant graph. The value of an entry in node_map is a vector
350 // of nodes because a ShapeN node in the original graph is replaced by a vector
351 // of Constant nodes in the constant graph.
AddNodeToConstantGraph(Node * n,std::unordered_map<Node *,std::vector<Node * >> * node_map,Graph * constant_graph)352 void AddNodeToConstantGraph(
353     Node* n, std::unordered_map<Node*, std::vector<Node*>>* node_map,
354     Graph* constant_graph) {
355   std::vector<Node*>& added = (*node_map)[n];
356   added.push_back(constant_graph->CopyNode(n));
357   for (const Edge* in_edge : n->in_edges()) {
358     // Don't copy control edges to the constant graph.
359     if (!in_edge->IsControlEdge()) {
360       Node* in = in_edge->src();
361       auto it = node_map->find(in);
362       CHECK(it != node_map->end())
363           << n->DebugString() << " <-" << in->DebugString();
364       if (it->second.size() == 1) {
365         constant_graph->AddEdge(it->second[0], in_edge->src_output(), added[0],
366                                 in_edge->dst_input());
367       } else {
368         // The original source node had multiple outputs and was replaced by a
369         // vector of constants, so the edge comes from the 0th output of the kth
370         // added constant, rather than the kth output of the added node as in
371         // the standard case above.
372         constant_graph->AddEdge(it->second[in_edge->src_output()], 0, added[0],
373                                 in_edge->dst_input());
374       }
375     }
376   }
377 }
378 
379 // Replaces constant-foldable shape node n by a vector of constants in
380 // constant_graph, which is being built up for subsequent evaluation of constant
381 // propagation. node_map is the mapping of nodes in the original graph to nodes
382 // in the constant graph. The value of an entry in node_map is a vector of nodes
383 // because a ShapeN node in the original graph is replaced by a vector of
384 // Constant nodes in the constant graph.
AddShapeNodeToConstantGraph(Node * n,const std::unordered_map<const Node *,std::vector<Tensor>> & shape_replacement_map,std::unordered_map<Node *,std::vector<Node * >> * node_map,const ConstantFoldNameGenerator & generate_new_name,Graph * constant_graph)385 void AddShapeNodeToConstantGraph(
386     Node* n,
387     const std::unordered_map<const Node*, std::vector<Tensor>>&
388         shape_replacement_map,
389     std::unordered_map<Node*, std::vector<Node*>>* node_map,
390     const ConstantFoldNameGenerator& generate_new_name, Graph* constant_graph) {
391   std::vector<Node*>& added = (*node_map)[n];
392   const string& node_name = n->name();
393   for (const Tensor& t : shape_replacement_map.at(n)) {
394     auto builder =
395         NodeDefBuilder(generate_new_name(constant_graph, node_name), "Const")
396             .Attr("dtype", t.dtype())
397             .Attr("value", t);
398     NodeDef def;
399     CHECK(builder.Finalize(&def).ok());
400     Node* constant_node;
401     CHECK(NodeBuilder(builder).Finalize(constant_graph, &constant_node).ok());
402     added.push_back(constant_node);
403   }
404   // Don't copy incoming edges to shape nodes that are being replaced.
405 }
406 
407 // Given the constant foldable nodes in 'nodes', returns a new graph 'g'. 'g'
408 // will contain copies of the nodes in 'nodes'. In addition, if there is an edge
409 // going from a node 'n' in 'nodes' to another node in 'orig_graph' but not in
410 // 'nodes', then 'tensors_to_fetch' will contain the mapping from the
411 // corresponding copy of 'n' and the edge number in 'g' to 'n'.
GetConstantGraph(const Graph * orig_graph,const std::vector<Node * > & nodes,const std::unordered_map<const Node *,std::vector<Tensor>> & shape_replacement_map,std::map<NodeAndOutput,Node * > * tensors_to_fetch,const ConstantFoldNameGenerator & generate_new_name)412 Graph* GetConstantGraph(
413     const Graph* orig_graph, const std::vector<Node*>& nodes,
414     const std::unordered_map<const Node*, std::vector<Tensor>>&
415         shape_replacement_map,
416     std::map<NodeAndOutput, Node*>* tensors_to_fetch,
417     const ConstantFoldNameGenerator& generate_new_name) {
418   Graph* constant_graph = new Graph(orig_graph->op_registry());
419   std::unordered_map<Node*, std::vector<Node*>> node_map;
420   node_map[orig_graph->source_node()] = {constant_graph->source_node()};
421   node_map[orig_graph->sink_node()] = {constant_graph->sink_node()};
422   for (Node* n : nodes) {
423     if (shape_replacement_map.count(n) == 0) {
424       AddNodeToConstantGraph(n, &node_map, constant_graph);
425     } else {
426       AddShapeNodeToConstantGraph(n, shape_replacement_map, &node_map,
427                                   generate_new_name, constant_graph);
428     }
429   }
430 
431   for (auto const& added_nodes : node_map) {
432     for (const Edge* out_edge : added_nodes.first->out_edges()) {
433       if (node_map.count(out_edge->dst()) == 0) {
434         if (out_edge->IsControlEdge()) continue;
435         if (added_nodes.second.size() == 1) {
436           tensors_to_fetch->insert(
437               {{added_nodes.second[0], out_edge->src_output()},
438                added_nodes.first});
439         } else {
440           // The node had multiple outputs and was replaced by a
441           // vector of constants, so the NodeAndOutput is the 0th
442           // output of the kth added constant, rather than the kth
443           // output of the added node as in the standard case above.
444           tensors_to_fetch->insert(
445               {{added_nodes.second[out_edge->src_output()], 0},
446                added_nodes.first});
447         }
448       }
449     }
450   }
451 
452   return constant_graph;
453 }
454 
455 // Replaces the identified Tensor in 'graph' by a 'Const' node with
456 // the value supplied in 'constant'. 'partition_device', if non-null
457 // is the device where the graph executes. Returns true if the
458 // replacement was successful, false otherwise.
459 // 'control_deps' is the set of nodes that should be control predecessors of the
460 // new constant node.
ReplaceTensorWithConstant(Graph * graph,Device * partition_device,NodeAndOutput tensor,const Tensor & constant,const gtl::FlatSet<Node * > & control_deps,int64 max_constant_size_in_bytes,const ConstantFoldNameGenerator & generate_new_name)461 bool ReplaceTensorWithConstant(
462     Graph* graph, Device* partition_device, NodeAndOutput tensor,
463     const Tensor& constant, const gtl::FlatSet<Node*>& control_deps,
464     int64 max_constant_size_in_bytes,
465     const ConstantFoldNameGenerator& generate_new_name) {
466   // Be conservative when replacing a tensor with a constant, when not
467   // running on CPU.
468   // 1) If the destination tensor is not an int32 tensor, and has HOST_MEMORY
469   // constraint, do not replace it.
470   // 2) If the destination tensor is an int32 tensor, but has DEVICE_MEMORY
471   // constraint, do not replace it.
472   // 3) If the constant op created does not have a kernel implementation
473   // for the device, do not use it.
474   // 4) If the size of the constant in bytes is too large (>
475   // max_constant_in_bytes), do not replace it. This prevents the size of the
476   // Graph from growing too large.
477   // TODO(keveman): Consider adding a new constant op that has a kernel
478   // implementation for all types, but with HostMemory constraint on it's
479   // output.
480   // 5) Do not replace another constant.
481   if (tensor.first->IsConstant()) {
482     return false;
483   }
484   DeviceType device_type = partition_device
485                                ? DeviceType{partition_device->device_type()}
486                                : DEVICE_CPU;
487   if (partition_device && device_type != DEVICE_CPU) {
488     MemoryType memory_type;
489     if (!MemoryTypeForOutput(device_type, graph, tensor.first, tensor.second,
490                              &memory_type)
491              .ok()) {
492       return false;
493     }
494     bool is_int32 = tensor.first->output_type(tensor.second) == DT_INT32;
495     if ((memory_type == HOST_MEMORY && !is_int32) ||
496         (memory_type == DEVICE_MEMORY && is_int32)) {
497       return false;
498     }
499   }
500   if (constant.TotalBytes() > max_constant_size_in_bytes) {
501     return false;
502   }
503 
504   Node* n = tensor.first;
505   std::vector<const Edge*> edges_to_remove;
506   for (const Edge* out_edge : n->out_edges()) {
507     if (out_edge->src_output() == tensor.second) {
508       edges_to_remove.push_back(out_edge);
509     }
510   }
511   const string& node_name = n->name();
512   Node* constant_node;
513   auto builder = NodeDefBuilder(generate_new_name(graph, node_name), "Const")
514                      .Attr("dtype", constant.dtype())
515                      .Attr("value", constant);
516   if (partition_device) {
517     builder.Device(partition_device->name());
518   }
519   NodeDef def;
520   if (!builder.Finalize(&def).ok()) {
521     return false;
522   }
523   const KernelDef* kdef;
524   if (!FindKernelDef(device_type, def, &kdef, nullptr).ok()) {
525     return false;
526   }
527 
528   VLOG(1) << "Replacing " << tensor.first->name() << " :: " << tensor.second
529           << " with a constant";
530 
531   if (!NodeBuilder(builder).Finalize(graph, &constant_node).ok()) {
532     return false;
533   }
534   for (auto edge : edges_to_remove) {
535     graph->AddEdge(constant_node, 0, edge->dst(), edge->dst_input());
536     graph->RemoveEdge(edge);
537   }
538   if (control_deps.empty()) {
539     graph->AddControlEdge(graph->source_node(), constant_node);
540   } else {
541     for (Node* node : control_deps) {
542       graph->AddControlEdge(node, constant_node);
543     }
544   }
545   if (partition_device) {
546     constant_node->set_assigned_device_name(partition_device->name());
547   }
548   return true;
549 }
550 
551 }  // namespace
552 
ConstantFold(const ConstantFoldingOptions & opts,FunctionLibraryRuntime * function_library,Env * env,Device * partition_device,Graph * graph,bool * was_mutated)553 Status ConstantFold(const ConstantFoldingOptions& opts,
554                     FunctionLibraryRuntime* function_library, Env* env,
555                     Device* partition_device, Graph* graph, bool* was_mutated) {
556   DumpGraph("Before", graph);
557   ConstantFoldNameGenerator generate_new_name = opts.generate_new_name;
558   if (generate_new_name == nullptr) {
559     generate_new_name = [](Graph* graph, string old_name) {
560       return strings::StrCat(graph->NewName(old_name), "__cf__",
561                              UniqueConstantId());
562     };
563   }
564 
565   std::vector<Node*> constant_foldable_nodes;
566   std::unordered_map<const Node*, gtl::FlatSet<Node*>> constant_control_deps;
567   std::unordered_map<const Node*, std::vector<Tensor>> shape_replacement_map;
568   FindConstantFoldableNodes(graph, opts, &constant_foldable_nodes,
569                             &constant_control_deps, &shape_replacement_map);
570   if (constant_foldable_nodes.empty()) {
571     VLOG(1) << "No constant foldable nodes found";
572     *was_mutated = false;
573     // This is not an error, so return the status as OK.
574     return Status::OK();
575   }
576 
577   std::map<NodeAndOutput, Node*> tensors_to_fetch;
578   std::unique_ptr<Graph> constant_graph(
579       GetConstantGraph(graph, constant_foldable_nodes, shape_replacement_map,
580                        &tensors_to_fetch, generate_new_name));
581   DumpGraph("Constant graph", constant_graph.get());
582 
583   if (tensors_to_fetch.empty()) {
584     VLOG(1) << "No constant nodes found that feed into the original graph.";
585     *was_mutated = false;
586     // This is not an error, so return the status as OK.
587     return Status::OK();
588   }
589   VLOG(1) << "Constant foldable " << constant_graph->num_node_ids() << " : "
590           << graph->num_node_ids();
591 
592   std::vector<string> tensors_to_fetch_names;
593   std::vector<NodeAndOutput> tensors_to_replace;
594   // Sorting the nodes based on the name gives us a stable ordering between runs
595   // for the same graph.
596   std::vector<std::pair<NodeAndOutput, Node*>> tensors_to_fetch_sorted(
597       tensors_to_fetch.begin(), tensors_to_fetch.end());
598   std::sort(tensors_to_fetch_sorted.begin(), tensors_to_fetch_sorted.end(),
599             [](const std::pair<NodeAndOutput, Node*>& n1,
600                const std::pair<NodeAndOutput, Node*>& n2) {
601               return n1.first.first->name() < n2.first.first->name();
602             });
603   for (auto n : tensors_to_fetch_sorted) {
604     tensors_to_fetch_names.push_back(
605         strings::StrCat(n.first.first->name(), ":", n.first.second));
606     tensors_to_replace.push_back({n.second, n.first.second});
607   }
608 
609   auto graph_runner = std::unique_ptr<GraphRunner>(new GraphRunner(env));
610   // Evaluate the constant foldable nodes.
611   std::vector<Tensor> outputs;
612   auto delete_tensors = gtl::MakeCleanup([&graph_runner, &outputs] {
613     // Output tensors need to be cleared before the GraphRunner is deleted.
614     outputs.clear();
615     graph_runner.reset(nullptr);
616   });
617 
618   Status s =
619       graph_runner->Run(constant_graph.get(), function_library, {} /* inputs*/,
620                         tensors_to_fetch_names, &outputs);
621   if (!s.ok()) {
622     VLOG(1) << "Could not fetch constants: " << s;
623     *was_mutated = false;
624     return s;
625   }
626 
627   // Fetch the constant tensors and replace the corresponding tensors in the
628   // original graph with those constants.
629   int32 num_nodes_replaced = 0;
630   for (size_t c = 0; c < outputs.size(); ++c) {
631     const gtl::FlatSet<Node*>& control_deps =
632         constant_control_deps[tensors_to_replace[c].first];
633     if (ReplaceTensorWithConstant(
634             graph, partition_device, tensors_to_replace[c], outputs[c],
635             control_deps, opts.max_constant_size_in_bytes, generate_new_name)) {
636       ++num_nodes_replaced;
637     }
638   }
639 
640   DumpGraph("After", graph);
641 
642   *was_mutated = (num_nodes_replaced > 0);
643   return Status::OK();
644 }
645 
646 }  // namespace tensorflow
647