1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_TOOLS_GRAPH_TRANSFORMS_TRANSFORM_UTILS_H_
17 #define TENSORFLOW_TOOLS_GRAPH_TRANSFORMS_TRANSFORM_UTILS_H_
18 
19 #include <set>
20 #include <unordered_set>
21 #include <vector>
22 
23 #include "tensorflow/core/framework/attr_value.pb.h"
24 #include "tensorflow/core/framework/attr_value_util.h"
25 #include "tensorflow/core/framework/graph.pb.h"
26 #include "tensorflow/core/framework/node_def.pb.h"
27 #include "tensorflow/core/framework/tensor.h"
28 #include "tensorflow/core/framework/tensor.pb.h"
29 #include "tensorflow/core/lib/core/status.h"
30 
31 namespace tensorflow {
32 namespace graph_transforms {
33 
34 // Used to quickly look up nodes in the graph def from a name.
35 void MapNamesToNodes(const GraphDef& graph_def,
36                      std::map<string, const NodeDef*>* result);
37 
38 // For every node in the graph create a list of the nodes that use it as an
39 // input.
40 void MapNodesToOutputs(const GraphDef& graph_def,
41                        std::map<string, std::vector<const NodeDef*>>* result);
42 
43 // NodeDef input strings can contain other information besides the name of an
44 // input node. These include:
45 //  - Optional '^' prefix, indicating this is a control edge.
46 //  - The required name of the input node.
47 //  - Optional ':<number>' suffix, showing which output of the node to use.
48 // This function takes a raw string, and breaks it into those component parts.
49 // The rules for inputs in function libraries are a bit more complex, and
50 // aren't handled by this routine.
51 void NodeNamePartsFromInput(const string& input_name, string* prefix,
52                             string* node_name, string* suffix);
53 
54 // Adds a ':0' port to any inputs with no suffix, to make comparisons easier.
55 string CanonicalInputName(const string& input_name);
56 
57 // Convenience function to strip the optional prefix and suffix components from
58 // a string pulled from a NodeDef input, and return the plain node name.
59 string NodeNameFromInput(const string& input_name);
60 
61 // Returns a stable hash for the contents of the NodeDef, so that equivalent
62 // nodes should have equal hashes.
63 uint64 HashNodeDef(const NodeDef& node);
64 
65 // Adds the given node name to the end of the node's inputs.
66 void AddNodeInput(const string& input_name, NodeDef* node);
67 
68 // Copies an attribute from one NodeDef to another.
69 void CopyNodeAttr(const NodeDef& source, const string& source_key,
70                   const string& dest_key, NodeDef* dest);
71 
72 // Inserts a value into a NodeDef's map of attributes.
73 // This is a bit different than AddNodeAttr in node_def_util.h because it
74 // overwrites any existing attributes with the same key.
75 template <class T>
SetNodeAttr(const string & key,const T & value,NodeDef * node)76 inline void SetNodeAttr(const string& key, const T& value, NodeDef* node) {
77   AttrValue attr_value;
78   SetAttrValue(value, &attr_value);
79   auto* attr_map = node->mutable_attr();
80   (*attr_map)[key] = attr_value;
81 }
82 
83 template <class T>
SetNodeTensorAttr(const string & key,const Tensor & tensor,NodeDef * node)84 inline void SetNodeTensorAttr(const string& key, const Tensor& tensor,
85                               NodeDef* node) {
86   TensorProto tensor_proto;
87   tensor.AsProtoTensorContent(&tensor_proto);
88   SetNodeAttr(key, tensor_proto, node);
89 }
90 
91 // Inserts a Tensor into the specified attribute of a NodeDef.
92 template <class T>
SetNodeTensorAttr(const string & key,const TensorShape & shape,const std::vector<T> & values,NodeDef * node)93 inline void SetNodeTensorAttr(const string& key, const TensorShape& shape,
94                               const std::vector<T>& values, NodeDef* node) {
95   const DataType dtype = DataTypeToEnum<T>::v();
96   CHECK_EQ(shape.num_elements(), values.size());
97   Tensor tensor(dtype, shape);
98   T* dest_data = tensor.flat<T>().data();
99   std::copy_n(values.data(), values.size(), dest_data);
100   SetNodeTensorAttr<T>(key, tensor, node);
101 }
102 
103 // Retrieves a tensor value from a NodeDef attribute.
104 Tensor GetNodeTensorAttr(const NodeDef& node, const string& key);
105 
106 // Creates a copy of the input GraphDef, but only containing the nodes where the
107 // supplied selector function returned true.
108 void FilterGraphDef(const GraphDef& input_graph_def,
109                     std::function<bool(const NodeDef&)> selector,
110                     GraphDef* output_graph_def);
111 
112 // Creates a copy of the input graph, with all occurrences of the attributes
113 // with the names in the argument removed from the node defs.
114 void RemoveAttributes(const GraphDef& input_graph_def,
115                       const std::vector<string>& attributes,
116                       GraphDef* output_graph_def);
117 
118 // For a lot of replacement and matching operations it's useful to have the
119 // nodes processed in a controlled order, so this does a topological sort to
120 // ensure that nodes always appear in the GraphDef.node list after their inputs.
121 Status SortByExecutionOrder(const GraphDef& input_graph_def,
122                             GraphDef* output_graph_def);
123 
124 // Finds inputs that refer to nodes that are not in the graph.
125 void FindInvalidInputs(const GraphDef& graph_def,
126                        std::vector<std::pair<string, string>>* invalid_inputs);
127 
128 // Returns a descriptive error status if there are problems spotted with the
129 // graph.
130 Status IsGraphValid(const GraphDef& graph_def);
131 
132 // Returns input and output types for a particular NodeDef.
133 Status GetInOutTypes(const NodeDef& node_def, DataTypeVector* inputs,
134                      DataTypeVector* outputs);
135 
136 // Takes a comma-separated string of numbers and parses them into a shape.
137 Status TensorShapeFromString(const string& shape_string, TensorShape* result);
138 
139 // This is used to spot particular subgraphs in a larger model. To use it,
140 // create a pattern like:
141 // OpTypePattern pattern({"Conv2D", {{"ResizeBilinear", {{"MirrorPad"}}}}});
142 // This defines a subgraph where a Conv2D has a ResizeBilinear input, which
143 // pulls from a MirrorPad op.
144 // Regular expressions aren't supported for the op names, but you can use "*" to
145 // match any op. You can also use | as a separator to match multiple op names,
146 // like "Reshape|Concat|Conv2D".
147 struct OpTypePattern {
148   string op;
149   std::vector<OpTypePattern> inputs;
150   string DebugString() const;
151 };
152 
153 // Returns a sub-graph of nodes that match a pattern.
154 struct NodeMatch {
NodeMatchNodeMatch155   NodeMatch() : node() {}
156   NodeDef node;
157   std::vector<NodeMatch> inputs;
158   string DebugString() const;
159 };
160 
161 // Utility class to spot subgraphs matching particular patterns.
162 class GraphMatcher {
163  public:
164   GraphMatcher(const GraphDef& graph_def);
165 
166   // Sorts the input nodes into execution order, and then skips any previously
167   // matches so that no node appears in more than one match. The NodeDef
168   // pointers contained in the results are owned by the GraphMatcher object, and
169   // so will be invalid after its lifetime.
170   Status GetOpTypeMatches(const OpTypePattern& pattern,
171                           std::vector<NodeMatch>* matches);
172 
173  private:
174   bool DoesOpTypeMatch(const NodeDef& node, const OpTypePattern& pattern,
175                        const std::set<string>& previously_matched_nodes,
176                        NodeMatch* match);
177 
178   GraphDef graph_def_;
179   std::map<string, const NodeDef*> node_map_;
180 };
181 
182 struct ReplaceMatchingOpTypesOptions {
183   // Whether to raise an error if the graph is left with dangling inputs. If you
184   // enable this option, you must fix inconsistencies in a later pass.
185   bool allow_inconsistencies;
186 };
187 
188 // Replaces all of the matching sub-graphs with new ops. This calls into the
189 // given function, and expects to receive a set of new nodes to replace each
190 // matched sub-graph. It has some logic to protect the integrity of the
191 // resulting graph, for example making sure that nodes needed by other nodes
192 // outside the sub-graph aren't removed. These are passed in as the set of
193 // outputs, and nodes with the same names must be added to the new nodes
194 // produced by the replacement function. Many of these checks can be disabled
195 // by setting allow_inconsistencies to true in the options, but then it's the
196 // caller's responsibility to patch up any problems before passing on the graph
197 // to others. There's more comprehensive usage documentation in the README.
198 Status ReplaceMatchingOpTypes(
199     const GraphDef& input_graph_def, const OpTypePattern& pattern,
200     const std::function<Status(const NodeMatch&, const std::set<string>&,
201                                const std::set<string>&, std::vector<NodeDef>*)>&
202         node_generator,
203     const ReplaceMatchingOpTypesOptions& options, GraphDef* output_graph_def);
204 
205 // Returns a list of the unique nodes found in this match.
206 void MatchedNodesAsArray(const NodeMatch& match, std::vector<NodeDef>* result);
207 
208 // Changes all input references to a particular node name. Any nodes with names
209 // listed in nodes_to_ignore will not have their inputs rewritten.
210 Status RenameNodeInputs(const GraphDef& input_graph_def,
211                         const std::map<string, string>& inputs_to_rename,
212                         const std::unordered_set<string>& nodes_to_ignore,
213                         GraphDef* output_graph_def);
214 
215 // Utility function that copies all the nodes found in a match into the
216 // new_nodes list. This is useful in replacement functions when you decide to
217 // leave the original matched subgraph untouched and make no changes.
218 void CopyOriginalMatch(const NodeMatch& match, std::vector<NodeDef>* new_nodes);
219 
220 // Holds information that's needed for transform functions.
221 typedef std::map<string, std::vector<string>> TransformFuncParameters;
222 struct TransformFuncContext {
223   std::vector<string> input_names;
224   std::vector<string> output_names;
225   TransformFuncParameters params;
226 
227   // Returns how many occurrences of the given parameter are present.
228   int CountParameters(const string& name) const;
229 
230   // Gets a single instance of a parameter, using a default if it's not present.
231   Status GetOneStringParameter(const string& name, const string& default_value,
232                                string* result) const;
233 
234   // Gets a single occurrence of a parameter as a 32-bit integer, falling back
235   // to a default if it isn't present and returning an error if it isn't
236   // convertible to a number.
237   Status GetOneInt32Parameter(const string& name, int32 default_value,
238                               int32* result) const;
239 
240   // Gets a single occurrence of a parameter as a 64-bit integer, falling back
241   // to a default if it isn't present and returning an error if it isn't
242   // convertible to a number.
243   Status GetOneInt64Parameter(const string& name, int64 default_value,
244                               int64* result) const;
245 
246   // Gets a single occurrence of a parameter as a floating point number, falling
247   // back to a default if it isn't present and returning an error if it isn't
248   // convertible to a number.
249   Status GetOneFloatParameter(const string& name, float default_value,
250                               float* result) const;
251 
252   // Gets a single occurrence of a parameter as a boolean, falling back to a
253   // default if it isn't present and returning an error if it's not one of
254   // "true", "1", "false", or "0".
255   Status GetOneBoolParameter(const string& name, bool default_value,
256                              bool* result) const;
257 };
258 
259 // This is the function API for all graph transformations, taking an input
260 // GraphDef and other arguments, and returning a transformed GraphDef.
261 typedef std::function<Status(const GraphDef&,
262                              const TransformFuncContext& context, GraphDef*)>
263     TransformFunc;
264 
265 // To add a new graph transform function, call the macro:
266 // REGISTER_GRAPH_TRANSFORM("fold_constants", FoldConstants);
267 // Under the hood this adds the function to the list of known transforms, so you
268 // just need to link in the .cc file with your registration call to have access
269 // to it through the command line tool.
270 // The rest of the machinery below is to enable that automagical registration.
271 typedef std::map<string, TransformFunc> TransformRegistry;
272 TransformRegistry* GetTransformRegistry();
273 class TransformRegistrar {
274  public:
TransformRegistrar(const string & name,TransformFunc transform_func)275   TransformRegistrar(const string& name, TransformFunc transform_func) {
276     TransformRegistry* transform_registry = GetTransformRegistry();
277     (*transform_registry)[name] = transform_func;
278   }
279 };
280 #define REGISTER_GRAPH_TRANSFORM(name, func) \
281   REGISTER_GRAPH_TRANSFORM_UNIQ_HELPER(__COUNTER__, name, func)
282 #define REGISTER_GRAPH_TRANSFORM_UNIQ_HELPER(ctr, name, func) \
283   REGISTER_GRAPH_TRANSFORM_UNIQ(ctr, name, func)
284 #define REGISTER_GRAPH_TRANSFORM_UNIQ(ctr, name, func)    \
285   static tensorflow::graph_transforms::TransformRegistrar \
286       registrar__body__##ctr##__object(name, func);
287 
288 }  // namespace graph_transforms
289 }  // namespace tensorflow
290 
291 #endif  // TENSORFLOW_TOOLS_GRAPH_TRANSFORMS_TRANSFORM_UTILS_H_
292