1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_CORE_GRAPPLER_UTILS_GRAPH_VIEW_H_
17 #define TENSORFLOW_CORE_GRAPPLER_UTILS_GRAPH_VIEW_H_
18 
19 #include <vector>
20 
21 #include "absl/container/flat_hash_map.h"
22 #include "absl/container/flat_hash_set.h"
23 #include "absl/strings/string_view.h"
24 #include "absl/types/span.h"
25 #include "tensorflow/core/framework/attr_value.pb.h"
26 #include "tensorflow/core/framework/graph.pb.h"
27 #include "tensorflow/core/framework/node_def.pb.h"
28 #include "tensorflow/core/framework/node_def_util.h"
29 #include "tensorflow/core/graph/tensor_id.h"
30 #include "tensorflow/core/grappler/utils/graph_view_internal.h"
31 #include "tensorflow/core/lib/core/status.h"
32 
33 namespace tensorflow {
34 namespace grappler {
35 namespace utils {
36 
37 class NodeView;
38 
39 class GraphView;
40 
41 // FaninView is a helper class to represent fanouts of a node. This holds a
42 // pointer to GraphView, the index of the node being represented from GraphView,
43 // and the input index (hence is labeled as Fanin).
44 class FaninView : public internal::NodeIndexAndPortIndex<NodeView, GraphView> {
45  public:
FaninView()46   FaninView() : NodeIndexAndPortIndex() {}
47 
FaninView(GraphView * graph_view,int node_index,int port_index)48   FaninView(GraphView* graph_view, int node_index, int port_index)
49       : NodeIndexAndPortIndex(graph_view, node_index, port_index) {}
50 
51   FaninView(NodeView* node_view, int index);
52 
53  private:
54   friend class NodeView;
55   friend class GraphView;
56 };
57 
58 // FanoutView is a helper class to represent fanins of a node. This holds a
59 // pointer to GraphView, the index of the node being represented from GraphView,
60 // and the output index (hence is labeled as Fanout).
61 class FanoutView : public internal::NodeIndexAndPortIndex<NodeView, GraphView> {
62  public:
FanoutView()63   FanoutView() : NodeIndexAndPortIndex() {}
64 
FanoutView(GraphView * graph_view,int node_index,int port_index)65   FanoutView(GraphView* graph_view, int node_index, int port_index)
66       : NodeIndexAndPortIndex(graph_view, node_index, port_index) {}
67 
68   FanoutView(NodeView* node_view, int index);
69 
70  private:
71   friend class NodeView;
72   friend class GraphView;
73 };
74 
75 // Immutable NodeView that keeps the constness of the NodeDef. This allows for
76 // lookups of fanins and fanouts, and traversals of the graph, but no mutations.
77 // No dedupping of fanins will be performed on the node to preserve it's
78 // constness.
79 class NodeView : public internal::NodeViewInternal<FaninView, FanoutView,
80                                                    GraphView, true> {
81  public:
NodeView(GraphView * graph_view,int node_index)82   explicit NodeView(GraphView* graph_view, int node_index)
83       : NodeViewInternal(graph_view, node_index) {}
84 
NodeView()85   NodeView() : NodeViewInternal() {}
86 
87   ~NodeView() override = default;
88 
89   NodeView(NodeView&&) = default;
90   NodeView& operator=(NodeView&&) = default;
91 
92   const NodeDef* node() const override;
93 
94   // Checks if a fanin exists for the node.
95   bool HasFanin(const FanoutView& fanin) const override;
96 
97   // Checks if a fanout exists for the node.
98   bool HasFanout(const FaninView& fanout) const override;
99 
100  private:
101   inline const FanoutView& GetMissingFanin() const override;
102 
103   inline const std::vector<FaninView>& GetMissingFanout() const override;
104 
105   absl::flat_hash_set<internal::NodeDefAndPortIndex> fanins_set_;
106 
107   friend class FaninView;
108   friend class FanoutView;
109   friend class GraphView;
110 };
111 
112 // Immutable GraphView that keeps the constness of the GraphDef. This allows
113 // for lookups and traversals of the graph, but no mutations.
114 class GraphView : public internal::GraphViewInternal<NodeView, FaninView,
115                                                      FanoutView, true> {
116  public:
117   explicit GraphView(const GraphDef* graph, Status* status);
118   ~GraphView() override = default;
119 
120  private:
121   bool AddUniqueNodeInternal(const NodeDef* node);
122 
123   Status CheckAndAddFaninsInternal(NodeView* node_view);
124 
125   friend class NodeView;
126 };
127 
128 class MutableNodeView;
129 
130 class MutableGraphView;
131 
132 class Mutation;
133 
134 // MutableFaninView is a helper class to represent fanouts of a node. This holds
135 // a pointer to MutableGraphView, the index of the node from MutableGraphView
136 // being mutated, and the input index (hence is labeled as Fanin).
137 class MutableFaninView
138     : public internal::NodeIndexAndPortIndex<MutableNodeView,
139                                              MutableGraphView> {
140  public:
MutableFaninView()141   MutableFaninView() : NodeIndexAndPortIndex() {}
142 
MutableFaninView(MutableGraphView * graph_view,int node_index,int port_index)143   MutableFaninView(MutableGraphView* graph_view, int node_index, int port_index)
144       : NodeIndexAndPortIndex(graph_view, node_index, port_index) {}
145 
MutableFaninView(MutableGraphView * graph_view,int node_index,int port_index,int fanin_index)146   explicit MutableFaninView(MutableGraphView* graph_view, int node_index,
147                             int port_index, int fanin_index)
148       : NodeIndexAndPortIndex(graph_view, node_index, port_index),
149         fanin_index_(fanin_index) {
150     // TODO(lyandy): Remove once constructor is not public.
151     DCHECK(port_index < 0 || port_index == fanin_index);
152   }
153 
154   MutableFaninView(MutableNodeView* node_view, int index);
155 
156  private:
157   // Index of associated fanin in fanout's underlying MutableNodeView. For
158   // regular fanouts, this will be the same as port_index (index of the
159   // associated fanin in MutableNodeView::regular_fanins_). For controlled
160   // fanouts, this will be the index of the associated fanin in
161   // MutableNodeView::controlling_fanins_.
162   int fanin_index_ = internal::kMissingIndex;
163 
164   friend class MutableNodeView;
165   friend class MutableGraphView;
166   friend class Mutation;
167 };
168 
169 // MutableFanoutView is a helper class to represent fanins of a node. This holds
170 // a pointer to MutableGraphView, the index of the node from MutableGraphView
171 // being mutated, and the output index (hence is labeled as Fanout).
172 class MutableFanoutView
173     : public internal::NodeIndexAndPortIndex<MutableNodeView,
174                                              MutableGraphView> {
175  public:
MutableFanoutView()176   MutableFanoutView() : NodeIndexAndPortIndex() {}
177 
MutableFanoutView(MutableGraphView * graph_view,int node_index,int port_index)178   MutableFanoutView(MutableGraphView* graph_view, int node_index,
179                     int port_index)
180       : NodeIndexAndPortIndex(graph_view, node_index, port_index) {}
181 
MutableFanoutView(MutableGraphView * graph_view,int node_index,int port_index,int fanout_index)182   explicit MutableFanoutView(MutableGraphView* graph_view, int node_index,
183                              int port_index, int fanout_index)
184       : NodeIndexAndPortIndex(graph_view, node_index, port_index),
185         fanout_index_(fanout_index) {}
186 
187   MutableFanoutView(MutableNodeView* node_view, int index);
188 
189  private:
190   // Index of associated fanout in fanin's underlying MutableNodeView. For
191   // regular fanins, this will be the index of the associated fanout in
192   // MutableNodeView::regular_fanouts_by_port_[port_index]. For controlled
193   // fanins, this will be the index of the associated fanout in
194   // MutableNodeView::controlled_fanouts_.
195   int fanout_index_ = internal::kMissingIndex;
196 
197   friend class MutableNodeView;
198   friend class MutableGraphView;
199   friend class Mutation;
200 };
201 
202 // Mutable NodeView that holds a mutable NodeDef. This allows for lookups of
203 // fanins and fanouts, and traversals of the graph. Control dependencies will be
204 // dedupped among other control dependencies on initialization via
205 // MutableGraphView. Mutations should be handled via MutableGraphView and not
206 // directly on the mutable NodeDef.
207 class MutableNodeView
208     : public internal::NodeViewInternal<MutableFaninView, MutableFanoutView,
209                                         MutableGraphView, false> {
210  public:
MutableNodeView(MutableGraphView * graph_view,int node_index)211   explicit MutableNodeView(MutableGraphView* graph_view, int node_index)
212       : NodeViewInternal(graph_view, node_index) {}
213 
MutableNodeView()214   MutableNodeView() : NodeViewInternal() {}
215 
216   ~MutableNodeView() override = default;
217 
218   MutableNodeView(MutableNodeView&&) = default;
219   MutableNodeView& operator=(MutableNodeView&&) = default;
220 
221   NodeDef* node() const override;
222 
223   // Checks if a fanin exists for the node.
224   bool HasFanin(const MutableFanoutView& fanin) const override;
225 
226   // Checks if a fanout exists for the node.
227   bool HasFanout(const MutableFaninView& fanout) const override;
228 
229  private:
230   inline const MutableFanoutView& GetMissingFanin() const override;
231 
232   inline const std::vector<MutableFaninView>& GetMissingFanout() const override;
233 
234   absl::flat_hash_map<internal::NodeDefAndPortIndex, int> fanins_count_;
235   absl::flat_hash_map<absl::string_view, int> controlling_fanins_index_;
236   // Index of associated MutableNodeViewDiff in Mutation::updated_nodes_.
237   // If this is -1, there exists no MutableNodeViewDiff for this node.
238   int update_index_ = internal::kMissingIndex;
239 
240   friend class MutableFaninView;
241   friend class MutableFanoutView;
242   friend class MutableGraphView;
243   friend class Mutation;
244 };
245 
246 class MutationNewNode {
247  public:
MutationNewNode()248   MutationNewNode() {}
249 
250  private:
MutationNewNode(Mutation * mutation,int mutation_counter,int index)251   explicit MutationNewNode(Mutation* mutation, int mutation_counter, int index)
252       : mutation_(mutation),
253         mutation_counter_(mutation_counter),
254         index_(index) {}
255 
256   Mutation* mutation_ = nullptr;
257   int mutation_counter_ = internal::kMissingSlot;
258   int index_ = internal::kMissingIndex;
259 
260   friend class Mutation;
261 };
262 
263 // Mutation is a helper class that allows rewrites of MutableGraphView. This
264 // should not be initialized or be used directly.
265 // Note, if a node is renamed to another node, or a new node is created with the
266 // same name as an existing node, the node with the same name originally in the
267 // graph will be overwritten.
268 class Mutation {
269  public:
270   // Create a new node to be added to the graph. If the node's fanins are not
271   // well formed (self loops, control dependencies between regular fanins), the
272   // `status` will be set.
273   MutationNewNode AddNode(NodeDef&& node, Status* status);
274 
275   // Remove an existing node in the graph.
276   void RemoveNode(MutableNodeView* node);
277 
278   // Update the name of an existing node.
279   void UpdateNodeName(MutableNodeView* node, absl::string_view name);
280 
281   // Update the name of a new node.
282   void UpdateNodeName(const MutationNewNode& node, absl::string_view name);
283 
284   // Update the op of an existing node.
285   void UpdateNodeOp(MutableNodeView* node, absl::string_view op);
286 
287   // Update the op of a new node.
288   void UpdateNodeOp(const MutationNewNode& node, absl::string_view op);
289 
290   // Update the device of an existing node.
291   void UpdateNodeDevice(MutableNodeView* node, absl::string_view device);
292 
293   // Update the device of a new node.
294   void UpdateNodeDevice(const MutationNewNode& node, absl::string_view device);
295 
296   // Add or replace regular fanin `fanin` at `index` for an existing node.
297   void AddOrUpdateRegularFanin(MutableNodeView* node, int index,
298                                const TensorId& fanin);
299 
300   // Add or replace regular fanin `fanin` at `index` for a new node.
301   void AddOrUpdateRegularFanin(const MutationNewNode& node, int index,
302                                const TensorId& fanin);
303 
304   // Remove regular fanin at `index` for an existing node.
305   void RemoveRegularFanin(MutableNodeView* node, int index);
306 
307   // Remove regular fanin at `index` for a new node.
308   void RemoveRegularFanin(const MutationNewNode& node, int index);
309 
310   // Add controlling fanin `fanin_node_name` for an existing node.
311   void AddControllingFanin(MutableNodeView* node,
312                            absl::string_view fanin_node_name);
313 
314   // Add controlling fanin `fanin_node_name` for a new node.
315   void AddControllingFanin(const MutationNewNode& node,
316                            absl::string_view fanin_node_name);
317 
318   // Remove controlling fanin `fanin_node_name` for an existing node.
319   void RemoveControllingFanin(MutableNodeView* node,
320                               absl::string_view fanin_node_name);
321 
322   // Remove controlling fanin `fanin_node_name` for a new node.
323   void RemoveControllingFanin(const MutationNewNode& node,
324                               absl::string_view fanin_node_name);
325 
326   // Add or replace attribute `attr_name` with `attr_value` for an existing
327   // node.
328   void AddOrUpdateNodeAttr(MutableNodeView* node, absl::string_view attr_name,
329                            const AttrValue& attr_value);
330 
331   // Add or replace attribute `attr_name` with `attr_value` for a new node.
332   void AddOrUpdateNodeAttr(const MutationNewNode& node,
333                            absl::string_view attr_name,
334                            const AttrValue& attr_value);
335 
336   // Remove attribute `attr_name` for an existing node.
337   void RemoveNodeAttr(MutableNodeView* node, absl::string_view attr_name);
338 
339   // Remove attribute `attr_name` for a new node.
340   void RemoveNodeAttr(const MutationNewNode& node, absl::string_view attr_name);
341 
342   // Reset and clear mutation.
343   void Reset();
344 
345   // Applies the Mutation to the graph. If the mutation is valid, the graph will
346   // be modified. Otherwise an error status will be returned and the graph will
347   // not be modified.
348   Status Apply();
349 
350  private:
351   explicit Mutation(MutableGraphView* graph_view);
352 
353   void ResetInternal();
354 
355   using MutableNodeViewDiff = internal::NodeViewDiff<MutableGraphView>;
356 
357   // Adds a mutation to the `node`. Mutation function `mutate_fn` must return
358   // `true` if it actually does any mutations. If it returns `false` mutation
359   // will be ignored.
360   void AddMutation(MutableNodeView* node,
361                    std::function<bool(MutableNodeViewDiff*)> mutate_fn);
362 
363   MutableGraphView* graph_view_ = nullptr;
364   int mutation_counter_ = 0;
365   std::vector<MutableNodeViewDiff> updated_nodes_;
366   absl::flat_hash_set<int> removed_nodes_;
367 
368   using MutationNewNodeHolder = internal::NewNode<MutableGraphView>;
369   std::vector<MutationNewNodeHolder> new_nodes_;
370 
371   friend class MutableGraphView;
372 };
373 
374 // Mutable GraphView that holds a mutable GraphDef. This allows for lookups and
375 // traversals of the graph. Control dependencies will be dedupped among other
376 // control dependencies on initialization. Mutations should be handled using
377 // this API instead of directly on the GraphDef/NodeDef.
378 // Note, after a mutation, pointers of MutableNodeView's from MutableGraphView
379 // may be invalidated.
380 class MutableGraphView
381     : public internal::GraphViewInternal<MutableNodeView, MutableFaninView,
382                                          MutableFanoutView, false> {
383  public:
384   explicit MutableGraphView(GraphDef* graph, Status* status);
385   ~MutableGraphView() override = default;
386 
387   // Returns a Mutation (builder) that can be used to modify MutableGraphView.
388   Mutation* GetMutationBuilder();
389 
390   // Helper class representing an extra dependency for topological sorting.
391   class TopologicalDependency {
392    public:
TopologicalDependency(const MutableNodeView * from_node,const MutableNodeView * to_node)393     TopologicalDependency(const MutableNodeView* from_node,
394                           const MutableNodeView* to_node) {
395       if (from_node->graph_view_ == to_node->graph_view_) {
396         graph_view_ = from_node->graph_view_;
397         from_ = from_node->node_index_;
398         to_ = to_node->node_index_;
399       }
400     }
401 
402    private:
403     MutableGraphView* graph_view_ = nullptr;
404     int from_ = internal::kMissingIndex;
405     int to_ = internal::kMissingIndex;
406 
407     friend class MutableGraphView;
408   };
409 
410   // Sorts graph topologically in-place. If `ignore_cycles` is set, a
411   // topological like sorting will be performed when there are cycles. Otherwise
412   // if a cycle is detected or if the graph cannot be sorted, an error will be
413   // returned.
414   Status SortTopologically(
415       bool ignore_cycles,
416       absl::Span<const TopologicalDependency> extra_dependencies);
417 
418  private:
419   bool AddUniqueNodeInternal(NodeDef* node);
420 
421   Status CheckFaninsInternal(std::vector<std::vector<TensorId>>* fanins);
422 
423   void AddFaninsInternal(std::vector<std::vector<TensorId>>* fanins);
424 
425   // RenamedOrOverwrittenNode holds a index to Mutation::updated_nodes_ for a
426   // renamed node, alongside a potential overwritten node index in the actual
427   // graph. If the renamed node is not overwriting any existing nodes,
428   // `overwritten_node_index_` will be set to `internal::kMissingIndex`.
429   class RenamedOrOverwrittenNode {
430    public:
RenamedOrOverwrittenNode(int renamed_update_index,int overwritten_node_index)431     RenamedOrOverwrittenNode(int renamed_update_index,
432                              int overwritten_node_index)
433         : renamed_update_index_(renamed_update_index),
434           overwritten_node_index_(overwritten_node_index) {}
435 
436    private:
437     int renamed_update_index_;
438     int overwritten_node_index_;
439 
440     friend class MutableGraphView;
441   };
442 
443   Status GetNodeNamesAndPartitionUpdatedNodes(
444       absl::flat_hash_map<absl::string_view, int>* node_names,
445       std::vector<RenamedOrOverwrittenNode>* renamed_nodes,
446       std::vector<int>* inplace_nodes,
447       std::vector<int>* empty_diff_node_indices);
448 
449   Status RemovedOrMissingNodeFanoutsWellFormed(
450       const absl::flat_hash_map<absl::string_view, int>& node_names,
451       const std::vector<RenamedOrOverwrittenNode>& renamed_nodes);
452 
453   Status CheckNodeNamesAndFanins(
454       const absl::flat_hash_map<absl::string_view, int>& node_names,
455       const std::vector<RenamedOrOverwrittenNode>& renamed_nodes,
456       const std::vector<int>& inplace_nodes);
457 
458   Status CheckKernelRegisteredForNodes();
459 
460   // Helper class to move fanouts around.
461   class NodeViewFanouts {
462    public:
NodeViewFanouts(std::vector<std::vector<MutableFaninView>> && regular_fanouts_by_port,int num_regular_fanouts,std::vector<MutableFaninView> controlled_fanouts)463     NodeViewFanouts(
464         std::vector<std::vector<MutableFaninView>>&& regular_fanouts_by_port,
465         int num_regular_fanouts,
466         std::vector<MutableFaninView> controlled_fanouts)
467         : regular_fanouts_by_port_(std::move(regular_fanouts_by_port)),
468           num_regular_fanouts_(num_regular_fanouts),
469           controlled_fanouts_(std::move(controlled_fanouts)) {}
470 
471    private:
472     std::vector<std::vector<MutableFaninView>> regular_fanouts_by_port_;
473     int num_regular_fanouts_ = 0;
474     std::vector<MutableFaninView> controlled_fanouts_;
475 
476     friend class MutableGraphView;
477   };
478 
479   template <typename T>
480   void ReplaceNodeFanouts(MutableNodeView* node, T* fanouts);
481 
482   void FixRenamedNodes(
483       std::vector<RenamedOrOverwrittenNode>* renamed_nodes,
484       absl::flat_hash_map<string, NodeViewFanouts>* renamed_fanouts,
485       std::vector<bool>* overwritten_name_removed_nodes);
486 
487   void AddNewNodes(
488       absl::flat_hash_map<string, NodeViewFanouts>* renamed_fanouts,
489       std::vector<int>* new_node_indices);
490 
491   void FixRenamedFanouts(
492       const absl::flat_hash_map<string, NodeViewFanouts>& renamed_fanouts);
493 
494   inline void RemoveRegularFaninFanoutInternal(MutableNodeView* node_view,
495                                                int i);
496 
497   inline void AddRegularFaninInternal(MutableNodeView* node_view,
498                                       const SafeTensorId& fanin_id);
499 
500   inline void UpdateRegularFaninInternal(MutableNodeView* node_view,
501                                          const int i,
502                                          const SafeTensorId& fanin_id);
503 
504   inline void RemoveControllingFaninFanoutInternal(MutableNodeView* node_view,
505                                                    int i);
506 
507   inline void RemoveControllingFaninInternal(
508       MutableNodeView* node_view, const std::set<int>& indices_to_remove);
509 
510   inline void AddControllingFaninInternal(MutableNodeView* node_view,
511                                           absl::string_view fanin_node_name);
512 
513   void ApplyNodeUpdates();
514 
515   void SetNewNodesFanins(const std::vector<int>& new_node_indices);
516 
517   inline void RemoveAllFaninFanoutInternal(MutableNodeView* node_view);
518 
519   void RemoveNodesInternal(
520       const std::vector<RenamedOrOverwrittenNode>& renamed_nodes,
521       const std::vector<bool>& overwritten_name_removed_nodes);
522 
523   inline Status ValidateInternal(
524       absl::flat_hash_map<absl::string_view, int>* node_names,
525       std::vector<RenamedOrOverwrittenNode>* renamed_nodes,
526       std::vector<int>* inplace_nodes,
527       std::vector<int>* empty_diff_node_indices);
528 
529   Status ApplyMutationInternal();
530 
531   Mutation mutation_;
532 
533   friend class MutableNodeView;
534   friend class Mutation;
535 };
536 
537 }  // namespace utils
538 }  // namespace grappler
539 }  // namespace tensorflow
540 
541 #endif  // TENSORFLOW_CORE_GRAPPLER_UTILS_GRAPH_VIEW_H_
542