1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_CORE_GRAPPLER_MUTABLE_GRAPH_VIEW_H_
17 #define TENSORFLOW_CORE_GRAPPLER_MUTABLE_GRAPH_VIEW_H_
18 
19 #include <set>
20 #include <string>
21 
22 #include "absl/container/flat_hash_set.h"
23 #include "absl/strings/string_view.h"
24 #include "absl/types/span.h"
25 #include "tensorflow/core/framework/graph.pb.h"
26 #include "tensorflow/core/framework/node_def.pb.h"
27 #include "tensorflow/core/graph/graph.h"
28 #include "tensorflow/core/graph/tensor_id.h"
29 #include "tensorflow/core/grappler/graph_view.h"
30 #include "tensorflow/core/grappler/op_types.h"
31 #include "tensorflow/core/lib/core/status.h"
32 #include "tensorflow/core/platform/types.h"
33 
34 namespace tensorflow {
35 namespace grappler {
36 
37 const char kMutableGraphViewCtrl[] = "ConstantFoldingCtrl";
38 
39 // A utility class to simplify the traversal of a GraphDef that, unlike
40 // GraphView, supports updating the graph.  Note that you should not modify the
41 // graph separately, because the view will get out of sync.
42 
43 class MutableGraphView : public internal::GraphViewInternal<GraphDef, NodeDef> {
44  public:
MutableGraphView(GraphDef * graph)45   explicit MutableGraphView(GraphDef* graph) : GraphViewInternal(graph) {
46     for (NodeDef& node : *graph->mutable_node()) AddUniqueNodeOrDie(&node);
47     for (NodeDef& node : *graph->mutable_node()) AddAndDedupFanouts(&node);
48   }
49 
50   // Lookup fanouts/fanins using immutable ports.
51   using GraphViewInternal::GetFanout;
52   const absl::flat_hash_set<InputPort>& GetFanout(
53       const GraphView::OutputPort& port) const;
54 
55   using GraphViewInternal::GetFanin;
56   absl::flat_hash_set<OutputPort> GetFanin(
57       const GraphView::InputPort& port) const;
58 
59   using GraphViewInternal::GetRegularFanin;
60   const OutputPort GetRegularFanin(const GraphView::InputPort& port) const;
61 
62   // Adds a new node to graph and updates the view. Returns a pointer to the
63   // node in graph.
64   NodeDef* AddNode(NodeDef&& node);
65 
66   // Adds all nodes from the `subgraph` to the underlying graph and updates the
67   // view. `subgraph` doesn't have to be a valid graph definition on it's own,
68   // it can have edges to the nodes that are not in it, however after adding
69   // it to the underlying graph, final graph must be valid.
70   //
71   // If subgraph function library is not empty, all new functions will be added
72   // to the graph. Functions that appear with the same name in both subgraph and
73   // the graph represented by *this, must have identical function definitions.
74   //
75   // IMPORTANT: All nodes and functions of the given subgraph moved into the
76   // underlying graph, which leaves subgraph in valid but undefined state.
77   Status AddSubgraph(GraphDef&& subgraph);
78 
79   // Updates node `node_name` op, device, and attributes. This will clear any
80   // existing attributes. If it is not possible to update the node or if the
81   // node does not exist, an error will be returned and nothing will be modified
82   // in the graph.
83   Status UpdateNode(absl::string_view node_name, absl::string_view op,
84                     absl::string_view device,
85                     absl::Span<const std::pair<string, AttrValue>> attrs);
86 
87   // Updates node `from_node_name` name to `to_node_name`. If `to_node_name` is
88   // in use, node `from_node_name` does not exist, or node `from_node_name` has
89   // fanouts and `update_fanouts` is set to false, an error will be returned and
90   // nothing will be modified in the graph.
91   Status UpdateNodeName(absl::string_view from_node_name,
92                         absl::string_view to_node_name, bool update_fanouts);
93 
94   // Swap node names `from_node_name` and `to_node_name`. Self loops of one node
95   // are removed by updating the inputs introducing self loops to use the other
96   // node's name. Setting `update_fanouts` to false will exclude other fanouts
97   // from having their inputs updated, but inputs introducing self loops will
98   // always be updated regardless of `update_fanouts.
99   //
100   // Example:
101   //   1. foo(other:3, bar:2, ^bar)
102   //   2. bar(foo:3, other:1, foo:1, ^foo)
103   //   3. other(foo:5, bar:6)
104   //
105   // After calling SwapNodeNames("foo", "bar", false):
106   //   1. bar(other:3, foo:2, ^foo)
107   //   2. foo(bar:3, other:1, bar:1, ^bar)
108   //   3. other(foo:5, bar:6)
109   //
110   // After calling SwapNodeNames("foo", "bar", true):
111   //   1. bar(other:3, foo:2, ^foo)
112   //   2. foo(bar:3, other:1, bar:1, ^bar)
113   //   3. other(bar:5, foo:6)
114   //
115   // If it is not possible to swap node names (i.e. nodes do not exist or Switch
116   // control dependency may be introduced), an error will be returned and
117   // nothing will be modified in the graph.
118   Status SwapNodeNames(absl::string_view from_node_name,
119                        absl::string_view to_node_name, bool update_fanouts);
120 
121   // Updates all fanouts (input ports fetching output tensors) from
122   // `from_node_name` to the `to_node_name`, including control dependencies.
123   //
124   // Example: We have 3 nodes that use `bar` node output tensors as inputs:
125   //   1. foo1(bar:0, bar:1, other:0)
126   //   2. foo2(bar:1, other:1)
127   //   3. foo3(other:2, ^bar)
128   //
129   // After calling UpdateFanouts(bar, new_bar):
130   //   1. foo1(new_bar:0, new_bar:1, other:0)
131   //   2. foo2(new_bar:1, other:1)
132   //   3. foo3(other:2, ^new_bar)
133   Status UpdateFanouts(absl::string_view from_node_name,
134                        absl::string_view to_node_name);
135 
136   // Adds regular fanin `fanin` to node `node_name`. If the node or fanin do not
137   // exist in the graph, nothing will be modified in the graph. Otherwise fanin
138   // will be added after existing non control dependency fanins. Control
139   // dependencies will be deduped. To add control dependencies, use
140   // AddControllingFanin.
141   Status AddRegularFanin(absl::string_view node_name, const TensorId& fanin);
142 
143   // Adds regular fanin `fanin` to node `node_name` at port `port`. If the node
144   // or fanin do not exist in the graph, nothing will be modified in the graph.
145   // Otherwise fanin will be inserted at port `port`. Control dependencies will
146   // be deduped. To add control dependencies, use AddControllingFanin.
147   //
148   // If the port is not a valid port (less than 0 or greater than the number of
149   // regular fanins), this will result in an error and the node will not be
150   // modified.
151   Status AddRegularFaninByPort(absl::string_view node_name, int port,
152                                const TensorId& fanin);
153 
154   // Adds control dependency `fanin` to the target node named `node_name`. To
155   // add regular fanins, use AddRegularFanin.
156   //
157   // Case 1: If the fanin is not a Switch node, the control dependency is simply
158   // added to the target node:
159   //
160   //   fanin -^> target node.
161   //
162   // Case 2: If the fanin is a Switch node, we cannot anchor a control
163   // dependency on it, because unlike other nodes, only one of its outputs will
164   // be generated when the node is activated. In this case, we try to find an
165   // Identity/IdentityN node in the fanout of the relevant port of the Switch
166   // and add it as a fanin to the target node. If no such Identity/IdentityN
167   // node can be found, a new Identity node will be created. In both cases, we
168   // end up with:
169   //
170   //   fanin -> Identity{N} -^> target node.
171   //
172   // If the control dependency being added is redundant (control dependency
173   // already exists or control dependency can be deduped from regular fanins),
174   // this will not result in an error and the node will not be modified.
175   Status AddControllingFanin(absl::string_view node_name,
176                              const TensorId& fanin);
177 
178   // Removes regular fanin `fanin` from node `node_name`. If the node or fanin
179   // do not exist in the graph, nothing will be modified in the graph. If there
180   // are multiple inputs that match the fanin, all of them will be removed. To
181   // remove controlling fanins, use RemoveControllingFanin.
182   //
183   // If the fanin being removed doesn't exist in the node's inputs, this will
184   // not result in an error and the node will not be modified.
185   Status RemoveRegularFanin(absl::string_view node_name, const TensorId& fanin);
186 
187   // Removes regular fanin at port `port` from node `node_name`. If the node
188   // does not exist in the graph, nothing will be modified in the graph.
189   // To remove controlling fanins, use RemoveControllingFanin.
190   //
191   // If the port is not a valid port (less than 0 or greater than the last index
192   // of the regular fanins), this will result in an error and the node will not
193   // be modified.
194   Status RemoveRegularFaninByPort(absl::string_view node_name, int port);
195 
196   // Removes control dependency `fanin_node_name` from the target node named
197   // `node_name`. If the node or fanin do not exist in the graph, nothing will
198   // be modified in the graph. To remove regular fanins, use RemoveRegularFanin.
199   //
200   // If the fanin being removed doesn't exist in the node's inputs, this will
201   // not result in an error and the node will not be modified.
202   Status RemoveControllingFanin(absl::string_view node_name,
203                                 absl::string_view fanin_node_name);
204 
205   // Removes all fanins from node `node_name`. Control dependencies will be
206   // retained if keep_controlling_fanins is true.
207   //
208   // If no fanins are removed, this will not result in an error and the node
209   // will not be modified.
210   Status RemoveAllFanins(absl::string_view node_name,
211                          bool keep_controlling_fanins);
212 
213   // Replaces all fanins `from_fanin` with `to_fanin` in node `node_name`. If
214   // the fanins or node do not exist, nothing will be modified in the graph.
215   // Control dependencies will be deduped.
216   //
217   // If the fanin being updated doesn't exist in the node's inputs, this will
218   // not result in an error and the node will not be modified.
219   Status UpdateFanin(absl::string_view node_name, const TensorId& from_fanin,
220                      const TensorId& to_fanin);
221 
222   // Replaces fanin at port `port` in node `node_name` with fanin `fanin`. If
223   // the fanins or node do not exist, nothing will be modified in the graph.
224   // Control dependencies will be deduped.
225   //
226   // If the port is not a valid port (less than 0 or greater than the last index
227   // of the regular fanins), this will result in an error and the node will not
228   // be modified.
229   Status UpdateRegularFaninByPort(absl::string_view node_name, int port,
230                                   const TensorId& fanin);
231 
232   // Swaps fanins at ports `from_port` and `to_port` in node `node_name`. If the
233   // node does not exist, nothing will be modified in the graph.
234   //
235   // If the ports are not a valid port (less than 0 or greater than the last
236   // index of the regular fanins), this will result in an error and the node
237   // will not be modified.
238   Status SwapRegularFaninsByPorts(absl::string_view node_name, int from_port,
239                                   int to_port);
240 
241   // Updates all regular fanins to equivalent controlling fanins. If it is not
242   // possible, an error will be returned and nothing will be modified in the
243   // graph.
244   Status UpdateAllRegularFaninsToControlling(absl::string_view node_name);
245 
246   // Deletes nodes from the graph. If a node can't be safely removed,
247   // specifically if a node still has fanouts, an error will be returned. Nodes
248   // that can't be found are ignored.
249   Status DeleteNodes(const absl::flat_hash_set<string>& nodes_to_delete);
250 
251  private:
252   // Adds fanouts for fanins of node to graph, while deduping control
253   // dependencies from existing control dependencies and regular fanins. Note,
254   // node inputs will be mutated if control dependencies can be deduped.
255   void AddAndDedupFanouts(NodeDef* node);
256 
257   // Finds next output port smaller than fanin.port_id and update. The
258   // max_regular_output_port is only updated if fanin.port_id is the same as the
259   // current max_regular_output_port and if the fanouts set is empty. If there
260   // are no regular outputs, max_regular_output_port will be erased.
261   void UpdateMaxRegularOutputPortForRemovedFanin(
262       const OutputPort& fanin,
263       const absl::flat_hash_set<InputPort>& fanin_fanouts);
264 
265   // Updates max regular output port for newly added fanin by checking the
266   // current max and updating if the newly added fanin is of a larger port.
267   void UpdateMaxRegularOutputPortForAddedFanin(const OutputPort& fanin);
268 
269   // Updates all fanouts (input ports fetching output tensors) from `from_node`
270   // to the `to_node`, including control dependencies.
271   //
272   // Example: We have 3 nodes that use `bar` node output tensors as inputs:
273   //   1. foo1(bar:0, bar:1, other:0)
274   //   2. foo2(bar:1, other:1)
275   //   3. foo3(other:2, ^bar)
276   //
277   // After calling UpdateFanouts(bar, new_bar):
278   //   1. foo1(new_bar:0, new_bar:1, other:0)
279   //   2. foo2(new_bar:1, other:1)
280   //   3. foo3(other:2, ^new_bar)
281   //
282   // IMPORTANT: If `from_node` or `to_node` is not in the underlying graph, the
283   // behavior is undefined.
284   Status UpdateFanoutsInternal(NodeDef* from_node, NodeDef* to_node);
285 
286   // Adds fanin to node. If fanin is a control dependency, existing control
287   // dependencies will be checked first before adding. Otherwise fanin will be
288   // added after existing non control dependency inputs.
289   bool AddFaninInternal(NodeDef* node, const OutputPort& fanin);
290 
291   // Finds control dependency node to be used based on fanin. If fanin is not a
292   // Switch node, fanin.node is simply returned. Otherwise this will try to find
293   // a candidate Identity node consuming fanin, as the control dependency. If it
294   // is not possible or will introduce a self loop, an error message will be
295   // set. If nullptr is returned with no error
296   // GetOrCreateIdentityConsumingSwitch should be called to generate the new
297   // Identity node.
298   NodeDef* GetControllingFaninToAdd(absl::string_view node_name,
299                                     const OutputPort& fanin, string* error_msg);
300 
301   // Finds a generated Identity node consuming Switch node `fanin.node` at port
302   // `fanin.port_id`. If such a node does not exist, a new Identity node will be
303   // created.
304   NodeDef* GetOrCreateIdentityConsumingSwitch(const OutputPort& fanin);
305 
306   // Removes all instances of regular fanin `fanin` from node `node`.
307   bool RemoveRegularFaninInternal(NodeDef* node, const OutputPort& fanin);
308 
309   // Removes controlling fanin `fanin_node` from node if such controlling fanin
310   // exists.
311   bool RemoveControllingFaninInternal(NodeDef* node, NodeDef* fanin_node);
312 
313   // Checks if nodes to be deleted are missing or have any fanouts that will
314   // remain in the graph. If node is removed in either case, the graph will
315   // enter an invalid state.
316   Status CheckNodesCanBeDeleted(
317       const absl::flat_hash_set<string>& nodes_to_delete);
318 
319   // Removes fanins of the deleted node from internal state. Control
320   // dependencies are retained iff keep_controlling_fanins is true.
321   void RemoveFaninsInternal(NodeDef* deleted_node,
322                             bool keep_controlling_fanins);
323 
324   // Removes fanouts of the deleted node from internal state.
325   void RemoveFanoutsInternal(NodeDef* deleted_node);
326 };
327 
328 }  // end namespace grappler
329 }  // end namespace tensorflow
330 
331 #endif  // TENSORFLOW_CORE_GRAPPLER_MUTABLE_GRAPH_VIEW_H_
332