1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/grappler/optimizers/model_pruner.h"
17 
18 #include <unordered_set>
19 
20 #include "absl/container/flat_hash_map.h"
21 #include "absl/container/flat_hash_set.h"
22 #include "tensorflow/core/framework/attr_value.pb.h"
23 #include "tensorflow/core/framework/function.pb.h"
24 #include "tensorflow/core/framework/node_def.pb.h"
25 #include "tensorflow/core/framework/node_def_builder.h"
26 #include "tensorflow/core/framework/types.h"
27 #include "tensorflow/core/framework/versions.pb.h"
28 #include "tensorflow/core/grappler/grappler_item.h"
29 #include "tensorflow/core/grappler/mutable_graph_view.h"
30 #include "tensorflow/core/grappler/op_types.h"
31 #include "tensorflow/core/grappler/utils.h"
32 #include "tensorflow/core/grappler/utils/transitive_fanin.h"
33 
34 namespace tensorflow {
35 namespace grappler {
36 namespace {
37 
IsTrivialIdentity(const NodeDef & node,const GraphView & graph_view)38 bool IsTrivialIdentity(const NodeDef& node, const GraphView& graph_view) {
39   for (const auto input :
40        graph_view.GetFanins(node, /*include_controlling_nodes=*/true)) {
41     if (input.port_id == Graph::kControlSlot) {
42       // Node is driven by control dependency.
43       return false;
44     } else if (IsSwitch(*input.node)) {  // Node is driven by switch.
45       return false;
46     }
47   }
48   for (const auto output :
49        graph_view.GetFanouts(node, /*include_controlled_nodes=*/true)) {
50     if (output.port_id == Graph::kControlSlot) {
51       // Node drives control dependency.
52       return false;
53     } else if (IsMerge(*output.node)) {  // Node feeds merge.
54       return false;
55     }
56   }
57   return true;
58 }
59 
IsTrivialOp(const NodeDef & node,const GraphView & graph_view)60 bool IsTrivialOp(const NodeDef& node, const GraphView& graph_view) {
61   // Remove the stop gradient nodes since they serve no purpose once the graph
62   // is built. Also remove Identity ops.
63   if (IsStopGradient(node)) {
64     return true;
65   }
66   if (IsIdentity(node) || IsIdentityNSingleInput(node)) {
67     return IsTrivialIdentity(node, graph_view);
68   }
69   if (IsNoOp(node) && node.input().empty()) {
70     return true;
71   }
72   // Const nodes are always executed before anything else, so if they only
73   // have control outputs we can remove them.
74   if (IsConstant(node) && node.input().empty() &&
75       graph_view.NumFanouts(node, /*include_controlled_nodes=*/false) == 0) {
76     return true;
77   }
78   return IsAddN(node) && NumNonControlInputs(node) <= 1;
79 }
80 
RemovalIncreasesEdgeCount(const NodeDef & node,const GraphView & graph_view)81 bool RemovalIncreasesEdgeCount(const NodeDef& node,
82                                const GraphView& graph_view) {
83   int in_degree =
84       graph_view.NumFanins(node, /*include_controlling_nodes=*/true);
85   int out_degree =
86       graph_view.NumFanouts(node, /*include_controlled_nodes=*/true);
87   return in_degree * out_degree > in_degree + out_degree;
88 }
89 
IsOutputPortRefValue(const NodeDef & node,int port_id,const OpRegistryInterface & op_registry)90 bool IsOutputPortRefValue(const NodeDef& node, int port_id,
91                           const OpRegistryInterface& op_registry) {
92   const OpRegistrationData* op_reg_data = nullptr;
93   Status s = op_registry.LookUp(node.op(), &op_reg_data);
94   if (s.ok()) {
95     DataType output_type;
96     s = OutputTypeForNode(node, op_reg_data->op_def, port_id, &output_type);
97     if (s.ok() && IsRefType(output_type)) {
98       return true;
99     }
100   }
101   return false;
102 }
103 
CanRemoveNode(const NodeDef & node,const GraphView & graph_view,const absl::flat_hash_set<string> & function_names,const OpRegistryInterface & op_registry)104 bool CanRemoveNode(const NodeDef& node, const GraphView& graph_view,
105                    const absl::flat_hash_set<string>& function_names,
106                    const OpRegistryInterface& op_registry) {
107   if (IsNoOp(node) &&
108       (node.input().empty() ||
109        graph_view.NumFanouts(node, /*include_controlled_nodes=*/true) == 0)) {
110     return true;
111   }
112   if (IsConstant(node) && node.input().empty() &&
113       graph_view.NumFanouts(node, /*include_controlled_nodes=*/false) == 0) {
114     return true;
115   }
116   if (RemovalIncreasesEdgeCount(node, graph_view)) {
117     return false;
118   }
119   for (const auto input :
120        graph_view.GetFanins(node, /*include_controlling_nodes=*/true)) {
121     if (node.device() != input.node->device()) {
122       // Node is driven by a different device.
123       return false;
124     } else if (input.port_id == Graph::kControlSlot) {
125       // Node is driven by control dependency.
126       continue;
127     } else if (function_names.find(input.node->op()) != function_names.end()) {
128       // Node input is a function call.
129       return false;
130     } else if (IsOutputPortRefValue(*input.node, input.port_id, op_registry)) {
131       return false;
132     }
133   }
134   for (const auto output :
135        graph_view.GetFanouts(node, /*include_controlled_nodes=*/false)) {
136     if (function_names.find(output.node->op()) != function_names.end()) {
137       // Node output is a function call.
138       return false;
139     }
140   }
141   return true;
142 }
143 
ForwardInputsInternal(const NodeDef & node,const absl::flat_hash_set<const NodeDef * > & nodes_to_delete,bool add_as_control,NodeDef * new_node,const absl::flat_hash_map<string,const NodeDef * > & optimized_nodes,const GraphView & graph_view)144 void ForwardInputsInternal(
145     const NodeDef& node,
146     const absl::flat_hash_set<const NodeDef*>& nodes_to_delete,
147     bool add_as_control, NodeDef* new_node,
148     const absl::flat_hash_map<string, const NodeDef*>& optimized_nodes,
149     const GraphView& graph_view) {
150   // To speed things up, use the optimized version of the node if
151   // available.
152   auto itr = optimized_nodes.find(node.name());
153   if (itr != optimized_nodes.end()) {
154     for (const string& input : itr->second->input()) {
155       *new_node->add_input() =
156           add_as_control ? AsControlDependency(NodeName(input)) : input;
157     }
158     return;
159   }
160   for (const auto& input : node.input()) {
161     const NodeDef* input_node = graph_view.GetNode(NodeName(input));
162     if (input_node == nullptr) {
163       // Invalid input, preserve it as is.
164       *new_node->add_input() =
165           add_as_control ? AsControlDependency(NodeName(input)) : input;
166       continue;
167     }
168     if (nodes_to_delete.find(input_node) != nodes_to_delete.end()) {
169       ForwardInputsInternal(*input_node, nodes_to_delete,
170                             add_as_control || IsControlInput(input), new_node,
171                             optimized_nodes, graph_view);
172     } else {
173       *new_node->add_input() =
174           add_as_control ? AsControlDependency(NodeName(input)) : input;
175     }
176   }
177 }
178 
ForwardInputs(const NodeDef & original_node,const absl::flat_hash_set<const NodeDef * > & nodes_to_delete,NodeDef * new_node,absl::flat_hash_map<string,const NodeDef * > * optimized_nodes,const GraphView & graph_view)179 void ForwardInputs(const NodeDef& original_node,
180                    const absl::flat_hash_set<const NodeDef*>& nodes_to_delete,
181                    NodeDef* new_node,
182                    absl::flat_hash_map<string, const NodeDef*>* optimized_nodes,
183                    const GraphView& graph_view) {
184   // Forwards inputs of nodes to be deleted to their respective outputs.
185   ForwardInputsInternal(original_node, nodes_to_delete,
186                         /*add_as_control=*/false, new_node, *optimized_nodes,
187                         graph_view);
188   if (!new_node->name().empty()) {
189     (*optimized_nodes)[new_node->name()] = new_node;
190   }
191   // Reorder inputs such that control inputs come after regular inputs.
192   int pos = 0;
193   for (int i = 0; i < new_node->input_size(); ++i) {
194     if (!IsControlInput(new_node->input(i))) {
195       new_node->mutable_input()->SwapElements(pos, i);
196       ++pos;
197     }
198   }
199   DedupControlInputs(new_node);
200 }
201 
IdentityNTerminalPorts(const NodeMap & node_map,const std::vector<string> & terminal_nodes,int graph_size)202 absl::flat_hash_map<string, absl::flat_hash_set<int>> IdentityNTerminalPorts(
203     const NodeMap& node_map, const std::vector<string>& terminal_nodes,
204     int graph_size) {
205   // Determines which ports for IdentityN nodes (that can be rewritten) lead to
206   // a terminal node.
207   std::vector<string> to_visit;
208   to_visit.reserve(graph_size);
209   // Set terminal nodes as visited so terminal nodes that may be IdentityN don't
210   // get pruned later on.
211   absl::flat_hash_set<string> visited(terminal_nodes.begin(),
212                                       terminal_nodes.end());
213   for (const string& terminal_node : terminal_nodes) {
214     NodeDef* node = node_map.GetNode(terminal_node);
215     if (node == nullptr) {
216       continue;
217     }
218     for (const string& input : node->input()) {
219       to_visit.push_back(input);
220     }
221   }
222 
223   absl::flat_hash_set<string> identity_n_fanouts;
224   while (!to_visit.empty()) {
225     string curr = to_visit.back();
226     to_visit.pop_back();
227     NodeDef* curr_node = node_map.GetNode(curr);
228     if (curr_node == nullptr ||
229         visited.find(curr_node->name()) != visited.end()) {
230       continue;
231     }
232     // For IdentityN nodes, only traverse up through the port that comes from a
233     // terminal node along with control inputs. The IdentityN node is not marked
234     // as visited so other node input traversals can go through the other ports
235     // of the IdentityN node.
236     if (IsIdentityN(*curr_node)) {
237       if (identity_n_fanouts.find(curr) == identity_n_fanouts.end()) {
238         identity_n_fanouts.emplace(curr);
239         int pos = NodePositionIfSameNode(curr, curr_node->name());
240         if (pos >= 0) {
241           to_visit.push_back(curr_node->input(pos));
242         }
243         for (const string& input : curr_node->input()) {
244           if (IsControlInput(input) &&
245               identity_n_fanouts.find(input) == identity_n_fanouts.end()) {
246             to_visit.push_back(input);
247           }
248         }
249       }
250     } else {
251       for (const string& input : curr_node->input()) {
252         to_visit.push_back(input);
253       }
254       visited.emplace(curr_node->name());
255     }
256   }
257 
258   absl::flat_hash_map<string, absl::flat_hash_set<int>> identity_n_ports;
259   for (const auto& fanout : identity_n_fanouts) {
260     int pos;
261     string node_name = ParseNodeName(fanout, &pos);
262     if (node_name.empty() || pos < 0) {  // Exclude control inputs.
263       continue;
264     }
265     if (identity_n_ports.find(node_name) == identity_n_ports.end()) {
266       identity_n_ports[node_name] = {pos};
267     } else {
268       identity_n_ports[node_name].emplace(pos);
269     }
270   }
271 
272   return identity_n_ports;
273 }
274 
NewIdentityFromIdentityN(int pos,const NodeDef & identity_n,GraphDef * graph,NodeMap * node_map)275 string NewIdentityFromIdentityN(int pos, const NodeDef& identity_n,
276                                 GraphDef* graph, NodeMap* node_map) {
277   // TODO(lyandy): Migrate over to GrapplerOptimizerStage and use
278   // OptimizedNodeName for new node name.
279   string new_node_name =
280       strings::StrCat(identity_n.name(), "-", pos, "-grappler-ModelPruner");
281   if (node_map->NodeExists(new_node_name)) {
282     return "";
283   }
284   NodeDef* new_node = graph->add_node();
285   Status status = NodeDefBuilder(new_node_name, "Identity")
286                       .Input(identity_n.input(pos), 0,
287                              identity_n.attr().at("T").list().type(pos))
288                       .Device(identity_n.device())
289                       .Finalize(new_node);
290   if (!status.ok()) {
291     return "";
292   }
293   node_map->AddNode(new_node->name(), new_node);
294   node_map->AddOutput(NodeName(new_node->input(0)), new_node->name());
295   return new_node->name();
296 }
297 
RewriteIdentityNAndInputsOutputs(NodeDef * node,int num_non_control_inputs,const absl::flat_hash_set<int> & terminal_ports,GraphDef * graph,NodeMap * node_map)298 Status RewriteIdentityNAndInputsOutputs(
299     NodeDef* node, int num_non_control_inputs,
300     const absl::flat_hash_set<int>& terminal_ports, GraphDef* graph,
301     NodeMap* node_map) {
302   // Rewrite IdentityN node and associated inputs and outputs. For inputs and
303   // outputs that don't lead to a terminal node, a new Identity node is created
304   // and those inputs and outputs are rewritten to use the new Identity node as
305   // their outputs and inputs respectively. For the remaining nodes, the outputs
306   // have their inputs updated with the adjusted port, from the IdentityN node
307   // having less inputs.
308   struct NodeOutputUpdate {
309     string input;
310     string output;
311   };
312 
313   absl::flat_hash_map<int, int> terminal_input_pos;
314   absl::flat_hash_map<int, string> new_identities;
315   int new_idx = 0;
316   for (int i = 0; i < num_non_control_inputs; i++) {
317     if (terminal_ports.find(i) != terminal_ports.end()) {
318       terminal_input_pos[i] = new_idx++;
319     } else {
320       string identity = NewIdentityFromIdentityN(i, *node, graph, node_map);
321       if (identity.empty()) {
322         // Fail early when creating Identity from IdentityN errors.
323         return errors::Internal(
324             "Could not create Identity node from IdentityN node ", node->name(),
325             " at port ", i);
326       }
327       new_identities[i] = identity;
328     }
329   }
330 
331   std::vector<NodeOutputUpdate> updates;
332   for (NodeDef* output : node_map->GetOutputs(node->name())) {
333     for (int i = 0; i < output->input_size(); i++) {
334       string input = output->input(i);
335       if (IsControlInput(input)) {
336         continue;
337       }
338       TensorId input_tensor = ParseTensorName(input);
339       if (input_tensor.node() == node->name()) {
340         if (terminal_ports.find(input_tensor.index()) == terminal_ports.end()) {
341           // Replace input that does not lead to a terminal node with newly
342           // created identity.
343           string new_identity = new_identities[input_tensor.index()];
344           output->set_input(i, new_identity);
345           updates.push_back({new_identity, output->name()});
346         } else {
347           // Update input ports that lead to a terminal node from splitting
348           // inputs.
349           int new_pos = terminal_input_pos[input_tensor.index()];
350           string updated_input_name =
351               new_pos > 0 ? strings::StrCat(node->name(), ":", new_pos)
352                           : node->name();
353           output->set_input(i, updated_input_name);
354         }
355       }
356     }
357   }
358 
359   for (const NodeOutputUpdate& update : updates) {
360     node_map->AddOutput(update.input, update.output);
361   }
362 
363   // Update inputs and types by removing inputs that were split away from
364   // main IdentityN node.
365   const int num_inputs = node->input_size();
366   int curr_pos = 0;
367   auto mutable_inputs = node->mutable_input();
368   auto mutable_types =
369       node->mutable_attr()->at("T").mutable_list()->mutable_type();
370   for (int i = 0; i < num_non_control_inputs; i++) {
371     if (terminal_input_pos.find(i) != terminal_input_pos.end()) {
372       mutable_inputs->SwapElements(i, curr_pos);
373       mutable_types->SwapElements(i, curr_pos);
374       curr_pos++;
375     }
376   }
377   mutable_types->Truncate(curr_pos);
378   // Control inputs.
379   for (int i = num_non_control_inputs; i < num_inputs; i++) {
380     mutable_inputs->SwapElements(i, curr_pos++);
381   }
382   mutable_inputs->DeleteSubrange(curr_pos, num_inputs - curr_pos);
383 
384   return Status::OK();
385 }
386 
SplitIdentityNInputs(GraphDef * graph,const std::vector<string> & terminal_nodes,bool * updated_graph)387 Status SplitIdentityNInputs(GraphDef* graph,
388                             const std::vector<string>& terminal_nodes,
389                             bool* updated_graph) {
390   // For inputs of IdentityN nodes that do not lead to a terminal node, remove
391   // them from IdentityN and create new individual Identity nodes. This will
392   // allow ModelPruner to possibly remove nodes in the transitive fanin of the
393   // newly created Identity nodes.
394   NodeMap node_map(graph);
395 
396   for (auto const& terminal :
397        IdentityNTerminalPorts(node_map, terminal_nodes, graph->node_size())) {
398     NodeDef* node = node_map.GetNode(terminal.first);
399     if (node == nullptr) {
400       continue;
401     }
402 
403     const int num_non_control_inputs = NumNonControlInputs(*node);
404     const int terminal_second_size = terminal.second.size();
405     if (node->attr().count("T") == 0 ||
406         node->attr().at("T").list().type_size() != num_non_control_inputs ||
407         terminal_second_size >= num_non_control_inputs) {
408       continue;
409     }
410 
411     TF_RETURN_IF_ERROR(RewriteIdentityNAndInputsOutputs(
412         node, num_non_control_inputs, terminal.second, graph, &node_map));
413     *updated_graph = true;
414   }
415 
416   return Status::OK();
417 }
418 
419 }  // namespace
420 
Optimize(Cluster * cluster,const GrapplerItem & item,GraphDef * optimized_graph)421 Status ModelPruner::Optimize(Cluster* cluster, const GrapplerItem& item,
422                              GraphDef* optimized_graph) {
423   const std::unordered_set<string> nodes_to_preserve = item.NodesToPreserve();
424 
425   // Prune all the nodes that won't be executed, ie all the nodes that aren't in
426   // the fanin of a fetch node. If fetch nodes aren't specified, we'll assume
427   // the whole graph might be executed.
428   std::unique_ptr<GraphDef> pruned_graph_release;
429   GraphDef* pruned_graph;
430   if (!nodes_to_preserve.empty()) {
431     pruned_graph_release.reset(new GraphDef());
432     pruned_graph = pruned_graph_release.get();
433     pruned_graph->mutable_node()->Reserve(item.graph.node_size());
434     std::vector<string> terminal_nodes(nodes_to_preserve.begin(),
435                                        nodes_to_preserve.end());
436     std::sort(terminal_nodes.begin(), terminal_nodes.end());
437     TF_RETURN_IF_ERROR(
438         SetTransitiveFaninGraph(item.graph, pruned_graph, terminal_nodes));
439     bool did_split_identity_n = false;
440     TF_RETURN_IF_ERROR(SplitIdentityNInputs(pruned_graph, terminal_nodes,
441                                             &did_split_identity_n));
442     if (did_split_identity_n) {
443       GraphDef fanin_split_identity_n_graph;
444       TF_RETURN_IF_ERROR(SetTransitiveFaninGraph(
445           *pruned_graph, &fanin_split_identity_n_graph, terminal_nodes));
446       pruned_graph->Swap(&fanin_split_identity_n_graph);
447     }
448     GRAPPLER_RETURN_IF_DEADLINE_EXCEEDED();
449   } else {
450     pruned_graph = const_cast<GraphDef*>(&item.graph);
451   }
452 
453   GraphView graph_view(pruned_graph);
454   absl::flat_hash_set<string> function_names;
455   for (const auto& function : item.graph.library().function()) {
456     function_names.insert(function.signature().name());
457   }
458   OpRegistryInterface* op_registry = OpRegistry::Global();
459 
460   // Check if we can further prune the graph, by removing the trivial ops.
461   absl::flat_hash_set<const NodeDef*> nodes_to_delete;
462   for (int i = 0; i < pruned_graph->node_size(); ++i) {
463     NodeDef* node = pruned_graph->mutable_node(i);
464     // Remove redundant control inputs, since they may prevent pruning below.
465     DedupControlInputs(node);
466 
467     if (!IsTrivialOp(*node, graph_view)) {
468       VLOG(3) << node->name() << " is not trivial.";
469       continue;
470     }
471 
472     // Don't remove nodes that must be preserved.
473     if (nodes_to_preserve.find(node->name()) != nodes_to_preserve.end()) {
474       continue;
475     }
476 
477     // - Don't remove nodes that drive control dependencies.
478     // - Don't remove nodes that are driven by control dependencies either since
479     //   we can't ensure (yet) that we won't increase the number of control
480     //   dependency edges by deleting them (for example, removing a node driven
481     //   by 10 control edges and driving 10 control edges would result in the
482     //   creation of 100 edges).
483     // - Don't modify nodes that are connected to functions since that can
484     //   result in inlining failures later on.
485     // - Don't prune nodes that are driven by another device since these could
486     //   be used to reduce cross device communication.
487     // - Don't remove nodes that receive reference values, as those can be
488     //   converting references to non-references. It is important to preserve
489     //   these non-references since the partitioner will avoid sending
490     //   non-references across partitions more than once.
491     if (CanRemoveNode(*node, graph_view, function_names, *op_registry)) {
492       nodes_to_delete.insert(node);
493     } else {
494       VLOG(3) << node->name() << " cannot be removed";
495     }
496   }
497 
498   if (nodes_to_delete.empty() && nodes_to_preserve.empty()) {
499     return errors::Aborted("Nothing to do.");
500   }
501 
502   optimized_graph->Clear();
503   *optimized_graph->mutable_library() = item.graph.library();
504   *optimized_graph->mutable_versions() = item.graph.versions();
505   if (nodes_to_delete.empty()) {
506     optimized_graph->mutable_node()->Swap(pruned_graph->mutable_node());
507     return Status::OK();
508   }
509 
510   const bool fetches_are_known = !item.fetch.empty();
511   absl::flat_hash_map<string, const NodeDef*> optimized_nodes;
512   optimized_graph->mutable_node()->Reserve(pruned_graph->node_size());
513   for (const auto& node : pruned_graph->node()) {
514     if (!fetches_are_known ||
515         nodes_to_delete.find(&node) == nodes_to_delete.end()) {
516       NodeDef* new_node = optimized_graph->add_node();
517       *new_node = node;
518       new_node->clear_input();
519       ForwardInputs(node, nodes_to_delete, new_node, &optimized_nodes,
520                     graph_view);
521     }
522   }
523   VLOG(1) << "Pruned " << nodes_to_delete.size()
524           << " nodes from the graph. The graph now contains "
525           << optimized_graph->node_size() << " nodes.";
526   if (optimized_graph->node_size() > item.graph.node_size()) {
527     return errors::Internal("Pruning increased graph size.");
528   }
529   return Status::OK();
530 }
531 
Feedback(Cluster * cluster,const GrapplerItem & item,const GraphDef & optimized_graph,double result)532 void ModelPruner::Feedback(Cluster* cluster, const GrapplerItem& item,
533                            const GraphDef& optimized_graph, double result) {
534   // Nothing to do for ModelPruner.
535 }
536 
537 }  // end namespace grappler
538 }  // end namespace tensorflow
539