1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/grappler/optimizers/model_pruner.h"
17 
18 #include <unordered_set>
19 
20 #include "absl/container/flat_hash_map.h"
21 #include "absl/container/flat_hash_set.h"
22 #include "tensorflow/core/framework/attr_value.pb.h"
23 #include "tensorflow/core/framework/function.pb.h"
24 #include "tensorflow/core/framework/node_def.pb.h"
25 #include "tensorflow/core/framework/node_def_builder.h"
26 #include "tensorflow/core/framework/types.h"
27 #include "tensorflow/core/framework/versions.pb.h"
28 #include "tensorflow/core/grappler/grappler_item.h"
29 #include "tensorflow/core/grappler/mutable_graph_view.h"
30 #include "tensorflow/core/grappler/op_types.h"
31 #include "tensorflow/core/grappler/utils.h"
32 
33 namespace tensorflow {
34 namespace grappler {
35 
IsTrivialIdentity(const NodeDef & node,const MutableGraphView & graph_view)36 bool IsTrivialIdentity(const NodeDef& node,
37                        const MutableGraphView& graph_view) {
38   for (const auto input :
39        graph_view.GetFanins(node, /*include_controlling_nodes=*/true)) {
40     if (input.port_id == Graph::kControlSlot) {
41       // Node is driven by control dependency.
42       return false;
43     } else if (IsSwitch(*input.node)) {  // Node is driven by switch.
44       return false;
45     }
46   }
47   for (const auto output :
48        graph_view.GetFanouts(node, /*include_controlled_nodes=*/true)) {
49     if (output.port_id == Graph::kControlSlot) {
50       // Node drives control dependency.
51       return false;
52     } else if (IsMerge(*output.node)) {  // Node feeds merge.
53       return false;
54     }
55   }
56   return true;
57 }
58 
IsTrivialOp(const NodeDef & node,const MutableGraphView & graph_view)59 bool IsTrivialOp(const NodeDef& node, const MutableGraphView& graph_view) {
60   // Remove the stop gradient nodes since they serve no purpose once the graph
61   // is built. Also remove Identity ops.
62   if (IsStopGradient(node)) {
63     return true;
64   }
65   if (IsIdentity(node) || IsIdentityNSingleInput(node)) {
66     return IsTrivialIdentity(node, graph_view);
67   }
68 
69   return IsAddN(node) && NumNonControlInputs(node) <= 1;
70 }
71 
RemovalIncreasesEdgeCount(const NodeDef & node,const MutableGraphView & graph_view)72 bool RemovalIncreasesEdgeCount(const NodeDef& node,
73                                const MutableGraphView& graph_view) {
74   int in_degree =
75       graph_view.NumFanins(node, /*include_controlling_nodes=*/true);
76   int out_degree =
77       graph_view.NumFanouts(node, /*include_controlling_nodes=*/true);
78   return in_degree * out_degree > in_degree + out_degree;
79 }
80 
IsOutputPortRefValue(const NodeDef & node,int port_id,const OpRegistryInterface & op_registry)81 bool IsOutputPortRefValue(const NodeDef& node, int port_id,
82                           const OpRegistryInterface& op_registry) {
83   const OpRegistrationData* op_reg_data = nullptr;
84   Status s = op_registry.LookUp(node.op(), &op_reg_data);
85   if (s.ok()) {
86     DataType output_type;
87     s = OutputTypeForNode(node, op_reg_data->op_def, port_id, &output_type);
88     if (s.ok() && IsRefType(output_type)) {
89       return true;
90     }
91   }
92   return false;
93 }
94 
CanRemoveNode(const NodeDef & node,const MutableGraphView & graph_view,const absl::flat_hash_set<string> & function_names,const OpRegistryInterface & op_registry)95 bool CanRemoveNode(const NodeDef& node, const MutableGraphView& graph_view,
96                    const absl::flat_hash_set<string>& function_names,
97                    const OpRegistryInterface& op_registry) {
98   if (RemovalIncreasesEdgeCount(node, graph_view)) {
99     return false;
100   }
101   for (const auto input :
102        graph_view.GetFanins(node, /*include_controlling_nodes=*/true)) {
103     if (node.device() != input.node->device()) {
104       // Node is driven by a different device.
105       return false;
106     } else if (input.port_id == Graph::kControlSlot) {
107       // Node is driven by control dependency.
108       continue;
109     } else if (function_names.find(input.node->op()) != function_names.end()) {
110       // Node input is a function call.
111       return false;
112     } else if (IsOutputPortRefValue(*input.node, input.port_id, op_registry)) {
113       return false;
114     }
115   }
116   for (const auto output :
117        graph_view.GetFanouts(node, /*include_controlled_nodes=*/false)) {
118     if (function_names.find(output.node->op()) != function_names.end()) {
119       // Node output is a function call.
120       return false;
121     }
122   }
123   return true;
124 }
125 
ForwardInputsInternal(const NodeDef & node,const absl::flat_hash_set<const NodeDef * > & nodes_to_delete,bool add_as_control,NodeDef * new_node,const absl::flat_hash_map<string,const NodeDef * > & optimized_nodes,const MutableGraphView & graph_view)126 void ForwardInputsInternal(
127     const NodeDef& node,
128     const absl::flat_hash_set<const NodeDef*>& nodes_to_delete,
129     bool add_as_control, NodeDef* new_node,
130     const absl::flat_hash_map<string, const NodeDef*>& optimized_nodes,
131     const MutableGraphView& graph_view) {
132   // To speed things up, use the optimized version of the node if
133   // available.
134   auto itr = optimized_nodes.find(node.name());
135   if (itr != optimized_nodes.end()) {
136     for (const string& input : itr->second->input()) {
137       *new_node->add_input() =
138           add_as_control ? AsControlDependency(NodeName(input)) : input;
139     }
140     return;
141   }
142   for (const auto& input : node.input()) {
143     const NodeDef* input_node = graph_view.GetNode(NodeName(input));
144     if (input_node == nullptr) {
145       // Invalid input, preserve it as is.
146       *new_node->add_input() =
147           add_as_control ? AsControlDependency(NodeName(input)) : input;
148       continue;
149     }
150     if (nodes_to_delete.find(input_node) != nodes_to_delete.end()) {
151       ForwardInputsInternal(*input_node, nodes_to_delete,
152                             add_as_control || IsControlInput(input), new_node,
153                             optimized_nodes, graph_view);
154     } else {
155       *new_node->add_input() =
156           add_as_control ? AsControlDependency(NodeName(input)) : input;
157     }
158   }
159 }
160 
ForwardInputs(const NodeDef & original_node,const absl::flat_hash_set<const NodeDef * > & nodes_to_delete,NodeDef * new_node,absl::flat_hash_map<string,const NodeDef * > * optimized_nodes,const MutableGraphView & graph_view)161 void ForwardInputs(const NodeDef& original_node,
162                    const absl::flat_hash_set<const NodeDef*>& nodes_to_delete,
163                    NodeDef* new_node,
164                    absl::flat_hash_map<string, const NodeDef*>* optimized_nodes,
165                    const MutableGraphView& graph_view) {
166   // Forwards inputs of nodes to be deleted to their respective outputs.
167   ForwardInputsInternal(original_node, nodes_to_delete,
168                         /*add_as_control=*/false, new_node, *optimized_nodes,
169                         graph_view);
170   if (!new_node->name().empty()) {
171     (*optimized_nodes)[new_node->name()] = new_node;
172   }
173   // Reorder inputs such that control inputs come after regular inputs.
174   int pos = 0;
175   for (int i = 0; i < new_node->input_size(); ++i) {
176     if (!IsControlInput(new_node->input(i))) {
177       new_node->mutable_input()->SwapElements(pos, i);
178       ++pos;
179     }
180   }
181   DedupControlInputs(new_node);
182 }
183 
IdentityNTerminalPorts(const NodeMap & node_map,const std::vector<string> & terminal_nodes,int graph_size)184 absl::flat_hash_map<string, absl::flat_hash_set<int>> IdentityNTerminalPorts(
185     const NodeMap& node_map, const std::vector<string>& terminal_nodes,
186     int graph_size) {
187   // Determines which ports for IdentityN nodes (that can be rewritten) lead to
188   // a terminal node.
189   std::vector<string> to_visit;
190   to_visit.reserve(graph_size);
191   // Set terminal nodes as visited so terminal nodes that may be IdentityN don't
192   // get pruned later on.
193   absl::flat_hash_set<string> visited(terminal_nodes.begin(),
194                                       terminal_nodes.end());
195   for (string terminal_node : terminal_nodes) {
196     NodeDef* node = node_map.GetNode(terminal_node);
197     if (node == nullptr) {
198       continue;
199     }
200     for (string input : node->input()) {
201       to_visit.push_back(input);
202     }
203   }
204 
205   absl::flat_hash_set<string> identity_n_fanouts;
206   while (!to_visit.empty()) {
207     string curr = to_visit.back();
208     to_visit.pop_back();
209     NodeDef* curr_node = node_map.GetNode(curr);
210     if (curr_node == nullptr ||
211         visited.find(curr_node->name()) != visited.end()) {
212       continue;
213     }
214     // For IdentityN nodes, only traverse up through the port that comes from a
215     // terminal node along with control inputs. The IdentityN node is not marked
216     // as visited so other node input traversals can go through the other ports
217     // of the IdentityN node.
218     if (IsIdentityN(*curr_node)) {
219       if (identity_n_fanouts.find(curr) == identity_n_fanouts.end()) {
220         identity_n_fanouts.emplace(curr);
221         int pos = NodePositionIfSameNode(curr, curr_node->name());
222         if (pos >= 0) {
223           to_visit.push_back(curr_node->input(pos));
224         }
225         for (const string& input : curr_node->input()) {
226           if (IsControlInput(input) &&
227               identity_n_fanouts.find(input) == identity_n_fanouts.end()) {
228             to_visit.push_back(input);
229           }
230         }
231       }
232     } else {
233       for (const string& input : curr_node->input()) {
234         to_visit.push_back(input);
235       }
236       visited.emplace(curr_node->name());
237     }
238   }
239 
240   absl::flat_hash_map<string, absl::flat_hash_set<int>> identity_n_ports;
241   for (const auto& fanout : identity_n_fanouts) {
242     int pos;
243     string node_name = ParseNodeName(fanout, &pos);
244     if (node_name.empty() || pos < 0) {  // Exclude control inputs.
245       continue;
246     }
247     if (identity_n_ports.find(node_name) == identity_n_ports.end()) {
248       identity_n_ports[node_name] = {pos};
249     } else {
250       identity_n_ports[node_name].emplace(pos);
251     }
252   }
253 
254   return identity_n_ports;
255 }
256 
NewIdentityFromIdentityN(int pos,const NodeDef & identity_n,GraphDef * graph,NodeMap * node_map)257 string NewIdentityFromIdentityN(int pos, const NodeDef& identity_n,
258                                 GraphDef* graph, NodeMap* node_map) {
259   // TODO(lyandy): Migrate over to GrapplerOptimizerStage and use
260   // OptimizedNodeName for new node name.
261   string new_node_name =
262       strings::StrCat(identity_n.name(), "-", pos, "-grappler-ModelPruner");
263   if (node_map->NodeExists(new_node_name)) {
264     return "";
265   }
266   NodeDef* new_node = graph->add_node();
267   Status status = NodeDefBuilder(new_node_name, "Identity")
268                       .Input(identity_n.input(pos), 0,
269                              identity_n.attr().at("T").list().type(pos))
270                       .Device(identity_n.device())
271                       .Finalize(new_node);
272   if (!status.ok()) {
273     return "";
274   }
275   node_map->AddNode(new_node->name(), new_node);
276   node_map->AddOutput(NodeName(new_node->input(0)), new_node->name());
277   return new_node->name();
278 }
279 
RewriteIdentityNAndInputsOutputs(NodeDef * node,int num_non_control_inputs,const absl::flat_hash_set<int> & terminal_ports,GraphDef * graph,NodeMap * node_map)280 Status RewriteIdentityNAndInputsOutputs(
281     NodeDef* node, int num_non_control_inputs,
282     const absl::flat_hash_set<int>& terminal_ports, GraphDef* graph,
283     NodeMap* node_map) {
284   // Rewrite IdentityN node and associated inputs and outputs. For inputs and
285   // outputs that don't lead to a terminal node, a new Identity node is created
286   // and those inputs and outputs are rewritten to use the new Identity node as
287   // their outputs and inputs respectively. For the remaining nodes, the ouputs
288   // have their inputs updated with the adjusted port, from the IdentityN node
289   // having less inputs.
290   struct NodeOutputUpdate {
291     string input;
292     string output;
293   };
294 
295   absl::flat_hash_map<int, int> terminal_input_pos;
296   absl::flat_hash_map<int, string> new_identities;
297   int new_idx = 0;
298   for (int i = 0; i < num_non_control_inputs; i++) {
299     if (terminal_ports.find(i) != terminal_ports.end()) {
300       terminal_input_pos[i] = new_idx++;
301     } else {
302       string identity = NewIdentityFromIdentityN(i, *node, graph, node_map);
303       if (identity.empty()) {
304         // Fail early when creating Identity from IdentityN errors.
305         return errors::Internal(
306             "Could not create Identity node from IdentityN node ", node->name(),
307             " at port ", i);
308       }
309       new_identities[i] = identity;
310     }
311   }
312 
313   std::vector<NodeOutputUpdate> updates;
314   for (NodeDef* output : node_map->GetOutputs(node->name())) {
315     for (int i = 0; i < output->input_size(); i++) {
316       string input = output->input(i);
317       if (IsControlInput(input)) {
318         continue;
319       }
320       TensorId input_tensor = ParseTensorName(input);
321       if (input_tensor.node() == node->name()) {
322         if (terminal_ports.find(input_tensor.index()) == terminal_ports.end()) {
323           // Replace input that does not lead to a terminal node with newly
324           // created identity.
325           string new_identity = new_identities[input_tensor.index()];
326           output->set_input(i, new_identity);
327           updates.push_back({new_identity, output->name()});
328         } else {
329           // Update input ports that lead to a terminal node from splitting
330           // inputs.
331           int new_pos = terminal_input_pos[input_tensor.index()];
332           string updated_input_name =
333               new_pos > 0 ? strings::StrCat(node->name(), ":", new_pos)
334                           : node->name();
335           output->set_input(i, updated_input_name);
336         }
337       }
338     }
339   }
340 
341   for (NodeOutputUpdate update : updates) {
342     node_map->AddOutput(update.input, update.output);
343   }
344 
345   // Update inputs and types by removing inputs that were split away from
346   // main IdentityN node.
347   const int num_inputs = node->input_size();
348   int curr_pos = 0;
349   auto mutable_inputs = node->mutable_input();
350   auto mutable_types =
351       node->mutable_attr()->at("T").mutable_list()->mutable_type();
352   for (int i = 0; i < num_non_control_inputs; i++) {
353     if (terminal_input_pos.find(i) != terminal_input_pos.end()) {
354       mutable_inputs->SwapElements(i, curr_pos);
355       mutable_types->SwapElements(i, curr_pos);
356       curr_pos++;
357     }
358   }
359   mutable_types->Truncate(curr_pos);
360   // Control inputs.
361   for (int i = num_non_control_inputs; i < num_inputs; i++) {
362     mutable_inputs->SwapElements(i, curr_pos++);
363   }
364   mutable_inputs->DeleteSubrange(curr_pos, num_inputs - curr_pos);
365 
366   return Status::OK();
367 }
368 
SplitIdentityNInputs(GraphDef * graph,const std::vector<string> & terminal_nodes,bool * updated_graph)369 Status SplitIdentityNInputs(GraphDef* graph,
370                             const std::vector<string>& terminal_nodes,
371                             bool* updated_graph) {
372   // For inputs of IdentityN nodes that do not lead to a terminal node, remove
373   // them from IdentityN and create new individual Identity nodes. This will
374   // allow ModelPruner to possibly remove nodes in the transitive fanin of the
375   // newly created Identity nodes.
376   NodeMap node_map(graph);
377 
378   for (auto const& terminal :
379        IdentityNTerminalPorts(node_map, terminal_nodes, graph->node_size())) {
380     NodeDef* node = node_map.GetNode(terminal.first);
381     if (node == nullptr) {
382       continue;
383     }
384 
385     const int num_non_control_inputs = NumNonControlInputs(*node);
386     if (node->attr().count("T") == 0 ||
387         node->attr().at("T").list().type_size() != num_non_control_inputs ||
388         terminal.second.size() >= num_non_control_inputs) {
389       continue;
390     }
391 
392     TF_RETURN_IF_ERROR(RewriteIdentityNAndInputsOutputs(
393         node, num_non_control_inputs, terminal.second, graph, &node_map));
394     *updated_graph = true;
395   }
396 
397   return Status::OK();
398 }
399 
SetTransitiveFaninGraph(const GraphDef & input_graph,GraphDef * output_graph,const std::vector<string> & terminal_nodes)400 Status SetTransitiveFaninGraph(const GraphDef& input_graph,
401                                GraphDef* output_graph,
402                                const std::vector<string>& terminal_nodes) {
403   // Determines transitive fanin nodes from terminal nodes and add them to the
404   // output graph.
405   bool ill_formed = false;
406   std::vector<const NodeDef*> keep =
407       ComputeTransitiveFanin(input_graph, terminal_nodes, &ill_formed);
408   if (ill_formed) {
409     // Some graph edges are invalid, or some of the feeds/fetch don't exist:
410     // let's be conservative and preserve the graph as is.
411     return errors::InvalidArgument("Invalid input graph.");
412   }
413   // Try to keep the nodes ordered somewhat topologically since this helps
414   // further optimizations perform better.
415   output_graph->mutable_node()->Reserve(keep.size());
416   for (int i = keep.size() - 1; i >= 0; --i) {
417     *output_graph->add_node() = *keep[i];
418   }
419 
420   return Status::OK();
421 }
422 
Optimize(Cluster * cluster,const GrapplerItem & item,GraphDef * pruned_graph)423 Status ModelPruner::Optimize(Cluster* cluster, const GrapplerItem& item,
424                              GraphDef* pruned_graph) {
425   const std::unordered_set<string> nodes_to_preserve = item.NodesToPreserve();
426 
427   // Prune all the nodes that won't be executed, ie all the nodes that aren't in
428   // the fanin of a fetch node. If fetch nodes aren't specified, we'll assume
429   // the whole graph might be executed.
430   GrapplerItem runnable_item;
431   if (!nodes_to_preserve.empty()) {
432     std::vector<string> terminal_nodes(nodes_to_preserve.begin(),
433                                        nodes_to_preserve.end());
434     std::sort(terminal_nodes.begin(), terminal_nodes.end());
435     TF_RETURN_IF_ERROR(SetTransitiveFaninGraph(item.graph, &runnable_item.graph,
436                                                terminal_nodes));
437     bool did_split_identity_n = false;
438     TF_RETURN_IF_ERROR(SplitIdentityNInputs(
439         &runnable_item.graph, terminal_nodes, &did_split_identity_n));
440     if (did_split_identity_n) {
441       GraphDef fanin_split_identity_n_graph;
442       TF_RETURN_IF_ERROR(SetTransitiveFaninGraph(
443           runnable_item.graph, &fanin_split_identity_n_graph, terminal_nodes));
444       runnable_item.graph.Swap(&fanin_split_identity_n_graph);
445     }
446   } else {
447     runnable_item = item;
448   }
449 
450   MutableGraphView graph_view(&runnable_item.graph);
451   absl::flat_hash_set<string> function_names;
452   for (const auto& function : item.graph.library().function()) {
453     function_names.insert(function.signature().name());
454   }
455   OpRegistryInterface* op_registry = OpRegistry::Global();
456 
457   // Check if we can further prune the graph, by removing the trivial ops.
458   absl::flat_hash_set<const NodeDef*> nodes_to_delete;
459   for (auto& node : runnable_item.graph.node()) {
460     if (!IsTrivialOp(node, graph_view)) {
461       continue;
462     }
463 
464     // Don't remove nodes that must be preserved.
465     if (nodes_to_preserve.find(node.name()) != nodes_to_preserve.end()) {
466       continue;
467     }
468 
469     // - Don't remove nodes that drive control dependencies.
470     // - Don't remove nodes that are driven by control dependencies either since
471     //   we can't ensure (yet) that we won't increase the number of control
472     //   dependency edges by deleting them (for example, removing a node driven
473     //   by 10 control edges and driving 10 control edges would result in the
474     //   creation of 100 edges).
475     // - Don't modify nodes that are connected to functions since that can
476     //   result in inlining failures later on.
477     // - Don't prune nodes that are driven by another device since these could
478     //   be used to reduce cross device communication.
479     // - Don't remove nodes that receive reference values, as those can be
480     //   converting references to non-references. It is important to preserve
481     //   these non-references since the partitioner will avoid sending
482     //   non-references across partitions more than once.
483     if (CanRemoveNode(node, graph_view, function_names, *op_registry)) {
484       nodes_to_delete.insert(&node);
485     }
486   }
487 
488   pruned_graph->Clear();
489   *pruned_graph->mutable_library() = item.graph.library();
490   *pruned_graph->mutable_versions() = item.graph.versions();
491 
492   if (nodes_to_delete.empty()) {
493     pruned_graph->mutable_node()->Swap(runnable_item.graph.mutable_node());
494     return Status::OK();
495   }
496 
497   const bool fetches_are_known = !item.fetch.empty();
498   pruned_graph->mutable_node()->Reserve(runnable_item.graph.node_size());
499   absl::flat_hash_map<string, const NodeDef*> optimized_nodes;
500   for (auto& node : runnable_item.graph.node()) {
501     if (!fetches_are_known ||
502         nodes_to_delete.find(&node) == nodes_to_delete.end()) {
503       NodeDef* new_node = pruned_graph->add_node();
504       *new_node = node;
505       new_node->clear_input();
506       ForwardInputs(node, nodes_to_delete, new_node, &optimized_nodes,
507                     graph_view);
508     }
509   }
510   VLOG(1) << "Pruned " << nodes_to_delete.size()
511           << " nodes from the graph. The graph now contains "
512           << pruned_graph->node_size() << " nodes.";
513   CHECK_LE(pruned_graph->node_size(), item.graph.node_size());
514 
515   return Status::OK();
516 }
517 
Feedback(Cluster * cluster,const GrapplerItem & item,const GraphDef & pruned_graph,double result)518 void ModelPruner::Feedback(Cluster* cluster, const GrapplerItem& item,
519                            const GraphDef& pruned_graph, double result) {
520   // Nothing to do for ModelPruner.
521 }
522 
523 }  // end namespace grappler
524 }  // end namespace tensorflow
525