1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #ifndef TENSORFLOW_LITE_TOCO_TENSORFLOW_GRAPH_MATCHING_CLUSTER_H_
16 #define TENSORFLOW_LITE_TOCO_TENSORFLOW_GRAPH_MATCHING_CLUSTER_H_
17 
18 #include <string>
19 #include <vector>
20 
21 #include "tensorflow/lite/toco/model.h"
22 #include "tensorflow/lite/toco/tensorflow_graph_matching/cluster_utils.h"
23 #include "tensorflow/lite/toco/tooling_util.h"
24 #include "tensorflow/core/framework/attr_value.pb.h"
25 #include "tensorflow/core/framework/graph.pb.h"
26 #include "tensorflow/core/framework/node_def.pb.h"
27 #include "tensorflow/core/framework/tensor.pb.h"
28 #include "tensorflow/core/framework/tensor_shape.pb.h"
29 
30 namespace toco {
31 
32 // The base class for Cluster. A cluster is group of nodes all related to each
33 // other because their name match a given "pattern", which shows they all belong
34 // to a composite op supported in TFLite. The nodes in a cluster will be
35 // collapsed into a single composite op node plus a series of constant nodes
36 // holding the input parameters to that node. The nodes in a cluster are assumed
37 // to be using the same device. By changing the "pattern" we can have different
38 // subclasses of the base Cluster class.
39 class Cluster {
40  public:
~Cluster()41   virtual ~Cluster() {}
42 
43   virtual void CreateNodes() = 0;
44 
45   // Save the following info from the original GraphDef this cluster is from:
46   // 1- a pointer to the GraphDef
47   // 2- All the nodes in GraphDef which belong to this cluster.
48   void SetGraphDefInfo(const tensorflow::GraphDef* graph_def);
49 
GetName()50   const std::string& GetName() const { return name_; }
51 
GetNewNodes()52   const std::vector<std::unique_ptr<tensorflow::NodeDef>>& GetNewNodes() const {
53     return new_nodes_;
54   }
55 
GetNodes()56   const std::vector<const tensorflow::NodeDef*>& GetNodes() { return nodes_; }
57 
SetName(const std::string & name)58   void SetName(const std::string& name) { name_ = name; }
59 
SetDevice(const std::string & device)60   void SetDevice(const std::string& device) { device_ = device; }
61 
62   // Find the input(s) and output(s) of this Cluster.
63   bool FindClusterInputsAndOutputs();
64 
65  protected:
66   std::string name_;
67   std::string device_;
68   std::vector<std::string> inputs_;
69   std::vector<std::string> outputs_;
70 
71   // Used to hold the pointers to nodes which are in this cluster. These nodes
72   // are pointing to the nodes in graph_def_.
73   std::vector<const tensorflow::NodeDef*> nodes_;
74 
75   // Used to cache the newly generated nodes: like the nodes created by
76   // collapsing Const nodes, or the nodes which is used to show the composite
77   // op.
78   std::vector<std::unique_ptr<tensorflow::NodeDef>> new_nodes_;
79 
80   const tensorflow::GraphDef* graph_def_; /*Not owned*/
81 };
82 
83 // A factory interface for cluster class.
84 // It defines a virtual function interface which is responsible for creating
85 // a cluster. Each cluster factory is responsible to pack a cluster of nodes
86 // into a cluster using a name-based pattern matching approach.
87 class ClusterFactoryInterface {
88  public:
~ClusterFactoryInterface()89   virtual ~ClusterFactoryInterface() {}
90 
91   // Creates a cluster of nodes using a name-based pattern matching approach. It
92   // uses a node as a seed and if its name matches a certain pattern, then it
93   // builds the cluster around that node.
94   virtual std::unique_ptr<Cluster> CreateCluster(
95       const tensorflow::NodeDef& node,
96       const tensorflow::GraphDef& graph_def) const = 0;
97 };
98 
99 }  // end namespace toco
100 
101 #endif  // TENSORFLOW_LITE_TOCO_TENSORFLOW_GRAPH_MATCHING_CLUSTER_H_
102