1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/grappler/utils/colocation.h"
17 
18 #include <cstring>
19 #include "tensorflow/core/framework/attr_value.pb.h"
20 #include "tensorflow/core/framework/node_def.pb.h"
21 #include "tensorflow/core/grappler/utils.h"
22 
23 namespace tensorflow {
24 namespace grappler {
25 
26 namespace {
27 
28 // Find root node of the colocation group.
29 // The map is mapping from one node name to its parent. node_name is the
30 // starting node to search. By iteratively following the path from child to
31 // parent, we can find the root node for the colocation group that node_name
32 // belongs to.
GetColocationGroupRoot(std::unordered_map<string,string> * map,const string & node_name)33 string GetColocationGroupRoot(std::unordered_map<string, string>* map,
34                               const string& node_name) {
35   if (map->find(node_name) == map->end()) {
36     // If node_name is not in the map, we create a new root node which points
37     // to itself.
38     map->insert({node_name, node_name});
39     return node_name;
40   }
41   std::list<string> nodes_to_root;
42   string cur = node_name;
43   while ((*map)[cur] != cur) {
44     // Backtracing the map until we reach the root node.
45     nodes_to_root.push_back(cur);
46     cur = (*map)[cur];
47   }
48 
49   // Update the nodes on the path to the root node to point to the root as well,
50   // so the further lookups can be faster.
51   if (!nodes_to_root.empty()) {
52     nodes_to_root.pop_back();
53     for (const string& node : nodes_to_root) {
54       (*map)[node] = cur;
55     }
56   }
57   return cur;
58 }
59 
60 // Merge two colocation groups into one.
61 // left and right is the root node of two colocation groups respectively.
MergeColocationGroup(std::unordered_map<string,string> * map,const string & left,const string & right)62 void MergeColocationGroup(std::unordered_map<string, string>* map,
63                           const string& left, const string& right) {
64   // Do nothing if left or right node is not in the map.
65   if (map->find(left) == map->end() || map->find(right) == map->end()) {
66     return;
67   }
68   if (left != right) {
69     // Make the right node a child of the left node, which merges the two
70     // groups.
71     map->at(right) = left;
72   }
73 }
74 }  // namespace
75 
76 // Use of disjoint set algorithm to build the colocation groups from the input
77 // graph. The core data structure in use is a hash map from one node to its
78 // parent node. Whenever we see two nodes colocate with each other, we merge
79 // their colocation groups together. After we traverse all colocation pairs
80 // in the graph, we will have several disjoint sets. Then we pick the root node
81 // of each disjoint set as the representative node, and let all other nodes in
82 // the group colocate with the representative node.
ReassignColocation(GraphDef * graph)83 void ReassignColocation(GraphDef* graph) {
84   constexpr char kClassAttr[] = "_class";
85   constexpr char kColocPrefix[] = "loc:@";
86 
87   // A hashmap that maps from a node name to its parent node name.
88   std::unordered_map<string, string> coloc_groups;
89   NodeMap node_map(graph);
90   for (const auto& node : graph->node()) {
91     auto iter = node.attr().find(kClassAttr);
92     if (iter != node.attr().end() && iter->second.has_list()) {
93       for (const auto& str : iter->second.list().s()) {
94         size_t pos = str.find(kColocPrefix);
95         if (pos == 0) {
96           // After we find a colocation, update the colocation groups.
97           string colocate_node = str.substr(pos + strlen(kColocPrefix));
98           MergeColocationGroup(
99               &coloc_groups, GetColocationGroupRoot(&coloc_groups, node.name()),
100               GetColocationGroupRoot(&coloc_groups, colocate_node));
101         }
102       }
103     }
104   }
105 
106   // We use the root node of each colocation groups as its representative
107   // node. For each node in one group, colocate with the representative node
108   // if the node is in the graph.
109   for (const auto& pair : coloc_groups) {
110     if (pair.first != pair.second) {
111       // This is a child node.
112       NodeDef* node = node_map.GetNode(pair.first);
113       if (node) {
114         // Colocate this node with the root node.
115         AttrValue new_value;
116         new_value.mutable_list()->add_s(
117             kColocPrefix + GetColocationGroupRoot(&coloc_groups, pair.first));
118         node->mutable_attr()->erase(kClassAttr);
119         node->mutable_attr()->insert({kClassAttr, new_value});
120       }
121     } else {
122       // This is a root node. Clear the _class attribute.
123       NodeDef* node = node_map.GetNode(pair.first);
124       if (node) {  // root node should always exist in the graph as guaranteed
125                    // by order of merging. Just put check here to ensure safety.
126         node->mutable_attr()->erase(kClassAttr);
127       }
128     }
129   }
130 }
131 
132 }  // namespace grappler
133 }  // namespace tensorflow
134