1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/core/grappler/optimizers/model_pruner.h"
17
18 #include <unordered_set>
19
20 #include "absl/container/flat_hash_map.h"
21 #include "absl/container/flat_hash_set.h"
22 #include "tensorflow/core/framework/attr_value.pb.h"
23 #include "tensorflow/core/framework/function.pb.h"
24 #include "tensorflow/core/framework/node_def.pb.h"
25 #include "tensorflow/core/framework/node_def_builder.h"
26 #include "tensorflow/core/framework/types.h"
27 #include "tensorflow/core/framework/versions.pb.h"
28 #include "tensorflow/core/grappler/grappler_item.h"
29 #include "tensorflow/core/grappler/mutable_graph_view.h"
30 #include "tensorflow/core/grappler/op_types.h"
31 #include "tensorflow/core/grappler/utils.h"
32
33 namespace tensorflow {
34 namespace grappler {
35
IsTrivialIdentity(const NodeDef & node,const MutableGraphView & graph_view)36 bool IsTrivialIdentity(const NodeDef& node,
37 const MutableGraphView& graph_view) {
38 for (const auto input :
39 graph_view.GetFanins(node, /*include_controlling_nodes=*/true)) {
40 if (input.port_id == Graph::kControlSlot) {
41 // Node is driven by control dependency.
42 return false;
43 } else if (IsSwitch(*input.node)) { // Node is driven by switch.
44 return false;
45 }
46 }
47 for (const auto output :
48 graph_view.GetFanouts(node, /*include_controlled_nodes=*/true)) {
49 if (output.port_id == Graph::kControlSlot) {
50 // Node drives control dependency.
51 return false;
52 } else if (IsMerge(*output.node)) { // Node feeds merge.
53 return false;
54 }
55 }
56 return true;
57 }
58
IsTrivialOp(const NodeDef & node,const MutableGraphView & graph_view)59 bool IsTrivialOp(const NodeDef& node, const MutableGraphView& graph_view) {
60 // Remove the stop gradient nodes since they serve no purpose once the graph
61 // is built. Also remove Identity ops.
62 if (IsStopGradient(node)) {
63 return true;
64 }
65 if (IsIdentity(node) || IsIdentityNSingleInput(node)) {
66 return IsTrivialIdentity(node, graph_view);
67 }
68
69 return IsAddN(node) && NumNonControlInputs(node) <= 1;
70 }
71
RemovalIncreasesEdgeCount(const NodeDef & node,const MutableGraphView & graph_view)72 bool RemovalIncreasesEdgeCount(const NodeDef& node,
73 const MutableGraphView& graph_view) {
74 int in_degree =
75 graph_view.NumFanins(node, /*include_controlling_nodes=*/true);
76 int out_degree =
77 graph_view.NumFanouts(node, /*include_controlling_nodes=*/true);
78 return in_degree * out_degree > in_degree + out_degree;
79 }
80
IsOutputPortRefValue(const NodeDef & node,int port_id,const OpRegistryInterface & op_registry)81 bool IsOutputPortRefValue(const NodeDef& node, int port_id,
82 const OpRegistryInterface& op_registry) {
83 const OpRegistrationData* op_reg_data = nullptr;
84 Status s = op_registry.LookUp(node.op(), &op_reg_data);
85 if (s.ok()) {
86 DataType output_type;
87 s = OutputTypeForNode(node, op_reg_data->op_def, port_id, &output_type);
88 if (s.ok() && IsRefType(output_type)) {
89 return true;
90 }
91 }
92 return false;
93 }
94
CanRemoveNode(const NodeDef & node,const MutableGraphView & graph_view,const absl::flat_hash_set<string> & function_names,const OpRegistryInterface & op_registry)95 bool CanRemoveNode(const NodeDef& node, const MutableGraphView& graph_view,
96 const absl::flat_hash_set<string>& function_names,
97 const OpRegistryInterface& op_registry) {
98 if (RemovalIncreasesEdgeCount(node, graph_view)) {
99 return false;
100 }
101 for (const auto input :
102 graph_view.GetFanins(node, /*include_controlling_nodes=*/true)) {
103 if (node.device() != input.node->device()) {
104 // Node is driven by a different device.
105 return false;
106 } else if (input.port_id == Graph::kControlSlot) {
107 // Node is driven by control dependency.
108 continue;
109 } else if (function_names.find(input.node->op()) != function_names.end()) {
110 // Node input is a function call.
111 return false;
112 } else if (IsOutputPortRefValue(*input.node, input.port_id, op_registry)) {
113 return false;
114 }
115 }
116 for (const auto output :
117 graph_view.GetFanouts(node, /*include_controlled_nodes=*/false)) {
118 if (function_names.find(output.node->op()) != function_names.end()) {
119 // Node output is a function call.
120 return false;
121 }
122 }
123 return true;
124 }
125
ForwardInputsInternal(const NodeDef & node,const absl::flat_hash_set<const NodeDef * > & nodes_to_delete,bool add_as_control,NodeDef * new_node,const absl::flat_hash_map<string,const NodeDef * > & optimized_nodes,const MutableGraphView & graph_view)126 void ForwardInputsInternal(
127 const NodeDef& node,
128 const absl::flat_hash_set<const NodeDef*>& nodes_to_delete,
129 bool add_as_control, NodeDef* new_node,
130 const absl::flat_hash_map<string, const NodeDef*>& optimized_nodes,
131 const MutableGraphView& graph_view) {
132 // To speed things up, use the optimized version of the node if
133 // available.
134 auto itr = optimized_nodes.find(node.name());
135 if (itr != optimized_nodes.end()) {
136 for (const string& input : itr->second->input()) {
137 *new_node->add_input() =
138 add_as_control ? AsControlDependency(NodeName(input)) : input;
139 }
140 return;
141 }
142 for (const auto& input : node.input()) {
143 const NodeDef* input_node = graph_view.GetNode(NodeName(input));
144 if (input_node == nullptr) {
145 // Invalid input, preserve it as is.
146 *new_node->add_input() =
147 add_as_control ? AsControlDependency(NodeName(input)) : input;
148 continue;
149 }
150 if (nodes_to_delete.find(input_node) != nodes_to_delete.end()) {
151 ForwardInputsInternal(*input_node, nodes_to_delete,
152 add_as_control || IsControlInput(input), new_node,
153 optimized_nodes, graph_view);
154 } else {
155 *new_node->add_input() =
156 add_as_control ? AsControlDependency(NodeName(input)) : input;
157 }
158 }
159 }
160
ForwardInputs(const NodeDef & original_node,const absl::flat_hash_set<const NodeDef * > & nodes_to_delete,NodeDef * new_node,absl::flat_hash_map<string,const NodeDef * > * optimized_nodes,const MutableGraphView & graph_view)161 void ForwardInputs(const NodeDef& original_node,
162 const absl::flat_hash_set<const NodeDef*>& nodes_to_delete,
163 NodeDef* new_node,
164 absl::flat_hash_map<string, const NodeDef*>* optimized_nodes,
165 const MutableGraphView& graph_view) {
166 // Forwards inputs of nodes to be deleted to their respective outputs.
167 ForwardInputsInternal(original_node, nodes_to_delete,
168 /*add_as_control=*/false, new_node, *optimized_nodes,
169 graph_view);
170 if (!new_node->name().empty()) {
171 (*optimized_nodes)[new_node->name()] = new_node;
172 }
173 // Reorder inputs such that control inputs come after regular inputs.
174 int pos = 0;
175 for (int i = 0; i < new_node->input_size(); ++i) {
176 if (!IsControlInput(new_node->input(i))) {
177 new_node->mutable_input()->SwapElements(pos, i);
178 ++pos;
179 }
180 }
181 DedupControlInputs(new_node);
182 }
183
IdentityNTerminalPorts(const NodeMap & node_map,const std::vector<string> & terminal_nodes,int graph_size)184 absl::flat_hash_map<string, absl::flat_hash_set<int>> IdentityNTerminalPorts(
185 const NodeMap& node_map, const std::vector<string>& terminal_nodes,
186 int graph_size) {
187 // Determines which ports for IdentityN nodes (that can be rewritten) lead to
188 // a terminal node.
189 std::vector<string> to_visit;
190 to_visit.reserve(graph_size);
191 // Set terminal nodes as visited so terminal nodes that may be IdentityN don't
192 // get pruned later on.
193 absl::flat_hash_set<string> visited(terminal_nodes.begin(),
194 terminal_nodes.end());
195 for (string terminal_node : terminal_nodes) {
196 NodeDef* node = node_map.GetNode(terminal_node);
197 if (node == nullptr) {
198 continue;
199 }
200 for (string input : node->input()) {
201 to_visit.push_back(input);
202 }
203 }
204
205 absl::flat_hash_set<string> identity_n_fanouts;
206 while (!to_visit.empty()) {
207 string curr = to_visit.back();
208 to_visit.pop_back();
209 NodeDef* curr_node = node_map.GetNode(curr);
210 if (curr_node == nullptr ||
211 visited.find(curr_node->name()) != visited.end()) {
212 continue;
213 }
214 // For IdentityN nodes, only traverse up through the port that comes from a
215 // terminal node along with control inputs. The IdentityN node is not marked
216 // as visited so other node input traversals can go through the other ports
217 // of the IdentityN node.
218 if (IsIdentityN(*curr_node)) {
219 if (identity_n_fanouts.find(curr) == identity_n_fanouts.end()) {
220 identity_n_fanouts.emplace(curr);
221 int pos = NodePositionIfSameNode(curr, curr_node->name());
222 if (pos >= 0) {
223 to_visit.push_back(curr_node->input(pos));
224 }
225 for (const string& input : curr_node->input()) {
226 if (IsControlInput(input) &&
227 identity_n_fanouts.find(input) == identity_n_fanouts.end()) {
228 to_visit.push_back(input);
229 }
230 }
231 }
232 } else {
233 for (const string& input : curr_node->input()) {
234 to_visit.push_back(input);
235 }
236 visited.emplace(curr_node->name());
237 }
238 }
239
240 absl::flat_hash_map<string, absl::flat_hash_set<int>> identity_n_ports;
241 for (const auto& fanout : identity_n_fanouts) {
242 int pos;
243 string node_name = ParseNodeName(fanout, &pos);
244 if (node_name.empty() || pos < 0) { // Exclude control inputs.
245 continue;
246 }
247 if (identity_n_ports.find(node_name) == identity_n_ports.end()) {
248 identity_n_ports[node_name] = {pos};
249 } else {
250 identity_n_ports[node_name].emplace(pos);
251 }
252 }
253
254 return identity_n_ports;
255 }
256
NewIdentityFromIdentityN(int pos,const NodeDef & identity_n,GraphDef * graph,NodeMap * node_map)257 string NewIdentityFromIdentityN(int pos, const NodeDef& identity_n,
258 GraphDef* graph, NodeMap* node_map) {
259 // TODO(lyandy): Migrate over to GrapplerOptimizerStage and use
260 // OptimizedNodeName for new node name.
261 string new_node_name =
262 strings::StrCat(identity_n.name(), "-", pos, "-grappler-ModelPruner");
263 if (node_map->NodeExists(new_node_name)) {
264 return "";
265 }
266 NodeDef* new_node = graph->add_node();
267 Status status = NodeDefBuilder(new_node_name, "Identity")
268 .Input(identity_n.input(pos), 0,
269 identity_n.attr().at("T").list().type(pos))
270 .Device(identity_n.device())
271 .Finalize(new_node);
272 if (!status.ok()) {
273 return "";
274 }
275 node_map->AddNode(new_node->name(), new_node);
276 node_map->AddOutput(NodeName(new_node->input(0)), new_node->name());
277 return new_node->name();
278 }
279
RewriteIdentityNAndInputsOutputs(NodeDef * node,int num_non_control_inputs,const absl::flat_hash_set<int> & terminal_ports,GraphDef * graph,NodeMap * node_map)280 Status RewriteIdentityNAndInputsOutputs(
281 NodeDef* node, int num_non_control_inputs,
282 const absl::flat_hash_set<int>& terminal_ports, GraphDef* graph,
283 NodeMap* node_map) {
284 // Rewrite IdentityN node and associated inputs and outputs. For inputs and
285 // outputs that don't lead to a terminal node, a new Identity node is created
286 // and those inputs and outputs are rewritten to use the new Identity node as
287 // their outputs and inputs respectively. For the remaining nodes, the ouputs
288 // have their inputs updated with the adjusted port, from the IdentityN node
289 // having less inputs.
290 struct NodeOutputUpdate {
291 string input;
292 string output;
293 };
294
295 absl::flat_hash_map<int, int> terminal_input_pos;
296 absl::flat_hash_map<int, string> new_identities;
297 int new_idx = 0;
298 for (int i = 0; i < num_non_control_inputs; i++) {
299 if (terminal_ports.find(i) != terminal_ports.end()) {
300 terminal_input_pos[i] = new_idx++;
301 } else {
302 string identity = NewIdentityFromIdentityN(i, *node, graph, node_map);
303 if (identity.empty()) {
304 // Fail early when creating Identity from IdentityN errors.
305 return errors::Internal(
306 "Could not create Identity node from IdentityN node ", node->name(),
307 " at port ", i);
308 }
309 new_identities[i] = identity;
310 }
311 }
312
313 std::vector<NodeOutputUpdate> updates;
314 for (NodeDef* output : node_map->GetOutputs(node->name())) {
315 for (int i = 0; i < output->input_size(); i++) {
316 string input = output->input(i);
317 if (IsControlInput(input)) {
318 continue;
319 }
320 TensorId input_tensor = ParseTensorName(input);
321 if (input_tensor.node() == node->name()) {
322 if (terminal_ports.find(input_tensor.index()) == terminal_ports.end()) {
323 // Replace input that does not lead to a terminal node with newly
324 // created identity.
325 string new_identity = new_identities[input_tensor.index()];
326 output->set_input(i, new_identity);
327 updates.push_back({new_identity, output->name()});
328 } else {
329 // Update input ports that lead to a terminal node from splitting
330 // inputs.
331 int new_pos = terminal_input_pos[input_tensor.index()];
332 string updated_input_name =
333 new_pos > 0 ? strings::StrCat(node->name(), ":", new_pos)
334 : node->name();
335 output->set_input(i, updated_input_name);
336 }
337 }
338 }
339 }
340
341 for (NodeOutputUpdate update : updates) {
342 node_map->AddOutput(update.input, update.output);
343 }
344
345 // Update inputs and types by removing inputs that were split away from
346 // main IdentityN node.
347 const int num_inputs = node->input_size();
348 int curr_pos = 0;
349 auto mutable_inputs = node->mutable_input();
350 auto mutable_types =
351 node->mutable_attr()->at("T").mutable_list()->mutable_type();
352 for (int i = 0; i < num_non_control_inputs; i++) {
353 if (terminal_input_pos.find(i) != terminal_input_pos.end()) {
354 mutable_inputs->SwapElements(i, curr_pos);
355 mutable_types->SwapElements(i, curr_pos);
356 curr_pos++;
357 }
358 }
359 mutable_types->Truncate(curr_pos);
360 // Control inputs.
361 for (int i = num_non_control_inputs; i < num_inputs; i++) {
362 mutable_inputs->SwapElements(i, curr_pos++);
363 }
364 mutable_inputs->DeleteSubrange(curr_pos, num_inputs - curr_pos);
365
366 return Status::OK();
367 }
368
SplitIdentityNInputs(GraphDef * graph,const std::vector<string> & terminal_nodes,bool * updated_graph)369 Status SplitIdentityNInputs(GraphDef* graph,
370 const std::vector<string>& terminal_nodes,
371 bool* updated_graph) {
372 // For inputs of IdentityN nodes that do not lead to a terminal node, remove
373 // them from IdentityN and create new individual Identity nodes. This will
374 // allow ModelPruner to possibly remove nodes in the transitive fanin of the
375 // newly created Identity nodes.
376 NodeMap node_map(graph);
377
378 for (auto const& terminal :
379 IdentityNTerminalPorts(node_map, terminal_nodes, graph->node_size())) {
380 NodeDef* node = node_map.GetNode(terminal.first);
381 if (node == nullptr) {
382 continue;
383 }
384
385 const int num_non_control_inputs = NumNonControlInputs(*node);
386 if (node->attr().count("T") == 0 ||
387 node->attr().at("T").list().type_size() != num_non_control_inputs ||
388 terminal.second.size() >= num_non_control_inputs) {
389 continue;
390 }
391
392 TF_RETURN_IF_ERROR(RewriteIdentityNAndInputsOutputs(
393 node, num_non_control_inputs, terminal.second, graph, &node_map));
394 *updated_graph = true;
395 }
396
397 return Status::OK();
398 }
399
SetTransitiveFaninGraph(const GraphDef & input_graph,GraphDef * output_graph,const std::vector<string> & terminal_nodes)400 Status SetTransitiveFaninGraph(const GraphDef& input_graph,
401 GraphDef* output_graph,
402 const std::vector<string>& terminal_nodes) {
403 // Determines transitive fanin nodes from terminal nodes and add them to the
404 // output graph.
405 bool ill_formed = false;
406 std::vector<const NodeDef*> keep =
407 ComputeTransitiveFanin(input_graph, terminal_nodes, &ill_formed);
408 if (ill_formed) {
409 // Some graph edges are invalid, or some of the feeds/fetch don't exist:
410 // let's be conservative and preserve the graph as is.
411 return errors::InvalidArgument("Invalid input graph.");
412 }
413 // Try to keep the nodes ordered somewhat topologically since this helps
414 // further optimizations perform better.
415 output_graph->mutable_node()->Reserve(keep.size());
416 for (int i = keep.size() - 1; i >= 0; --i) {
417 *output_graph->add_node() = *keep[i];
418 }
419
420 return Status::OK();
421 }
422
Optimize(Cluster * cluster,const GrapplerItem & item,GraphDef * pruned_graph)423 Status ModelPruner::Optimize(Cluster* cluster, const GrapplerItem& item,
424 GraphDef* pruned_graph) {
425 const std::unordered_set<string> nodes_to_preserve = item.NodesToPreserve();
426
427 // Prune all the nodes that won't be executed, ie all the nodes that aren't in
428 // the fanin of a fetch node. If fetch nodes aren't specified, we'll assume
429 // the whole graph might be executed.
430 GrapplerItem runnable_item;
431 if (!nodes_to_preserve.empty()) {
432 std::vector<string> terminal_nodes(nodes_to_preserve.begin(),
433 nodes_to_preserve.end());
434 std::sort(terminal_nodes.begin(), terminal_nodes.end());
435 TF_RETURN_IF_ERROR(SetTransitiveFaninGraph(item.graph, &runnable_item.graph,
436 terminal_nodes));
437 bool did_split_identity_n = false;
438 TF_RETURN_IF_ERROR(SplitIdentityNInputs(
439 &runnable_item.graph, terminal_nodes, &did_split_identity_n));
440 if (did_split_identity_n) {
441 GraphDef fanin_split_identity_n_graph;
442 TF_RETURN_IF_ERROR(SetTransitiveFaninGraph(
443 runnable_item.graph, &fanin_split_identity_n_graph, terminal_nodes));
444 runnable_item.graph.Swap(&fanin_split_identity_n_graph);
445 }
446 } else {
447 runnable_item = item;
448 }
449
450 MutableGraphView graph_view(&runnable_item.graph);
451 absl::flat_hash_set<string> function_names;
452 for (const auto& function : item.graph.library().function()) {
453 function_names.insert(function.signature().name());
454 }
455 OpRegistryInterface* op_registry = OpRegistry::Global();
456
457 // Check if we can further prune the graph, by removing the trivial ops.
458 absl::flat_hash_set<const NodeDef*> nodes_to_delete;
459 for (auto& node : runnable_item.graph.node()) {
460 if (!IsTrivialOp(node, graph_view)) {
461 continue;
462 }
463
464 // Don't remove nodes that must be preserved.
465 if (nodes_to_preserve.find(node.name()) != nodes_to_preserve.end()) {
466 continue;
467 }
468
469 // - Don't remove nodes that drive control dependencies.
470 // - Don't remove nodes that are driven by control dependencies either since
471 // we can't ensure (yet) that we won't increase the number of control
472 // dependency edges by deleting them (for example, removing a node driven
473 // by 10 control edges and driving 10 control edges would result in the
474 // creation of 100 edges).
475 // - Don't modify nodes that are connected to functions since that can
476 // result in inlining failures later on.
477 // - Don't prune nodes that are driven by another device since these could
478 // be used to reduce cross device communication.
479 // - Don't remove nodes that receive reference values, as those can be
480 // converting references to non-references. It is important to preserve
481 // these non-references since the partitioner will avoid sending
482 // non-references across partitions more than once.
483 if (CanRemoveNode(node, graph_view, function_names, *op_registry)) {
484 nodes_to_delete.insert(&node);
485 }
486 }
487
488 pruned_graph->Clear();
489 *pruned_graph->mutable_library() = item.graph.library();
490 *pruned_graph->mutable_versions() = item.graph.versions();
491
492 if (nodes_to_delete.empty()) {
493 pruned_graph->mutable_node()->Swap(runnable_item.graph.mutable_node());
494 return Status::OK();
495 }
496
497 const bool fetches_are_known = !item.fetch.empty();
498 pruned_graph->mutable_node()->Reserve(runnable_item.graph.node_size());
499 absl::flat_hash_map<string, const NodeDef*> optimized_nodes;
500 for (auto& node : runnable_item.graph.node()) {
501 if (!fetches_are_known ||
502 nodes_to_delete.find(&node) == nodes_to_delete.end()) {
503 NodeDef* new_node = pruned_graph->add_node();
504 *new_node = node;
505 new_node->clear_input();
506 ForwardInputs(node, nodes_to_delete, new_node, &optimized_nodes,
507 graph_view);
508 }
509 }
510 VLOG(1) << "Pruned " << nodes_to_delete.size()
511 << " nodes from the graph. The graph now contains "
512 << pruned_graph->node_size() << " nodes.";
513 CHECK_LE(pruned_graph->node_size(), item.graph.node_size());
514
515 return Status::OK();
516 }
517
Feedback(Cluster * cluster,const GrapplerItem & item,const GraphDef & pruned_graph,double result)518 void ModelPruner::Feedback(Cluster* cluster, const GrapplerItem& item,
519 const GraphDef& pruned_graph, double result) {
520 // Nothing to do for ModelPruner.
521 }
522
523 } // end namespace grappler
524 } // end namespace tensorflow
525