1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/lite/toco/tensorflow_graph_matching/resolve_cluster.h"
16 
17 #include <string>
18 #include <unordered_map>
19 #include <vector>
20 
21 #include "tensorflow/lite/toco/tensorflow_graph_matching/cluster.h"
22 #include "tensorflow/lite/toco/tensorflow_graph_matching/cluster_utils.h"
23 #include "tensorflow/lite/toco/tensorflow_graph_matching/resolve_svdf.h"
24 #include "tensorflow/lite/toco/tooling_util.h"
25 #include "tensorflow/core/framework/attr_value.pb.h"
26 #include "tensorflow/core/framework/function.pb.h"
27 #include "tensorflow/core/framework/graph.pb.h"
28 #include "tensorflow/core/framework/node_def.pb.h"
29 
30 namespace toco {
31 
32 using tensorflow::GraphDef;
33 using tensorflow::NodeDef;
34 
AddNodeToGraph(const NodeDef & node,const std::vector<std::string> & cluster_names,GraphDef * graph)35 void AddNodeToGraph(const NodeDef& node,
36                     const std::vector<std::string>& cluster_names,
37                     GraphDef* graph) {
38   NodeDef* new_node = graph->add_node();
39   new_node->set_op(node.op());
40   new_node->set_name(node.name());
41   new_node->set_device(node.device());
42   // If the inputs are coming from a node which belongs to another cluster, then
43   // those inputs are renamed to the source cluster name. Otherwise the original
44   // input name is used.
45   for (const std::string& node_input : node.input()) {
46     bool input_from_cluster = false;
47     for (const std::string& cluster_name : cluster_names) {
48       if (StrContains(node_input, cluster_name) &&
49           !StrContains(node.name(), cluster_name)) {
50         new_node->add_input(cluster_name);
51         input_from_cluster = true;
52         break;
53       }
54     }
55     if (!input_from_cluster) {
56       new_node->add_input(node_input);
57     }
58   }
59   for (const auto& attr : node.attr()) {
60     (*new_node->mutable_attr())[attr.first] = attr.second;
61   }
62 }
63 
FindCluster(const ClusterFactoryInterface & cluster_factory,const GraphDef & graph_def,std::unordered_map<std::string,bool> * is_node_in_cluster,std::vector<std::unique_ptr<Cluster>> * clusters)64 bool FindCluster(const ClusterFactoryInterface& cluster_factory,
65                  const GraphDef& graph_def,
66                  std::unordered_map<std::string, bool>* is_node_in_cluster,
67                  std::vector<std::unique_ptr<Cluster>>* clusters) {
68   for (const NodeDef& node : graph_def.node()) {
69     // If the node is not assigned to any cluster, then we check if it belong to
70     // the cluster_factory.
71     bool node_in_cluster = (*is_node_in_cluster)[node.name()];
72     if (!node_in_cluster) {
73       std::unique_ptr<Cluster> cluster =
74           cluster_factory.CreateCluster(node, graph_def);
75       if (cluster) {
76         // Label all the nodes in is_node_in_cluster which are in this cluster
77         // as belonged to this cluster.
78         for (const NodeDef* cluster_node : cluster->GetNodes()) {
79           (*is_node_in_cluster)[cluster_node->name()] = true;
80         }
81         clusters->push_back(std::move(cluster));
82       }
83     }
84   }
85   return (!clusters->empty());
86 }
87 
MaybeResolveClusters(const GraphDef & graph_def,const std::vector<ClusterFactoryInterface * > & cluster_factories)88 std::unique_ptr<GraphDef> MaybeResolveClusters(
89     const GraphDef& graph_def,
90     const std::vector<ClusterFactoryInterface*>& cluster_factories) {
91   std::unique_ptr<GraphDef> pruned_graph(new GraphDef);
92   // The structure to keep track of which cluster each node is assigned to, and
93   // to initialize them to all un-assigned,
94   std::unordered_map<std::string, bool> is_node_in_cluster;
95   for (const NodeDef& node : graph_def.node()) {
96     is_node_in_cluster[node.name()] = false;
97   }
98 
99   std::vector<std::string> cluster_names;
100   std::vector<std::unique_ptr<Cluster>> all_clusters;
101   // Find the clusters for all available cluster factories.
102   for (const ClusterFactoryInterface* cluster_factory : cluster_factories) {
103     std::vector<std::unique_ptr<Cluster>> clusters;
104     if (FindCluster(*cluster_factory, graph_def, &is_node_in_cluster,
105                     &clusters)) {
106       for (auto itr = clusters.begin(); itr != clusters.end(); ++itr) {
107         cluster_names.push_back((*itr)->GetName());
108         (*itr)->CreateNodes();
109         all_clusters.push_back(std::move(*itr));
110       }
111     }
112   }
113 
114   for (const std::unique_ptr<Cluster>& cluster : all_clusters) {
115     for (const std::unique_ptr<tensorflow::NodeDef>& src_node :
116          cluster->GetNewNodes()) {
117       // Add it to the output GraphDef.
118       AddNodeToGraph(*src_node, cluster_names, pruned_graph.get());
119     }
120   }
121 
122   // Add any node which is not part of a cluster.
123   for (const NodeDef& node : graph_def.node()) {
124     bool node_in_cluster = is_node_in_cluster[node.name()];
125     if (!node_in_cluster) {
126       AddNodeToGraph(node, cluster_names, pruned_graph.get());
127     }
128   }
129 
130   if (pruned_graph->node_size() == 0) {
131     return nullptr;
132   } else {
133     return pruned_graph;
134   }
135 }
136 
MaybeReplaceCompositeSubgraph(const GraphDef & tf_graph)137 std::unique_ptr<GraphDef> MaybeReplaceCompositeSubgraph(
138     const GraphDef& tf_graph) {
139   SvdfClusterFactory svdf_cluster_factory;
140 
141   std::vector<ClusterFactoryInterface*> cluster_factories;
142   cluster_factories.push_back(&svdf_cluster_factory);
143 
144   std::unique_ptr<GraphDef> pruned_graph =
145       MaybeResolveClusters(tf_graph, cluster_factories);
146 
147   // Copy function definitions
148   if (pruned_graph) {
149     *(pruned_graph->mutable_library()) = tf_graph.library();
150   }
151   return pruned_graph;
152 }
153 
154 }  // end namespace toco
155