1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/grappler/optimizers/memory_optimizer.h"
17 
18 #include <algorithm>
19 #include <queue>
20 #include <unordered_map>
21 #include <unordered_set>
22 #include <vector>
23 
24 #include "tensorflow/core/framework/attr_value.pb.h"
25 #include "tensorflow/core/framework/node_def.pb.h"
26 #include "tensorflow/core/framework/op.h"
27 #include "tensorflow/core/framework/tensor.pb.h"  // NOLINT
28 #include "tensorflow/core/framework/tensor_shape.pb.h"
29 #include "tensorflow/core/grappler/clusters/virtual_cluster.h"
30 #include "tensorflow/core/grappler/costs/graph_memory.h"
31 #include "tensorflow/core/grappler/costs/graph_properties.h"
32 #include "tensorflow/core/grappler/costs/utils.h"
33 #include "tensorflow/core/grappler/graph_topology_view.h"
34 #include "tensorflow/core/grappler/grappler_item.h"
35 #include "tensorflow/core/grappler/mutable_graph_view.h"
36 #include "tensorflow/core/grappler/op_types.h"
37 #include "tensorflow/core/grappler/optimizers/static_schedule.h"
38 #include "tensorflow/core/grappler/utils.h"
39 #include "tensorflow/core/grappler/utils/topological_sort.h"
40 #include "tensorflow/core/grappler/utils/traversal.h"
41 #include "tensorflow/core/lib/math/math_util.h"
42 #include "tensorflow/core/lib/strings/str_util.h"
43 #include "tensorflow/core/protobuf/rewriter_config.pb.h"
44 #include "tensorflow/core/util/device_name_utils.h"
45 
46 namespace tensorflow {
47 namespace grappler {
48 
49 namespace {
50 
51 // Prefix added to nodes which are recomputed.
52 const char* kRecomputedNodePrefix = "Recomputed";
53 const char* kRecomputeTriggerNodePrefix = "RecomputeTrigger";
54 // Attribute which may be added to nodes to manually allow them to be
55 // recomputed.
56 const char* kRecomputeHint = "_recompute_hint";
57 
58 // Ops which we wouldn't mind recomputing to save memory.
59 // TODO(allenl): Replace this list with a cost model.
GetCheapToRecomputeOps()60 std::unordered_set<string> GetCheapToRecomputeOps() {
61   std::unordered_set<string> cheap_ops = {
62       "Add",      "AddN",       "BiasAdd",        "Cast",   "Fill",
63       "FloorDiv", "FloorMod",   "FusedBatchNorm", "Mul",    "Neg",
64       "RealDiv",  "Reciprocal", "Relu",           "Relu6",  "Reshape",
65       "Rsqrt",    "Sigmoid",    "Sqrt",           "Square", "SquaredDifference",
66       "Sub",      "Tile",       "Transpose"};
67   return cheap_ops;
68 }
69 
70 // Find recomputable ops which feed into target nodes.
FindCandidateRecomputeNodes(const NodeMap & node_map,const GraphDef * graph,const std::function<bool (const NodeDef &)> & is_candidate,const std::function<bool (const NodeDef &)> & is_target)71 std::unordered_set<const NodeDef*> FindCandidateRecomputeNodes(
72     const NodeMap& node_map, const GraphDef* graph,
73     const std::function<bool(const NodeDef&)>& is_candidate,
74     const std::function<bool(const NodeDef&)>& is_target) {
75   std::unordered_set<const NodeDef*> candidate_recompute_nodes;
76   for (const auto& node : graph->node()) {
77     if (!is_candidate(node)) {
78       continue;
79     }
80     bool has_target_output = false;
81     for (const NodeDef* output : node_map.GetOutputs(node.name())) {
82       // It only makes sense to recompute this if it feeds into a target
83       // node. We expand this to dependencies in GetOpGroupsToRecompute.
84       if (is_target(*output)) {
85         has_target_output = true;
86         break;
87       }
88     }
89     if (!has_target_output) {
90       continue;
91     }
92     bool has_target_input = false;
93     for (const string& input_name : node.input()) {
94       // Don't recompute nodes which depend on target nodes.
95       const NodeDef* input_node = node_map.GetNode(input_name);
96       if (is_target(*input_node)) {
97         has_target_input = true;
98         break;
99       }
100     }
101     if (has_target_input) {
102       continue;
103     }
104     candidate_recompute_nodes.insert(&node);
105   }
106   return candidate_recompute_nodes;
107 }
108 
connected_subgraph(const NodeMap & node_map,bool collect_inputs,bool collect_outputs,const std::function<bool (const NodeDef &)> & is_candidate,std::unordered_set<const NodeDef * > * expanded_nodes)109 void connected_subgraph(const NodeMap& node_map, bool collect_inputs,
110                         bool collect_outputs,
111                         const std::function<bool(const NodeDef&)>& is_candidate,
112                         std::unordered_set<const NodeDef*>* expanded_nodes) {
113   std::queue<const NodeDef*> to_visit;
114   for (const NodeDef* starting_node : *expanded_nodes) {
115     to_visit.push(starting_node);
116   }
117   expanded_nodes->clear();
118   while (!to_visit.empty()) {
119     const NodeDef* current_node = to_visit.front();
120     to_visit.pop();
121     if (!expanded_nodes->insert(current_node).second) {
122       // We already visited this node
123       continue;
124     }
125     if (collect_inputs) {
126       // Add inputs and outputs to this subgraph if they are candidates
127       for (const string& input_name_raw : current_node->input()) {
128         const NodeDef* input_node = node_map.GetNode(input_name_raw);
129         if (expanded_nodes->count(input_node) == 0 &&
130             is_candidate(*input_node)) {
131           to_visit.push(input_node);
132         }
133       }
134     }
135     if (collect_outputs) {
136       for (const NodeDef* output : node_map.GetOutputs(current_node->name())) {
137         if (expanded_nodes->count(output) == 0 && is_candidate(*output)) {
138           to_visit.push(output);
139         }
140       }
141     }
142   }
143 }
144 
145 struct RecomputedSubGraph {
146   std::unordered_set<const NodeDef*> recomputed_source_nodes;
147   std::unordered_set<NodeDef*> target_nodes;
148 };
149 
150 // Find groups of ops to recompute together based on `should_recompute`.
GetOpGroupsToRecompute(const GraphDef * graph,const NodeMap & node_map,const std::function<bool (const NodeDef &)> & should_recompute,const std::function<bool (const NodeDef &)> & is_target)151 std::vector<RecomputedSubGraph> GetOpGroupsToRecompute(
152     const GraphDef* graph, const NodeMap& node_map,
153     const std::function<bool(const NodeDef&)>& should_recompute,
154     const std::function<bool(const NodeDef&)>& is_target) {
155   std::unordered_set<const NodeDef*> visited_nodes;
156   std::vector<RecomputedSubGraph> subgraphs_to_recompute;
157   std::unordered_set<const NodeDef*> candidate_recompute_nodes =
158       FindCandidateRecomputeNodes(node_map, graph, should_recompute, is_target);
159   for (const NodeDef* recompute_node : candidate_recompute_nodes) {
160     if (visited_nodes.count(recompute_node) > 0) {
161       continue;
162     }
163     RecomputedSubGraph current_recomputation;
164     // Build out recomputation groups by expanding to inexpensive-to-recompute
165     // nodes which do not feed target nodes. The goal is to capture some
166     // intermediate activations within this graph.
167     std::unordered_set<const NodeDef*> unpruned_recompute_nodes;
168     unpruned_recompute_nodes.insert(recompute_node);
169     connected_subgraph(node_map,
170                        true,  // Collect inputs
171                        true,  // Collect outputs
172                        should_recompute, &unpruned_recompute_nodes);
173     visited_nodes.insert(unpruned_recompute_nodes.begin(),
174                          unpruned_recompute_nodes.end());
175     for (const NodeDef* recompute_node : unpruned_recompute_nodes) {
176       bool inserted_feed = false;
177       for (NodeDef* output : node_map.GetOutputs(recompute_node->name())) {
178         if (is_target(*output)) {
179           current_recomputation.target_nodes.insert(output);
180           if (!inserted_feed) {
181             // Keep track of nodes which feed directly into a target node. These
182             // and nodes which feed into them will define the recomputed
183             // subgraph.
184             current_recomputation.recomputed_source_nodes.insert(
185                 recompute_node);
186             inserted_feed = true;
187           }
188         }
189       }
190     }
191     // Recompute only nodes which eventually feed into a target node.
192     connected_subgraph(
193         node_map,
194         true,   // Collect inputs
195         false,  // Collect outputs
196         [&unpruned_recompute_nodes](const NodeDef& node) {
197           return unpruned_recompute_nodes.count(&node) != 0;
198         },
199         &current_recomputation.recomputed_source_nodes);
200     if (current_recomputation.target_nodes.empty()) {
201       continue;
202     }
203     subgraphs_to_recompute.push_back(current_recomputation);
204   }
205   return subgraphs_to_recompute;
206 }
207 
208 // Computes the maximum topological numbers of (1) target node components
209 // (gradient nodes being fed by the recomputation), and (2) child recompute node
210 // components for each recomputed node. We will not attach any control
211 // dependencies to a recomputation unless they have component numbers greater
212 // than this value (to prevent cycles).
GetMaxDownstreamComponents(const std::unordered_set<const NodeDef * > & recomputed_source_nodes,const std::unordered_set<NodeDef * > & target_nodes,const NodeMap & node_map,const std::unordered_map<const NodeDef *,int> & components)213 std::unordered_map<const NodeDef*, int> GetMaxDownstreamComponents(
214     const std::unordered_set<const NodeDef*>& recomputed_source_nodes,
215     const std::unordered_set<NodeDef*>& target_nodes, const NodeMap& node_map,
216     const std::unordered_map<const NodeDef*, int>& components) {
217   std::unordered_map<const NodeDef*, int> recomputed_node_components;
218   // Start by setting component numbers to the maximum among target nodes.
219   for (const NodeDef* original_recompute_node : recomputed_source_nodes) {
220     int max_target_component = -1;
221     for (NodeDef* output :
222          node_map.GetOutputs(original_recompute_node->name())) {
223       if (target_nodes.count(output) != 0) {
224         int current_target_component = components.find(output)->second;
225         if (current_target_component > max_target_component) {
226           max_target_component = current_target_component;
227         }
228       }
229     }
230     if (max_target_component > -1) {
231       recomputed_node_components[original_recompute_node] =
232           max_target_component;
233     }
234   }
235   // Sort recomputed nodes topologically (based on the original graph) so we can
236   // efficiently assign to each node the maximum of its recomputed child
237   // components and its own targets.
238   std::vector<const NodeDef*> recomputed_source_nodes_topological(
239       recomputed_source_nodes.begin(), recomputed_source_nodes.end());
240   std::sort(recomputed_source_nodes_topological.begin(),
241             recomputed_source_nodes_topological.end(),
242             [&components](const NodeDef* first, const NodeDef* second) {
243               return components.find(first)->second <
244                      components.find(second)->second;
245             });
246   for (const NodeDef* original_recompute_node :
247        recomputed_source_nodes_topological) {
248     int max_component;
249     auto recomputed_component_iterator =
250         recomputed_node_components.find(original_recompute_node);
251     if (recomputed_component_iterator != recomputed_node_components.end()) {
252       max_component = recomputed_component_iterator->second;
253     } else {
254       max_component = -1;
255     }
256     for (NodeDef* output :
257          node_map.GetOutputs(original_recompute_node->name())) {
258       if (recomputed_source_nodes.count(output) == 0) {
259         continue;
260       }
261       auto child_component_iterator = recomputed_node_components.find(output);
262       CHECK(child_component_iterator != recomputed_node_components.end());
263       int child_component = child_component_iterator->second;
264       if (child_component > max_component) {
265         max_component = child_component;
266       }
267     }
268     CHECK_GE(max_component, 0);
269     recomputed_node_components[original_recompute_node] = max_component;
270   }
271   return recomputed_node_components;
272 }
273 
274 // Modifies `graph`, adding trigger nodes and returning a mapping from
275 // `recomputed_source_nodes` to trigger nodes which will not create loops in the
276 // graph (using the component numberings in `components` and
277 // `recomputed_node_max_feed_components`). The copied nodes (not the nodes in
278 // recomputed_source_nodes, which are the originals) eventually get these
279 // control dependencies.
280 std::unordered_map<const NodeDef*, const NodeDef*>
AddRecomputeControlDependencyNodes(const std::unordered_set<const NodeDef * > & recomputed_source_nodes,const std::unordered_set<NodeDef * > & target_nodes,const NodeMap & node_map,const std::unordered_map<const NodeDef *,int> & components,const std::unordered_map<const NodeDef *,int> & recomputed_node_max_feed_components,GraphDef * graph)281 AddRecomputeControlDependencyNodes(
282     const std::unordered_set<const NodeDef*>& recomputed_source_nodes,
283     const std::unordered_set<NodeDef*>& target_nodes, const NodeMap& node_map,
284     const std::unordered_map<const NodeDef*, int>& components,
285     const std::unordered_map<const NodeDef*, int>&
286         recomputed_node_max_feed_components,
287     GraphDef* graph) {
288   // Sort recomputed nodes based on max downstream components.
289   std::vector<const NodeDef*> recomputed_source_nodes_topological(
290       recomputed_source_nodes.begin(), recomputed_source_nodes.end());
291   std::sort(recomputed_source_nodes_topological.begin(),
292             recomputed_source_nodes_topological.end(),
293             [&recomputed_node_max_feed_components](const NodeDef* first,
294                                                    const NodeDef* second) {
295               int first_component =
296                   recomputed_node_max_feed_components.find(first)->second;
297               int second_component =
298                   recomputed_node_max_feed_components.find(second)->second;
299               return first_component > second_component
300                      // Ensure a consistent ordering. This is necessary because
301                      // we're working not with node component numbers (which are
302                      // unique) but with the maximum across nodes they feed into
303                      // (very much not unique).
304                      || (first_component == second_component &&
305                          first->name() > second->name());
306             });
307   // Create merged control dependency nodes by sorting target inputs
308   // topologically and zipper merging with the sorted recomputed nodes.
309   std::vector<const NodeDef*> target_inputs_topological;
310   for (const NodeDef* target_node : target_nodes) {
311     for (const string& target_input_name_raw : target_node->input()) {
312       const NodeDef* target_input = node_map.GetNode(target_input_name_raw);
313       // If this node has already had one of its inputs recomputed during this
314       // rewriting pass, we ignore that recomputed node here (it will not be in
315       // the NodeMap).
316       if (target_input == nullptr ||
317           recomputed_source_nodes.count(target_input) != 0 ||
318           components.find(target_node)->second ==
319               components.find(target_input)->second) {
320         continue;
321       }
322       target_inputs_topological.push_back(target_input);
323     }
324   }
325   std::sort(target_inputs_topological.begin(), target_inputs_topological.end(),
326             [&components](const NodeDef* first, const NodeDef* second) {
327               return components.find(first)->second >
328                      components.find(second)->second;
329             });
330   auto target_input_iterator = target_inputs_topological.begin();
331   NodeDef* current_trigger_node = nullptr;
332   std::unordered_map<const NodeDef*, const NodeDef*> triggers;
333   for (const NodeDef* original_recomputed_node :
334        recomputed_source_nodes_topological) {
335     NodeDef* new_trigger_node = graph->add_node();
336     new_trigger_node->set_name(AddPrefixToNodeName(
337         original_recomputed_node->name(), kRecomputeTriggerNodePrefix));
338     new_trigger_node->set_op("NoOp");
339     new_trigger_node->set_device(original_recomputed_node->device());
340     if (current_trigger_node != nullptr) {
341       *new_trigger_node->add_input() =
342           strings::StrCat("^", current_trigger_node->name());
343     }
344     current_trigger_node = new_trigger_node;
345     triggers[original_recomputed_node] = current_trigger_node;
346     for (;
347          target_input_iterator != target_inputs_topological.end() &&
348          components.find(*target_input_iterator)->second >
349              recomputed_node_max_feed_components.find(original_recomputed_node)
350                  ->second;
351          ++target_input_iterator) {
352       *current_trigger_node->add_input() =
353           strings::StrCat("^", (*target_input_iterator)->name());
354       VLOG(2) << "  Recomputation trigger " << current_trigger_node->name()
355               << " depends on " << (*target_input_iterator)->name();
356     }
357   }
358   return triggers;
359 }
360 
RecomputedOrOriginalNodeName(const std::unordered_set<string> & recomputed_node_names,const string & original_node_name)361 string RecomputedOrOriginalNodeName(
362     const std::unordered_set<string>& recomputed_node_names,
363     const string& original_node_name) {
364   if (recomputed_node_names.find(original_node_name) ==
365       recomputed_node_names.end()) {
366     return original_node_name;
367   } else {
368     return AddPrefixToNodeName(original_node_name, kRecomputedNodePrefix);
369   }
370 }
371 
372 // Helper function to recompute a sub-graph (recomputed_source_nodes). Edges
373 // from recomputed_source_nodes to target_nodes are changed to start from the
374 // recomputed nodes.
RecomputeSubgraph(const std::unordered_set<const NodeDef * > & recomputed_source_nodes,const std::unordered_set<NodeDef * > & target_nodes,const NodeMap & node_map,const std::unordered_map<const NodeDef *,int> & components,GraphDef * graph)375 void RecomputeSubgraph(
376     const std::unordered_set<const NodeDef*>& recomputed_source_nodes,
377     const std::unordered_set<NodeDef*>& target_nodes, const NodeMap& node_map,
378     const std::unordered_map<const NodeDef*, int>& components,
379     GraphDef* graph) {
380   std::unordered_set<string> recomputed_node_names;
381   VLOG(1) << "Recomputing a " << recomputed_source_nodes.size()
382           << " node subgraph";
383   std::unordered_map<const NodeDef*, int> recomputed_node_components =
384       GetMaxDownstreamComponents(recomputed_source_nodes, target_nodes,
385                                  node_map, components);
386   for (const NodeDef* original_node : recomputed_source_nodes) {
387     VLOG(2) << "  " << original_node->name();
388     recomputed_node_names.insert(original_node->name());
389   }
390   std::unordered_map<const NodeDef*, const NodeDef*> triggers =
391       AddRecomputeControlDependencyNodes(recomputed_source_nodes, target_nodes,
392                                          node_map, components,
393                                          recomputed_node_components, graph);
394   // Create the recomputed sub-graph
395   for (const NodeDef* original_node : recomputed_source_nodes) {
396     NodeDef* copied_node = graph->add_node();
397     copied_node->set_name(
398         AddPrefixToNodeName(original_node->name(), kRecomputedNodePrefix));
399     copied_node->set_op(original_node->op());
400     *copied_node->mutable_attr() = original_node->attr();
401     copied_node->set_device(original_node->device());
402     for (const string& original_input_name : original_node->input()) {
403       // Set inputs which are internal to the copied subgraph to their copied
404       // versions.
405       *copied_node->add_input() = RecomputedOrOriginalNodeName(
406           recomputed_node_names, original_input_name);
407     }
408     // Each recomputed node gets a control dependency to prevent it from being
409     // recomputed immediately.
410     *copied_node->add_input() =
411         strings::StrCat("^", triggers[original_node]->name());
412   }
413   // Set the inputs of nodes in the target subgraph to the recomputed nodes
414   // where applicable.
415   for (NodeDef* target_node : target_nodes) {
416     for (string& target_input_name : *target_node->mutable_input()) {
417       target_input_name = RecomputedOrOriginalNodeName(recomputed_node_names,
418                                                        target_input_name);
419     }
420   }
421 }
422 
RecomputationRewritingPass(RewriterConfig::MemOptType optimization_level,const string & recomputation_targets_name_scope,GraphDef * graph,const GrapplerItem & item)423 void RecomputationRewritingPass(RewriterConfig::MemOptType optimization_level,
424                                 const string& recomputation_targets_name_scope,
425                                 GraphDef* graph, const GrapplerItem& item) {
426   if (optimization_level != RewriterConfig::RECOMPUTATION_HEURISTICS &&
427       optimization_level != RewriterConfig::HEURISTICS &&
428       optimization_level != RewriterConfig::MANUAL) {
429     // Nothing to do
430     return;
431   }
432   // The topological numberings and NodeMap will be stale as soon as we start
433   // modifying the graph in RecomputeSubgraph. However, RecomputeSubgraph only
434   // looks up nodes which were in the original graph, and preserves the graph
435   // topology it's interested in.
436   // We don't use the results of this topological sort until later, but this
437   // call invalidates all NodeDef pointers, so it needs to be done before we
438   // start collecting those.
439   TF_CHECK_OK(TopologicalSort(graph));
440   NodeMap node_map(graph);
441   std::vector<RecomputedSubGraph> recomputed_subgraphs;
442   // Do not recompute nodes which are fed, since the recomputed node would not
443   // take on the fed value (i.e. gradients would be incorrect).
444   std::unordered_set<string> feeds;
445   for (const auto& feed : item.feed) {
446     feeds.insert(NodeName(feed.first));
447   }
448   std::function<bool(const NodeDef&)> is_target =
449       [&recomputation_targets_name_scope](const NodeDef& node) {
450         // Nodes whose inputs we may want to recompute. This matches node names
451         // that contain recomputation_targets_name_scope as a name scope,
452         // meaning it either begins with or contains the name scope.
453         // Defaults to "gradients/" which will match any node names that begins
454         // with "gradients/" or contains "/gradients/".
455         return node.name().find(recomputation_targets_name_scope) == 0 ||
456                node.name().find("/" + recomputation_targets_name_scope) != -1;
457       };
458 
459   if (optimization_level == RewriterConfig::RECOMPUTATION_HEURISTICS ||
460       optimization_level == RewriterConfig::HEURISTICS) {
461     // TODO(allenl): Handle ResNet-like architectures better. Right now all of
462     // the cheap forward ops get grouped into a single subgraph which must
463     // execute before gradients start executing (unless layers are manually
464     // separated by identity ops).
465     std::unordered_set<string> cheap_to_recompute_ops =
466         GetCheapToRecomputeOps();
467     recomputed_subgraphs = GetOpGroupsToRecompute(
468         graph, node_map,
469         [&cheap_to_recompute_ops, &feeds, &is_target](const NodeDef& node) {
470           return !is_target(node) && feeds.count(node.name()) == 0 &&
471                  (cheap_to_recompute_ops.count(node.op()) > 0 ||
472                   node.attr().count(kRecomputeHint) > 0);
473         },
474         is_target);
475   } else if (optimization_level == RewriterConfig::MANUAL) {
476     recomputed_subgraphs = GetOpGroupsToRecompute(
477         graph, node_map,
478         [&feeds, &is_target](const NodeDef& node) {
479           return !is_target(node) && feeds.count(node.name()) == 0 &&
480                  node.attr().count(kRecomputeHint) > 0;
481         },
482         is_target);
483   }
484   if (!recomputed_subgraphs.empty()) {
485     std::unordered_map<const NodeDef*, int> topological_numbering;
486     for (int node_number = 0; node_number < graph->node().size();
487          ++node_number) {
488       topological_numbering[graph->mutable_node(node_number)] =
489           graph->node().size() - node_number - 1;
490     }
491     // Duplicate the indicated sub-graphs and set up control dependencies
492     for (const RecomputedSubGraph& subgraph : recomputed_subgraphs) {
493       RecomputeSubgraph(subgraph.recomputed_source_nodes, subgraph.target_nodes,
494                         node_map, topological_numbering, graph);
495     }
496   }
497 }
498 
SchedulingPass(Cluster * cluster,GrapplerItem * item)499 bool SchedulingPass(Cluster* cluster, GrapplerItem* item) {
500   // Look for AddN nodes (and equivalent) and record input names.
501   MutableGraphView view(&item->graph);
502 
503   // It's ok to use immutable GraphTopologyView here, because we do not destroy
504   // any of the nodes in the underlying graph, we only add new nodes.
505   GraphTopologyView graph_topology;
506   Status initialized_topology = graph_topology.InitializeFromGraph(item->graph);
507   if (!initialized_topology.ok()) {
508     VLOG(1) << "Failed to initialize graph topology view: "
509             << initialized_topology.error_message();
510     return false;
511   }
512 
513   std::unordered_map<string, std::unordered_set<NodeDef*>> addn_list;
514   for (NodeDef& node : *item->graph.mutable_node()) {
515     if (!IsAddN(node) && node.op() != "AccumulateNV2") {
516       continue;
517     }
518     // There is nothing to gain by optimizing nodes with 2 or fewer inputs.
519     if (view.NumFanins(node, false) <= 2) {
520       continue;
521     }
522     for (const auto& input : view.GetFanins(node, false)) {
523       if (input.node->device() == node.device()) {
524         string tensor_name =
525             strings::StrCat(input.node->name(), ":", input.port_id);
526         addn_list[tensor_name].insert(&node);
527       }
528     }
529   }
530 
531   if (addn_list.empty()) {
532     return false;
533   }
534 
535   GraphMemory memory(*item);
536   const std::unordered_map<string, DeviceProperties>& devices =
537       cluster->GetDevices();
538   Status s = memory.InferStatically(devices);
539   if (!s.ok()) {
540     VLOG(1) << "Failed to infer memory usage: " << s.error_message();
541     return false;
542   }
543 
544   std::unordered_set<NodeDef*> addn_to_rewrite;
545   for (const auto& device : devices) {
546     const string& name = device.first;
547     const DeviceProperties& prop = device.second;
548     if (prop.memory_size() <= 0) {
549       VLOG(1) << "Available memory unknown for device " << name;
550       continue;
551     }
552     const GraphMemory::MemoryUsage& mem_usage = memory.GetPeakMemoryUsage(name);
553 
554     if (mem_usage.used_memory <= prop.memory_size() * 0.8) {
555       continue;
556     }
557 
558     for (const auto& live : mem_usage.live_tensors) {
559       string tensor_name = strings::StrCat(live.node, ":", live.output_id);
560       auto it = addn_list.find(tensor_name);
561       if (it != addn_list.end()) {
562         addn_to_rewrite.insert(it->second.begin(), it->second.end());
563       }
564     }
565   }
566 
567   if (addn_to_rewrite.empty()) {
568     return false;
569   }
570   GraphProperties properties(*item);
571   s = properties.InferStatically(false);
572   if (!s.ok()) {
573     VLOG(1) << "Failed to infer shapes: " << s.error_message();
574     return false;
575   }
576 
577   bool updated_graph = false;
578   // Rewrite the AddN.
579   for (NodeDef* node : addn_to_rewrite) {
580     if (!properties.HasOutputProperties(node->name())) {
581       VLOG(1) << "Missing properties for " << node->name();
582       continue;
583     }
584     const TensorShapeProto& shape =
585         properties.GetOutputProperties(node->name())[0].shape();
586     PartialTensorShape shp(shape);
587     if (!shp.IsFullyDefined()) {
588       VLOG(1) << "Shape not fully known for " << node->name();
589       continue;
590     }
591 
592     // Compute a topological ordering for the node fanin.
593     std::unordered_map<const NodeDef*, int> topo_order;
594     DfsTraversal(graph_topology, {node}, TraversalDirection::kFollowInputs,
595                  DfsCallbacks::PostOrder([&topo_order](const NodeDef* n) {
596                    int topo_index = static_cast<int>(topo_order.size());
597                    topo_order[n] = topo_index;
598                  }));
599 
600     std::vector<int> input_topo_index;
601 
602     for (int i = 0; i < node->input_size(); ++i) {
603       const string& input = node->input(i);
604       const string node_name = NodeName(input);
605       const NodeDef* node = view.GetNode(node_name);
606       input_topo_index.push_back(topo_order.at(node));
607     }
608     int min_input_topo_index = INT_MAX;
609     int min_input_id = -1;
610     for (int i = 0; i < node->input_size(); ++i) {
611       if (IsControlInput(node->input(i))) {
612         // control inputs are always last.
613         break;
614       }
615       const int current = input_topo_index[i];
616       if (current < min_input_topo_index) {
617         min_input_topo_index = current;
618         min_input_id = i;
619       }
620     }
621     CHECK_LE(0, min_input_id);
622     std::vector<string> pre_ctrl_deps;
623     std::vector<string> post_ctrl_deps;
624     for (int i = node->input_size() - 1; i >= 0; --i) {
625       if (!IsControlInput(node->input(i))) {
626         // control inputs are always last.
627         break;
628       }
629       if (input_topo_index[i] < min_input_topo_index) {
630         // These control dependencies can be executed before the node.
631         pre_ctrl_deps.push_back(node->input(i));
632       } else {
633         // These control dependencies should be executed after the node.
634         post_ctrl_deps.push_back(node->input(i));
635       }
636     }
637 
638     DataType dtype = node->attr().at("T").type();
639     const string& device = node->device();
640 
641     // Create the temporary variable that will hold intermediate results
642     NodeDef* tmp_var = item->graph.add_node();
643     tmp_var->set_name(strings::StrCat(node->name(), "/tmp_var"));
644     tmp_var->set_op("TemporaryVariable");
645     tmp_var->set_device(device);
646     (*tmp_var->mutable_attr())["dtype"].set_type(dtype);
647     *(*tmp_var->mutable_attr())["shape"].mutable_shape() = shape;
648     (*tmp_var->mutable_attr())["var_name"].set_s(tmp_var->name());
649 
650     for (const string& ctrl_dep : pre_ctrl_deps) {
651       *tmp_var->add_input() = ctrl_dep;
652     }
653     *tmp_var->add_input() =
654         AsControlDependency(NodeName(node->input(min_input_id)));
655 
656     // Initialize it to zero
657     NodeDef* zeros = item->graph.add_node();
658     zeros->set_name(strings::StrCat(node->name(), "/tmp_var_zeros"));
659     zeros->set_op("ZerosLike");
660     zeros->set_device(device);
661     (*zeros->mutable_attr())["T"].set_type(dtype);
662     *zeros->add_input() = node->input(min_input_id);
663 
664     NodeDef* initialize = item->graph.add_node();
665     initialize->set_name(strings::StrCat(node->name(), "/tmp_var_initializer"));
666     initialize->set_op("Assign");
667     initialize->set_device(device);
668     (*initialize->mutable_attr())["T"].set_type(dtype);
669     (*initialize->mutable_attr())["use_locking"].set_b(false);
670     (*initialize->mutable_attr())["validate_shape"].set_b(false);
671     *initialize->add_input() = tmp_var->name();
672     *initialize->add_input() = zeros->name();
673 
674     // Add the assignadd nodes
675     std::vector<NodeDef*> accumulates;
676     for (int i = 0; i < node->input_size(); ++i) {
677       const string& input = node->input(i);
678       if (!IsControlInput(input)) {
679         NodeDef* accumulate = item->graph.add_node();
680         accumulate->set_name(
681             strings::StrCat(node->name(), "/tmp_var_accum_", i));
682         accumulate->set_op("AssignAdd");
683         accumulate->set_device(device);
684         (*accumulate->mutable_attr())["T"].set_type(dtype);
685         (*accumulate->mutable_attr())["use_locking"].set_b(true);
686         *accumulate->add_input() = initialize->name();
687         *accumulate->add_input() = input;
688         accumulates.push_back(accumulate);
689       }
690     }
691 
692     // Rewrite the AddN node as a DestroyTemporaryVariable ops
693     node->set_op("DestroyTemporaryVariable");
694     node->clear_input();
695     node->clear_attr();
696     (*node->mutable_attr())["T"].set_type(dtype);
697     (*node->mutable_attr())["var_name"].set_s(tmp_var->name());
698     *node->add_input() = initialize->name();
699     for (const NodeDef* accum : accumulates) {
700       *node->add_input() = AsControlDependency(accum->name());
701     }
702     for (const string& ctrl_dep : post_ctrl_deps) {
703       *node->add_input() = ctrl_dep;
704     }
705 
706     updated_graph = true;
707   }
708 
709   return updated_graph;
710 }
711 
BuildSwapPair(NodeDef * node,int input_to_swap,const std::unordered_map<string,const NodeDef * > & name_map,GraphDef * graph,std::pair<NodeDef *,NodeDef * > * swap_pair)712 Status BuildSwapPair(NodeDef* node, int input_to_swap,
713                      const std::unordered_map<string, const NodeDef*>& name_map,
714                      GraphDef* graph,
715                      std::pair<NodeDef*, NodeDef*>* swap_pair) {
716   string task, device;
717   if (!DeviceNameUtils::SplitDeviceName(node->device(), &task, &device) ||
718       !str_util::StrContains(device, DEVICE_GPU)) {
719     return errors::InvalidArgument("Can't swap input ", input_to_swap,
720                                    " of node ", node->name(),
721                                    " since it is not on GPU");
722   }
723   const OpDef* op_def;
724   TF_RETURN_IF_ERROR(OpRegistry::Global()->LookUpOpDef(node->op(), &op_def));
725   DataType input_type;
726   TF_RETURN_IF_ERROR(
727       InputTypeForNode(*node, *op_def, input_to_swap, &input_type));
728   if (IsRefType(input_type)) {
729     return errors::InvalidArgument("Can't swap input ", input_to_swap,
730                                    " of node ", node->name(),
731                                    " since it expects a reference");
732   }
733 
734   string tensor_to_swap = strings::StrCat(node->name(), "_", input_to_swap);
735   string swap_out_name = strings::StrCat("swap_out_", tensor_to_swap);
736   string swap_in_name = strings::StrCat("swap_in_", tensor_to_swap);
737   if (name_map.find(swap_out_name) != name_map.end() ||
738       name_map.find(swap_in_name) != name_map.end()) {
739     return errors::InvalidArgument("Input ", input_to_swap, " of node ",
740                                    node->name(), " is already swapped");
741   }
742 
743   // Force the tensor to be copied to cpu.
744   NodeDef* swap_out_node = graph->add_node();
745   swap_out_node->set_name(swap_out_name);
746   swap_out_node->set_op("_CopyFromGpuToHost");
747 
748   // Force the tensor to be restored to the device.
749   NodeDef* swap_in_node = graph->add_node();
750   swap_in_node->set_name(swap_in_name);
751   swap_in_node->set_op("_CopyFromHostToGpu");
752   *swap_in_node->add_input() = swap_out_node->name();
753 
754   // Colocate the swap_out_ and swap_in_ nodes with the node itself.
755   swap_out_node->set_device(node->device());
756   swap_in_node->set_device(node->device());
757   string coloc_group = strings::StrCat("loc@", tensor_to_swap);
758   (*swap_out_node->mutable_attr())["_class"].mutable_list()->add_s(coloc_group);
759   (*swap_in_node->mutable_attr())["_class"].mutable_list()->add_s(coloc_group);
760   (*node->mutable_attr())["_class"].mutable_list()->add_s(coloc_group);
761 
762   (*swap_in_node->mutable_attr())["T"].set_type(input_type);
763   (*swap_out_node->mutable_attr())["T"].set_type(input_type);
764   *swap_pair = std::make_pair(swap_out_node, swap_in_node);
765 
766   return Status::OK();
767 }
768 
769 struct SwapInfo {
770   std::vector<int> inputs_to_swap;
771   Costs::NanoSeconds time_to_swap = 0;
772 };
773 
FindSwapInTrigger(const NodeDef * node,const SwapInfo & swap_info,const std::unordered_map<string,const NodeDef * > & name_map,const std::unordered_map<const NodeDef *,Costs::NanoSeconds> & execution_times)774 static const NodeDef* FindSwapInTrigger(
775     const NodeDef* node, const SwapInfo& swap_info,
776     const std::unordered_map<string, const NodeDef*>& name_map,
777     const std::unordered_map<const NodeDef*, Costs::NanoSeconds>&
778         execution_times) {
779   // max_trigger_time stores the time before which the swap operation needs to
780   // be started in order to load the data back onto the accelerator without
781   // delaying the downstream computation.
782   Costs::NanoSeconds max_trigger_time(0);
783   std::set<string> possible_inputs;
784   for (int i = 0; i < node->input_size(); ++i) {
785     const string input_node_name = NodeName(node->input(i));
786     auto it1 = name_map.find(input_node_name);
787     if (it1 == name_map.end()) {
788       return nullptr;
789     }
790     const NodeDef* input_node = it1->second;
791 
792     auto it2 = execution_times.find(input_node);
793     if (it2 == execution_times.end()) {
794       return nullptr;
795     }
796     max_trigger_time = std::max(max_trigger_time, it2->second);
797     possible_inputs.insert(input_node_name);
798   }
799 
800   for (const int i : swap_info.inputs_to_swap) {
801     const string input_node_name = NodeName(node->input(i));
802     possible_inputs.erase(input_node_name);
803   }
804   if (possible_inputs.empty()) {
805     return nullptr;
806   }
807 
808   max_trigger_time -= swap_info.time_to_swap;
809 
810   std::map<Costs::NanoSeconds, const NodeDef*> candidates;
811   std::set<string> already_processed;
812 
813   while (!possible_inputs.empty()) {
814     const string input_node_name = *possible_inputs.begin();
815     possible_inputs.erase(possible_inputs.begin());
816     already_processed.insert(input_node_name);
817     auto it1 = name_map.find(input_node_name);
818     if (it1 == name_map.end()) {
819       return nullptr;
820     }
821     const NodeDef* input_node = it1->second;
822     // Don't jump over frames, since adding a control dependency from one frame
823     // to the next isn't supported. Don't go through branches, since we don't
824     // know whether they'll be executed or not.
825     if (ModifiesFrameInfo(*input_node) || IsSwitch(*input_node) ||
826         IsMerge(*input_node)) {
827       continue;
828     }
829     auto it2 = execution_times.find(input_node);
830     if (it2 == execution_times.end()) {
831       return nullptr;
832     }
833     if (it2->second < max_trigger_time) {
834       candidates[it2->second] = input_node;
835     } else {
836       for (const string& fanin : input_node->input()) {
837         string name = NodeName(fanin);
838         if (already_processed.find(name) == already_processed.end()) {
839           possible_inputs.insert(name);
840         }
841       }
842     }
843   }
844 
845   // Select the candidate that will execute last, since we want to swap the data
846   // back at the last minute while still allowing enough time for data to be
847   // swapped back timely to feed the downstream nodes.
848   if (!candidates.empty()) {
849     return candidates.rbegin()->second;
850   }
851   return nullptr;
852 }
853 
IsSwappable(const MutableGraphView & graph,MutableGraphView::OutputPort output)854 static bool IsSwappable(const MutableGraphView& graph,
855                         MutableGraphView::OutputPort output) {
856   const NodeDef& node = *output.node;
857   // There is no point in swapping out persistent tensors, since the tensor will
858   // continue to use memory.
859   if (IsPersistent(node)) {
860     return false;
861   }
862 
863   const OpDef* op_def;
864   if (!OpRegistry::Global()->LookUpOpDef(node.op(), &op_def).ok()) {
865     return false;
866   }
867   DataType dtype;
868   if (!OutputTypeForNode(node, *op_def, output.port_id, &dtype).ok()) {
869     return false;
870   }
871   // References can only refer to persistent memory: therefore the node isn't
872   // swappable.
873   if (IsRefType(dtype)) {
874     return false;
875   }
876 
877   if (output.node->op() == "Identity" || output.node->op() == "Reshape") {
878     // If placed on the same device, these nodes are just forwarding references
879     // to their input. Therefore they are swappable iff their fanin is swappable
880     // or it resides on a different device.
881     MutableGraphView::InputPort input;
882     input.node = output.node;
883     input.port_id = 0;
884     MutableGraphView::OutputPort fanin = graph.GetRegularFanin(input);
885     if (fanin.node->device() == node.device()) {
886       return IsSwappable(graph, fanin);
887     }
888   }
889   return true;
890 }
891 
FindSwapOutTrigger(const NodeDef * node,int input_id,const MutableGraphView & view,const std::unordered_map<const NodeDef *,Costs::NanoSeconds> & execution_times)892 static NodeDef* FindSwapOutTrigger(
893     const NodeDef* node, int input_id, const MutableGraphView& view,
894     const std::unordered_map<const NodeDef*, Costs::NanoSeconds>&
895         execution_times) {
896   // Find the output port that generated the tensor to swap.
897   MutableGraphView::InputPort swap;
898   swap.node = const_cast<NodeDef*>(node);
899   swap.port_id = input_id;
900   MutableGraphView::OutputPort generator = view.GetRegularFanin(swap);
901   if (!generator.node) {
902     return nullptr;
903   }
904 
905   const absl::flat_hash_set<MutableGraphView::InputPort>& fanout =
906       view.GetFanout(generator);
907   NodeDef* trigger = nullptr;
908   Costs::NanoSeconds earliest_fanout(Costs::NanoSeconds::infinity());
909 
910   for (const auto& port : fanout) {
911     if (port.node == node) {
912       continue;
913     }
914     auto it = execution_times.find(port.node);
915     if (it != execution_times.end() && it->second < earliest_fanout) {
916       earliest_fanout = it->second;
917       trigger = port.node;
918     }
919   }
920 
921   return trigger;
922 }
923 
IsSwappable(MutableGraphView::InputPort input)924 static bool IsSwappable(MutableGraphView::InputPort input) {
925   const NodeDef& node = *input.node;
926 
927   const OpDef* op_def;
928   if (!OpRegistry::Global()->LookUpOpDef(node.op(), &op_def).ok()) {
929     return false;
930   }
931 
932   DataType dtype;
933   if (!InputTypeForNode(node, *op_def, input.port_id, &dtype).ok()) {
934     return false;
935   }
936 
937   return !IsRefType(dtype);
938 }
939 
940 struct MemInfo {
941   MutableGraphView::OutputPort port;
942   int64 memory_used;
943   std::vector<MutableGraphView::InputPort> uses_left;
944   double fitness;
945 
operator <tensorflow::grappler::__anon0f3025ed0111::MemInfo946   bool operator<(const MemInfo& other) const { return fitness < other.fitness; }
947 };
948 
IdentifySwappingCandidates(Cluster * cluster,GrapplerItem * item,std::unordered_set<string> * skip_list,std::unordered_map<NodeDef *,SwapInfo> * nodes_to_swap)949 static bool IdentifySwappingCandidates(
950     Cluster* cluster, GrapplerItem* item, std::unordered_set<string>* skip_list,
951     std::unordered_map<NodeDef*, SwapInfo>* nodes_to_swap) {
952   GraphMemory memory(*item);
953   const std::unordered_map<string, DeviceProperties>& devices =
954       cluster->GetDevices();
955   Status s = memory.InferStatically(devices);
956   if (!s.ok()) {
957     VLOG(1) << "Failed to infer memory usage: " << s.error_message();
958     return false;
959   }
960 
961   bool updated_graph = false;
962   for (const auto& device : devices) {
963     const string& name = device.first;
964     const DeviceProperties& prop = device.second;
965     if (prop.type() != "GPU") {
966       continue;
967     }
968     if (prop.memory_size() <= 0) {
969       VLOG(1) << "Peak memory usage unknown for device " << name;
970       continue;
971     }
972     const GraphMemory::MemoryUsage& mem_usage = memory.GetPeakMemoryUsage(name);
973 
974     if (mem_usage.used_memory <= prop.memory_size()) {
975       continue;
976     }
977     int64 required_savings = mem_usage.used_memory - prop.memory_size();
978 
979     std::unordered_map<string, Costs::NanoSeconds> op_completion_times;
980     {
981       VirtualCluster vcluster(cluster->GetDevices());
982       if (!vcluster.Provision().ok()) {
983         return false;
984       }
985       if (!vcluster.Initialize(*item).ok()) {
986         return false;
987       }
988       RunMetadata metadata;
989       Status s = vcluster.Run(item->graph, item->feed, item->fetch, &metadata);
990       if (!s.ok() && s.code() != error::RESOURCE_EXHAUSTED) {
991         return false;
992       }
993 
994       for (const auto& dev_stats : metadata.step_stats().dev_stats()) {
995         for (const auto& node_stats : dev_stats.node_stats()) {
996           Costs::NanoSeconds exec_time =
997               Costs::NanoSeconds(1) +
998               Costs::MicroSeconds(node_stats.all_start_micros() +
999                                   node_stats.op_end_rel_micros());
1000           op_completion_times.emplace(node_stats.node_name(), exec_time);
1001         }
1002       }
1003     }
1004 
1005     Costs::Duration peak_time = -1;
1006     for (const auto& live_tensor : mem_usage.live_tensors) {
1007       if (live_tensor.allocation_time > peak_time) {
1008         peak_time = live_tensor.allocation_time;
1009       }
1010     }
1011 
1012     std::vector<MemInfo> mem_state;
1013 
1014     MutableGraphView graph(&item->graph);
1015     for (const auto& live_tensor : mem_usage.live_tensors) {
1016       if (live_tensor.memory_used <= 1024) {
1017         // Don't bother with small tensors.
1018         continue;
1019       }
1020       if (live_tensor.deallocation_time - live_tensor.allocation_time <=
1021           Costs::Duration(1e6)) {
1022         // Not enough time to swap.
1023         VLOG(1) << "Not enough time to swap: skipping " << live_tensor.node;
1024         continue;
1025       }
1026 
1027       if (skip_list->find(live_tensor.node) != skip_list->end()) {
1028         continue;
1029       }
1030       MutableGraphView::OutputPort port =
1031           graph.GetOutputPort(live_tensor.node, live_tensor.output_id);
1032       if (!IsSwappable(graph, port)) {
1033         continue;
1034       }
1035       MemInfo mem_info;
1036       mem_info.port = port;
1037       mem_info.memory_used = live_tensor.memory_used;
1038       Costs::Duration allocation_time = live_tensor.allocation_time;
1039       Costs::Duration earliest_use(Costs::Duration::infinity());
1040       bool valid = true;
1041       for (MutableGraphView::InputPort input : graph.GetFanout(port)) {
1042         // Get execution time.
1043         auto it = op_completion_times.find(input.node->name());
1044         if (it == op_completion_times.end()) {
1045           valid = false;
1046           break;
1047         }
1048         if (it->second <= peak_time) {
1049           continue;
1050         }
1051 
1052         if (skip_list->find(input.node->name()) != skip_list->end()) {
1053           valid = false;
1054           break;
1055         }
1056         string input_name =
1057             strings::StrCat(input.node->name(), ":", input.port_id);
1058         if (skip_list->find(input_name) != skip_list->end()) {
1059           valid = false;
1060           break;
1061         }
1062         if (!IsSwappable(input)) {
1063           valid = false;
1064           break;
1065         }
1066 
1067         // Set earliest use time that's after peak.
1068         mem_info.uses_left.emplace_back(input);
1069         earliest_use = std::min(earliest_use, it->second);
1070       }
1071       if (valid && !mem_info.uses_left.empty()) {
1072         // Compute the fitness: we need the tensor to be generated way away of
1073         // the time of peak memory usage (to ensure there is enough time to swap
1074         // it out). We also need to ensure it's used way after the peak time, to
1075         // ensure that swapping the tensor back in won't recreate the memory
1076         // bottleneck. Last but not least, we want the tensor to have as few
1077         // remaining uses as possible.
1078         //
1079         // Note that we must perform the arithmetic inexactly as "double", since
1080         // the values do not fit into any integral type.
1081         mem_info.fitness =
1082             MathUtil::IPow<double>((earliest_use - peak_time).count(), 2) /
1083                 MathUtil::IPow<double>(mem_info.uses_left.size(), 2) +
1084             MathUtil::IPow<double>((allocation_time - peak_time).count(), 2);
1085         mem_info.fitness = -mem_info.fitness;
1086         mem_state.push_back(mem_info);
1087       }
1088     }
1089 
1090     // Sort by fitness
1091     std::sort(mem_state.begin(), mem_state.end());
1092 
1093     for (const MemInfo& mem_info : mem_state) {
1094       for (const MutableGraphView::InputPort fanout_to_swap :
1095            mem_info.uses_left) {
1096         VLOG(1) << "Will swap fanout " << fanout_to_swap.node->name() << ":"
1097                 << fanout_to_swap.port_id << " of tensor "
1098                 << mem_info.port.node->name() << ":" << mem_info.port.port_id
1099                 << " of size " << mem_info.memory_used;
1100 
1101         (*nodes_to_swap)[fanout_to_swap.node].inputs_to_swap.push_back(
1102             fanout_to_swap.port_id);
1103       }
1104       required_savings -= mem_info.memory_used;
1105       updated_graph = true;
1106       if (required_savings < 0) {
1107         break;
1108       }
1109     }
1110   }
1111   return updated_graph;
1112 }
1113 
SwappingPass(RewriterConfig::MemOptType optimization_level,Cluster * cluster,GrapplerItem * item,std::unordered_set<string> * skip_list)1114 bool SwappingPass(RewriterConfig::MemOptType optimization_level,
1115                   Cluster* cluster, GrapplerItem* item,
1116                   std::unordered_set<string>* skip_list) {
1117   std::unordered_map<NodeDef*, SwapInfo> nodes_to_swap;
1118   if (optimization_level == RewriterConfig::DEFAULT_MEM_OPT ||
1119       optimization_level == RewriterConfig::SWAPPING_HEURISTICS ||
1120       optimization_level == RewriterConfig::HEURISTICS) {
1121     // Use heuristics to figure out what needs to be swapped;
1122     IdentifySwappingCandidates(cluster, item, skip_list, &nodes_to_swap);
1123   }
1124   // Look for manual annotatations in the graph.
1125   for (auto& node : *item->graph.mutable_node()) {
1126     if (node.attr().count("_swap_to_host") != 0) {
1127       SwapInfo& swap_info = nodes_to_swap[&node];
1128       const AttrValue& val = node.attr().at("_swap_to_host");
1129       if (val.has_list()) {
1130         for (int64 input_id : val.list().i()) {
1131           swap_info.inputs_to_swap.push_back(input_id);
1132         }
1133       } else {
1134         int64 input_id = val.i();
1135         swap_info.inputs_to_swap.push_back(input_id);
1136       }
1137     }
1138   }
1139   if (nodes_to_swap.empty()) {
1140     // Nothing to do.
1141     return false;
1142   }
1143 
1144   // Estimate the size of the data to swap for each node.
1145   GraphProperties properties(*item);
1146   if (!properties.InferStatically(true).ok()) {
1147     return false;
1148   }
1149   for (auto& swap : nodes_to_swap) {
1150     const NodeDef* node = swap.first;
1151     const std::vector<OpInfo::TensorProperties>& props =
1152         properties.GetInputProperties(node->name());
1153     SwapInfo& swap_info = swap.second;
1154     int64 bytes_to_swap = 0;
1155     for (int64 input_id : swap_info.inputs_to_swap) {
1156       const OpInfo::TensorProperties& t = props[input_id];
1157       bytes_to_swap += CalculateTensorSize(t);
1158     }
1159     // Let's assume we're going to swap over PCIe running at 16 GBps.
1160     swap_info.time_to_swap = bytes_to_swap / 16;
1161   }
1162 
1163   std::unordered_map<const NodeDef*, Costs::NanoSeconds> execution_times;
1164   if (!EstimateEarliestExecutionTimes(*item, cluster, &execution_times).ok()) {
1165     return false;
1166   }
1167 
1168   std::unordered_map<string, const NodeDef*> name_map;
1169   for (const auto& node : item->graph.node()) {
1170     name_map[node.name()] = &node;
1171   }
1172   MutableGraphView view(&item->graph);
1173 
1174   bool updated_graph = false;
1175 
1176   for (auto& swap : nodes_to_swap) {
1177     NodeDef* node = swap.first;
1178     const SwapInfo& swap_info = swap.second;
1179     if (skip_list->find(node->name()) != skip_list->end()) {
1180       continue;
1181     }
1182 
1183     // Make sure the tensor isn't swapped back in right away: look for node that
1184     // will execute just before we need to swap the data back, and add a control
1185     // dependency from that node to the swap node.
1186     const NodeDef* in_trigger =
1187         FindSwapInTrigger(node, swap_info, name_map, execution_times);
1188     // If we failed, don't attempt to reprocess this node in a subsequent pass.
1189     if (!in_trigger) {
1190       skip_list->insert(node->name());
1191       continue;
1192     }
1193 
1194     // Swap all the tensors that are marked with the 'swap_to_host' attribute.
1195     for (int input_id : swap_info.inputs_to_swap) {
1196       string input_name = strings::StrCat(node->name(), ":", input_id);
1197       if (skip_list->find(input_name) != skip_list->end()) {
1198         continue;
1199       } else {
1200         // Don't attempt to reprocess this input in a subsequent pass.
1201         skip_list->insert(input_name);
1202       }
1203 
1204       // Make sure the tensor is swapped out quickly: look for node that
1205       // will execute just after the tensor is generated and add a control
1206       // dependency from the swap out node to that node.
1207       NodeDef* out_trigger =
1208           FindSwapOutTrigger(node, input_id, view, execution_times);
1209       if (!out_trigger) {
1210         continue;
1211       }
1212 
1213       std::pair<NodeDef*, NodeDef*> swap_nodes;
1214       if (!BuildSwapPair(node, input_id, name_map, &item->graph, &swap_nodes)
1215                .ok()) {
1216         continue;
1217       }
1218       *swap_nodes.first->add_input() = node->input(input_id);
1219       *node->mutable_input(input_id) = swap_nodes.second->name();
1220 
1221       // Add the control dependencies needed to delay the execution of the swap.
1222       out_trigger->add_input(strings::StrCat("^", swap_nodes.first->name()));
1223       swap_nodes.second->add_input(strings::StrCat("^", in_trigger->name()));
1224 
1225       // Make sure we won't try to swap the swap nodes in subsequent passes.
1226       skip_list->insert(swap_nodes.first->name());
1227       skip_list->insert(swap_nodes.second->name());
1228     }
1229   }
1230   return updated_graph;
1231 }
1232 
CrossesTaskOrCpuGpuBoundary(const NodeDef & node1,const NodeDef & node2)1233 bool CrossesTaskOrCpuGpuBoundary(const NodeDef& node1, const NodeDef& node2) {
1234   string task1;
1235   string device1;
1236   DeviceNameUtils::SplitDeviceName(node1.device(), &task1, &device1);
1237   string task2;
1238   string device2;
1239   DeviceNameUtils::SplitDeviceName(node2.device(), &task2, &device2);
1240   return task1 != task2 ||
1241          (str_util::StrContains(device1, DEVICE_CPU) &&
1242           str_util::StrContains(device2, DEVICE_GPU)) ||
1243          (str_util::StrContains(device1, DEVICE_GPU) &&
1244           str_util::StrContains(device2, DEVICE_CPU));
1245 }
1246 
1247 // TODO(rmlarsen): Add distributed TF test.
RelaxAllocatorConstraints(GraphDef * optimized_graph)1248 Status RelaxAllocatorConstraints(GraphDef* optimized_graph) {
1249   std::unordered_set<string> devices;
1250   std::vector<int> assign_nodes;
1251   bool found_send = false;
1252   for (int i = 0; i < optimized_graph->node_size(); ++i) {
1253     const NodeDef& node = optimized_graph->node(i);
1254     devices.insert(node.device());
1255     if (IsAssign(node)) {
1256       assign_nodes.push_back(i);
1257     }
1258     if (IsSend(node)) {
1259       found_send = true;
1260       break;
1261     }
1262   }
1263   if (!found_send && devices.size() == 1) {
1264     for (int assign_idx : assign_nodes) {
1265       // Set an attribute telling AssignOp to ignore allocator constraints.
1266       NodeDef* assign_node = optimized_graph->mutable_node(assign_idx);
1267       (*assign_node->mutable_attr())["_grappler_relax_allocator_constraints"]
1268           .set_b(true);
1269     }
1270     return Status::OK();
1271   }
1272 
1273   GraphTopologyView graph_view;
1274   TF_RETURN_IF_ERROR(graph_view.InitializeFromGraph(
1275       *optimized_graph, /*ignore_control_edges=*/true));
1276   std::unordered_set<const NodeDef*> optimized_nodes;
1277 
1278   for (int i : assign_nodes) {
1279     const NodeDef& assign_node = optimized_graph->node(i);
1280 
1281     if (optimized_nodes.find(&assign_node) == optimized_nodes.end()) {
1282       std::vector<const NodeDef*> assign_nodes_in_fanout;
1283       optimized_nodes.insert(&assign_node);
1284       assign_nodes_in_fanout.push_back(&assign_node);
1285 
1286       std::vector<const NodeDef*> transitive_fanout;
1287       // Find the nodes in transitive fanout. If a node is known to never
1288       // forward its inputs, we can skip its fanout.
1289       DfsTraversal(graph_view, {graph_view.GetNode(i)},
1290                    TraversalDirection::kFollowOutputs,
1291                    DfsPredicates::Advance([&](const NodeDef* node) {
1292                      return !NeverForwardsInputs(*node);
1293                    }),
1294                    DfsCallbacks::PreOrder([&](const NodeDef* node) {
1295                      transitive_fanout.push_back(node);
1296                    }));
1297 
1298       bool relax_constraint = true;
1299       // If all nodes in the transitive fanout are on the same device as the
1300       // assign node, there is no need to allocate the output in pinned memory.
1301       for (const NodeDef* fanout_node : transitive_fanout) {
1302         if (relax_constraint &&
1303             (IsSend(*fanout_node) ||
1304              CrossesTaskOrCpuGpuBoundary(*fanout_node, assign_node))) {
1305           relax_constraint = false;
1306           break;
1307         }
1308         if (optimized_nodes.find(fanout_node) == optimized_nodes.end() &&
1309             IsAssign(*fanout_node)) {
1310           assign_nodes_in_fanout.push_back(fanout_node);
1311         }
1312       }
1313 
1314       if (relax_constraint) {
1315         for (const NodeDef* assign_node_in_fanout : assign_nodes_in_fanout) {
1316           // If all devices match in fanout of node(i) then, by transitivity,
1317           // they must also match in the fanout of other assign nodes
1318           // in the fanout of node(i), so we can process them here,
1319           // and save computing their transitive fanout later.
1320           optimized_nodes.insert(assign_node_in_fanout);
1321 
1322           // Set an attribute telling AssignOp to ignore allocator constraints.
1323           const absl::optional<int> assign_node_idx =
1324               graph_view.GetNodeIndex(*assign_node_in_fanout);
1325           NodeDef* assign_node_to_relax =
1326               optimized_graph->mutable_node(assign_node_idx.value());
1327           (*assign_node_to_relax
1328                 ->mutable_attr())["_grappler_relax_allocator_constraints"]
1329               .set_b(true);
1330         }
1331       }
1332     }
1333   }
1334   return Status::OK();
1335 }
1336 
1337 }  // namespace
1338 
Optimize(Cluster * cluster,const GrapplerItem & item,GraphDef * optimized_graph)1339 Status MemoryOptimizer::Optimize(Cluster* cluster, const GrapplerItem& item,
1340                                  GraphDef* optimized_graph) {
1341   GrapplerItem optimized_item(item);
1342 
1343   RecomputationRewritingPass(optimization_level_,
1344                              recomputation_targets_name_scope_,
1345                              &optimized_item.graph, item);
1346 
1347   std::unordered_set<string> skip_list;
1348   // Bound the number of rewrite passes to avoid long processing times on graphs
1349   // that simply won't fit in memory.
1350   bool updated_graph = true;
1351   for (int i = 0; i < 25 && updated_graph; ++i) {
1352     GRAPPLER_RETURN_IF_DEADLINE_EXCEEDED();
1353     updated_graph = false;
1354     if ((optimization_level_ == RewriterConfig::DEFAULT_MEM_OPT ||
1355          optimization_level_ == RewriterConfig::SCHEDULING_HEURISTICS ||
1356          optimization_level_ == RewriterConfig::HEURISTICS) &&
1357         cluster != nullptr) {
1358       updated_graph |= SchedulingPass(cluster, &optimized_item);
1359     }
1360 
1361     GRAPPLER_RETURN_IF_DEADLINE_EXCEEDED();
1362     if ((optimization_level_ == RewriterConfig::DEFAULT_MEM_OPT ||
1363          optimization_level_ == RewriterConfig::SWAPPING_HEURISTICS ||
1364          optimization_level_ == RewriterConfig::HEURISTICS ||
1365          optimization_level_ == RewriterConfig::MANUAL) &&
1366         cluster != nullptr) {
1367       updated_graph |= SwappingPass(optimization_level_, cluster,
1368                                     &optimized_item, &skip_list);
1369     }
1370   }
1371 
1372   TF_RETURN_IF_ERROR(RelaxAllocatorConstraints(&optimized_item.graph));
1373 
1374   optimized_graph->Swap(&optimized_item.graph);
1375   return Status::OK();
1376 }
1377 
Feedback(Cluster * cluster,const GrapplerItem & item,const GraphDef & optimized_graph,double result)1378 void MemoryOptimizer::Feedback(Cluster* cluster, const GrapplerItem& item,
1379                                const GraphDef& optimized_graph, double result) {
1380   // Nothing to do for MemoryOptimizer.
1381 }
1382 
1383 }  // end namespace grappler
1384 }  // end namespace tensorflow
1385