1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_CORE_GRAPPLER_GRAPH_ANALYZER_SUBGRAPH_H_
17 #define TENSORFLOW_CORE_GRAPPLER_GRAPH_ANALYZER_SUBGRAPH_H_
18 
19 #include <initializer_list>
20 #include <set>
21 #include <unordered_set>
22 
23 #include "tensorflow/core/grappler/graph_analyzer/gen_node.h"
24 #include "tensorflow/core/grappler/graph_analyzer/map_tools.h"
25 #include "tensorflow/core/grappler/graph_analyzer/sig_node.h"
26 #include "tensorflow/core/lib/gtl/flatset.h"
27 
28 namespace tensorflow {
29 namespace grappler {
30 namespace graph_analyzer {
31 
32 // The description of a single subgraph for processing.
33 class Subgraph {
34  public:
35   // Identity of a single subgraph as a set of nodes.
36   class Identity : public gtl::FlatSet<const GenNode*> {
37    public:
38     using InitializerList = std::initializer_list<GenNode*>;
39 
40     Identity() = default;
41     Identity(InitializerList init);
42     bool operator<(const Identity& other) const;
43     bool operator==(const Identity& other) const;
44 
45     // Compute the hash.
46     size_t Hash() const;
47   };
48 
Subgraph(Identity id)49   explicit Subgraph(Identity id) : id_(std::move(id)), hash_(id_.Hash()) {}
50 
51   // Construct by extending the parent identity with an extra node.
52   Subgraph(const Identity& parent_id, GenNode* add_node);
53 
54   Subgraph() = delete;
55   Subgraph(const Subgraph& other) = delete;
56   void operator=(const Subgraph& other) = delete;
57 
58   // Order for building sets of subgraphs.
59   bool operator<(const Subgraph& other) const { return this->id_ < other.id_; }
60   // Support for hashed sets.
61   bool operator==(const Subgraph& other) const {
62     return this->id_ == other.id_;
63   }
Hash()64   size_t Hash() const { return hash_; }
65 
66   // Dump the subgraph information to a string.
67   string Dump();
68 
69   // Extract this subgraph into a separate graph representation for signature
70   // building, that includes only the links between the nodes in the subgraph
71   // and drops all the external links. The result map should be clear before the
72   // call.
73   void ExtractForSignature(SigNodeMap* result);
74 
id()75   const Identity& id() const { return id_; }
specific()76   bool specific() const { return specific_; }
SetSpecific(bool value)77   void SetSpecific(bool value) { specific_ = value; }
collation_count()78   int32_t collation_count() const { return collation_count_; }
79   void AddCollation(int32_t n = 1) { collation_count_ += n; }
ResetCollation()80   void ResetCollation() { collation_count_ = 1; }
MergeCollation(const Subgraph & other)81   void MergeCollation(const Subgraph& other) {
82     collation_count_ += other.collation_count_;
83   }
84 
85  private:
86   // Identity also serves as the list of nodes. It never changes throughout the
87   // life of subgraph.
88   Identity id_;
89   size_t hash_;  // Cached from the identity.
90   // Whether the dump should include the specific names of the nodes. The
91   // non-specific (i.e. generic) subgraphs represent a collation of multiple
92   // subgraphs.
93   bool specific_ = true;
94   // How many collated subgraphs are represented by this subgraph.
95   int32_t collation_count_ = 1;
96 };
97 
98 // Iteration of all links in a subgraph. This is more like Java iterators than
99 // the normal C++ iterators. It's simpler this way and there seems to be no
100 // major reason to make it a proper C++ iterator.
101 class SubgraphIterator {
102  public:
103   // Obviously an iterator is valid only until the original object
104   // gets destroyed.
105   explicit SubgraphIterator(const Subgraph::Identity* id);
SubgraphIterator(const Subgraph * sg)106   explicit SubgraphIterator(const Subgraph* sg) : SubgraphIterator(&sg->id()) {}
107 
108   // Check whether the built-in iterator is at the end.
AtEnd()109   bool AtEnd() const { return id_it_ == id_->end(); }
110 
111   // Get the neighbor at the current iterator.
112   // MUST NOT be called when AtEnd();
GetNeighbor()113   const GenNode::LinkTarget& GetNeighbor() const {
114     return link_map_it_->second[link_idx_];
115   }
116 
117   // Get the node at the current iterator.
118   // MUST NOT be called when AtEnd();
GetNode()119   const GenNode* GetNode() const { return *id_it_; }
120 
121   // Get the port leading to the neighbor at the current iterator.
122   // MUST NOT be called when AtEnd();
GetPort()123   GenNode::Port GetPort() const { return link_map_it_->first; }
124 
125   // Increases the iterator.
126   // Returns true if NOT AtEnd() after increasing the iterator.
127   // Safe to call if already AtEnd().
128   bool Next();
129 
130   // If there are more links at the same port, increases the iterator and
131   // returns true. Otherwise leaves the iterator unchanged and returns false.
132   bool NextIfSamePort();
133 
134   // Increases the iterator directly to the last position on the current port
135   // (or if already there then doesn't increase). Equivalent to calling
136   // NextIfSamePort() while it returns true, but faster.
137   // Safe to call if already AtEnd().
138   void SkipPort();
139 
140   // Increases the iterator directly to the last position on the current node.
141   // Safe to call if already AtEnd().
142   void SkipNode();
143 
144   // Returns true if the iterators are exactly the same.
145   bool operator==(const SubgraphIterator& other) const;
146   bool operator!=(const SubgraphIterator& other) const {
147     return !(*this == other);
148   }
149 
150  private:
151   // After link_idx_ has been increased, make sure that it points to the
152   // next valid element (or end) by increasing the higher levels of iteration if
153   // needed.
154   // Returns true if NOT AtEnd() after increasing the iterator.
155   // NOT safe to call if already AtEnd().
156   bool PropagateNext();
157 
158   // Identity of the subgraph being iterated over.
159   const Subgraph::Identity* id_;
160 
161   // The current position, allowing to iterate through the links (see the
162   // reasoning for it in the public section).
163   //
164   // (1) Iterator of the nodes in the subgraph.
165   Subgraph::Identity::const_iterator id_it_;
166   // (2) Iterator in the link map of the node.
167   GenNode::LinkMap::const_iterator link_map_it_;
168   // (3) Index in the vector of the links.
169   int32_t link_idx_;
170 };
171 
172 // A convenient way to store subgraphs: in a set of unique_ptrs. This way the
173 // addresses of subgraph objects will stay stable, and the objects themselves
174 // won't be copied.
175 class SubgraphPtrSet
176     : public std::unordered_set<std::unique_ptr<Subgraph>,
177                                 HashAtPtr<std::unique_ptr<Subgraph>>,
178                                 EqAtPtr<std::unique_ptr<Subgraph>>> {
179  public:
180   // Attempts to extend the set by adding a new subgraph that gets created by
181   // adding one node to the parent subgraph. If such a subgraph already exists,
182   // returns nullptr, otherwise returns the pointer to the new subgraph.
183   Subgraph* ExtendParent(const Subgraph::Identity& parent_id, GenNode* node);
184 };
185 
186 }  // end namespace graph_analyzer
187 }  // end namespace grappler
188 }  // end namespace tensorflow
189 
190 #endif  // TENSORFLOW_CORE_GRAPPLER_GRAPH_ANALYZER_SUBGRAPH_H_
191