1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/grappler/utils/transitive_fanin.h"
17 
18 #include <queue>
19 #include <vector>
20 
21 #include "tensorflow/core/framework/node_def_util.h"
22 #include "tensorflow/core/grappler/utils.h"
23 #include "tensorflow/core/platform/errors.h"
24 
25 namespace tensorflow {
26 namespace grappler {
27 
ComputeTransitiveFanin(const GraphDef & graph,const std::vector<string> & terminal_nodes,std::unordered_map<string,const NodeDef * > * name_to_fanin_node,std::vector<const NodeDef * > * fanin_nodes)28 Status ComputeTransitiveFanin(
29     const GraphDef& graph, const std::vector<string>& terminal_nodes,
30     std::unordered_map<string, const NodeDef*>* name_to_fanin_node,
31     std::vector<const NodeDef*>* fanin_nodes) {
32   std::unordered_map<string, const NodeDef*> name_to_node;
33   std::unordered_map<string, const NodeDef*> name_to_send;
34   for (const auto& node : graph.node()) {
35     name_to_node[node.name()] = &node;
36     if (node.op() == "_Send") {
37       const auto& attr = node.attr();
38       name_to_send[attr.at("tensor_name").s()] = &node;
39     }
40   }
41 
42   std::vector<const NodeDef*> queue;
43   for (const string& root : terminal_nodes) {
44     const NodeDef* node = name_to_node[NodeName(root)];
45     if (!node) {
46       return errors::InvalidArgument("Graph does not contain terminal node ",
47                                      root, ".");
48     }
49     queue.push_back(node);
50   }
51 
52   std::unordered_set<const NodeDef*> visited;
53 
54   while (!queue.empty()) {
55     const NodeDef* node = queue.back();
56     queue.pop_back();
57     if (!visited.insert(node).second) {
58       // The node has already been visited.
59       continue;
60     }
61     fanin_nodes->push_back(node);
62     if (name_to_fanin_node) {
63       name_to_fanin_node->insert(
64           std::pair<string, const NodeDef*>(node->name(), node));
65     }
66     for (const string& input : node->input()) {
67       const NodeDef* in = name_to_node[NodeName(input)];
68       if (!in) {
69         return errors::InvalidArgument("Graph does not contain input ",
70                                        NodeName(input), " of node ",
71                                        node->name(), ".");
72       }
73       queue.push_back(in);
74     }
75     if (node->op() == "_Recv") {
76       const auto& attr = node->attr();
77       const NodeDef* send = name_to_send[attr.at("tensor_name").s()];
78       if (send) {
79         queue.push_back(send);
80       }
81       // Subgraph after partitioning may have either _Send or _Recv, not both.
82       // So, we do not set ill_formed for missing _Send.
83     }
84   }
85   return Status::OK();
86 }
87 
ComputeTransitiveFanin(const GraphDef & graph,const std::vector<string> & terminal_nodes,std::vector<const NodeDef * > * fanin_nodes)88 Status ComputeTransitiveFanin(const GraphDef& graph,
89                               const std::vector<string>& terminal_nodes,
90                               std::vector<const NodeDef*>* fanin_nodes) {
91   return ComputeTransitiveFanin(graph, terminal_nodes, nullptr, fanin_nodes);
92 }
93 
SetTransitiveFaninGraph(const GraphDef & input_graph,GraphDef * output_graph,const std::vector<string> & terminal_nodes)94 Status SetTransitiveFaninGraph(const GraphDef& input_graph,
95                                GraphDef* output_graph,
96                                const std::vector<string>& terminal_nodes) {
97   // Determines transitive fanin nodes from terminal nodes and add them to the
98   // output graph.
99   std::vector<const NodeDef*> keep;
100   TF_RETURN_IF_ERROR(
101       ComputeTransitiveFanin(input_graph, terminal_nodes, &keep));
102   // Try to keep the nodes ordered somewhat topologically since this helps
103   // further optimizations perform better.
104   output_graph->mutable_node()->Reserve(keep.size());
105   for (int i = keep.size() - 1; i >= 0; --i) {
106     *output_graph->add_node() = *keep[i];
107   }
108 
109   return Status::OK();
110 }
111 
112 }  // namespace grappler
113 }  // namespace tensorflow
114