1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_CORE_GRAPPLER_UTILS_GRAPH_VIEW_INTERNAL_H_
17 #define TENSORFLOW_CORE_GRAPPLER_UTILS_GRAPH_VIEW_INTERNAL_H_
18 
19 #include "absl/container/flat_hash_map.h"
20 #include "absl/container/flat_hash_set.h"
21 #include "absl/hash/hash.h"
22 #include "absl/strings/string_view.h"
23 #include "tensorflow/core/framework/attr_value.pb.h"
24 #include "tensorflow/core/framework/graph.pb.h"
25 #include "tensorflow/core/framework/node_def.pb.h"
26 #include "tensorflow/core/framework/node_def_util.h"
27 #include "tensorflow/core/graph/tensor_id.h"
28 #include "tensorflow/core/lib/core/status.h"
29 #include "tensorflow/core/lib/gtl/map_util.h"
30 
31 namespace tensorflow {
32 namespace grappler {
33 namespace utils {
34 namespace internal {
35 
36 constexpr int kMissingSlot = -2;
37 constexpr int kMissingIndex = -1;
38 constexpr int kNodeNamePresent = -1;
39 
40 // NodeIndexAndPortIndex is a helper class that represents fanins and fanouts
41 // of a node.
42 template <typename NodeViewT, typename GraphViewT>
43 class NodeIndexAndPortIndex {
44  public:
NodeIndexAndPortIndex()45   NodeIndexAndPortIndex()
46       : graph_view_(nullptr),
47         node_index_(kMissingIndex),
48         port_index_(kMissingSlot) {}
NodeIndexAndPortIndex(GraphViewT * graph_view,int node_index,int port_index)49   NodeIndexAndPortIndex(GraphViewT* graph_view, int node_index, int port_index)
50       : graph_view_(graph_view),
51         node_index_(node_index),
52         port_index_(port_index) {}
53 
54   bool operator==(const NodeIndexAndPortIndex& other) const {
55     return port_index_ == other.port_index_ &&
56            node_index_ == other.node_index_ && graph_view_ == other.graph_view_;
57   }
58 
59   template <typename Hash>
AbslHashValue(Hash h,const NodeIndexAndPortIndex & n)60   friend Hash AbslHashValue(Hash h, const NodeIndexAndPortIndex& n) {
61     return Hash::combine(std::move(h), n.node_index_, n.port_index_);
62   }
63 
64   // Returns NodeView from `graph_view_` at `node_index_`.
node_view()65   NodeViewT* node_view() const {
66     if (graph_view_ == nullptr) {
67       return nullptr;
68     }
69     return graph_view_->GetNode(node_index_);
70   }
71 
72   // Returns node index in graph.
node_index()73   int node_index() const { return node_index_; }
74 
75   // Returns input/output port index.
index()76   int index() const { return port_index_; }
77 
78  protected:
79   GraphViewT* graph_view_;
80   int node_index_;
81   int port_index_;
82 };
83 
84 // NodeDefAndPortIndex is a helper class that represents fanins hashed with
85 // pointer stability using the fanin's NodeDef.
86 class NodeDefAndPortIndex {
87  public:
NodeDefAndPortIndex(const NodeDef * node_def,int port_index)88   NodeDefAndPortIndex(const NodeDef* node_def, int port_index)
89       : node_def_(node_def), port_index_(port_index) {}
90 
91   bool operator==(const NodeDefAndPortIndex& other) const {
92     return node_def_ == other.node_def_ && port_index_ == other.port_index_;
93   }
94 
95   template <typename Hash>
AbslHashValue(Hash h,const NodeDefAndPortIndex & n)96   friend Hash AbslHashValue(Hash h, const NodeDefAndPortIndex& n) {
97     return Hash::combine(std::move(h), n.node_def_, n.port_index_);
98   }
99 
100  private:
101   const NodeDef* node_def_;
102   int port_index_;
103 };
104 
105 // NodeViewInternal is a helper class to simplify graph traversal. It creates
106 // a view of a node and associated fanins and fanouts from the NodeDef
107 // protocol buffer.
108 //
109 // There are two public classes implementing NodeViewInternal:
110 //
111 // - NodeView: constructed from `const NodeDef` and doesn't allow mutating the
112 //   underlying node.
113 // - MutableNodeView: constructed from `NodeDef` and allows mutating the
114 //   underlying node.
115 //
116 // --------------------------- !!! WARNING !!! ---------------------------------
117 //     Modifying the node outside of implementations of NodeViewInternal
118 //     (i.e. modifying inputs of the NodeDef directly) may leave the NodeView
119 //     in an inconsistent/invalid state.
120 // -----------------------------------------------------------------------------
121 //
122 template <typename FaninViewT, typename FanoutViewT, typename GraphViewT,
123           bool IsConst>
124 class NodeViewInternal {
125  private:
126   using NodeDefT =
127       typename std::conditional<IsConst, const NodeDef, NodeDef>::type;
128 
129  public:
NodeViewInternal(GraphViewT * graph_view,int node_index)130   explicit NodeViewInternal(GraphViewT* graph_view, int node_index)
131       : graph_view_(graph_view),
132         node_index_(node_index),
133         attrs_(AttrSlice(graph_view->graph()->node(node_index))) {}
134 
NodeViewInternal()135   NodeViewInternal()
136       : graph_view_(nullptr), node_index_(kMissingIndex), attrs_(AttrSlice()) {}
137 
~NodeViewInternal()138   virtual ~NodeViewInternal() {}
139 
140   NodeViewInternal(NodeViewInternal&&) = default;
141   NodeViewInternal& operator=(NodeViewInternal&&) = default;
142 
143   bool operator==(const NodeViewInternal& other) const {
144     return node_index_ == other.node_index_ && graph_view_ == other.graph_view_;
145   }
146 
147   template <typename Hash>
AbslHashValue(Hash h,const NodeViewInternal & n)148   friend Hash AbslHashValue(Hash h, const NodeViewInternal& n) {
149     return Hash::combine(std::move(h), n.node_index_);
150   }
151 
152   // Returns NodeDef of view.
153   virtual NodeDefT* node() const = 0;
154 
155   // Returns index of node in GraphDef/GraphView.
node_index()156   int node_index() const { return node_index_; }
157 
158   // Returns the name of the node.
GetName()159   const string& GetName() const { return node()->name(); }
160 
161   // Returns the op of the node.
GetOp()162   const string& GetOp() const { return node()->op(); }
163 
164   // Returns the device set for the node.
GetDevice()165   const string& GetDevice() const { return node()->device(); }
166 
167   // Returns all regular fanins, based on ordering in the node.
GetRegularFanins()168   const std::vector<FanoutViewT>& GetRegularFanins() const {
169     return regular_fanins_;
170   }
171 
172   // Returns a regular fanin based on input index. If no such fanin exist, a
173   // missing fanin is returned, with no NodeView set and an index of -2.
GetRegularFanin(int i)174   const FanoutViewT& GetRegularFanin(int i) const {
175     int regular_fanins_size = regular_fanins_.size();
176     if (i < 0 || i >= regular_fanins_size) {
177       return GetMissingFanin();
178     }
179     return regular_fanins_[i];
180   }
181 
182   // Returns all controlling fanins, based on ordering in the node.
GetControllingFanins()183   const std::vector<FanoutViewT>& GetControllingFanins() const {
184     return controlling_fanins_;
185   }
186 
187   // Returns all regular fanouts.
GetRegularFanouts()188   const std::vector<std::vector<FaninViewT>>& GetRegularFanouts() const {
189     return regular_fanouts_by_port_;
190   }
191 
192   // Returns a regular fanout(s) based on output index. If no such output index
193   // exists, no fanouts will be returned.
GetRegularFanout(int i)194   const std::vector<FaninViewT>& GetRegularFanout(int i) const {
195     int regular_fanouts_by_port_size = regular_fanouts_by_port_.size();
196     if (i < 0 || i >= regular_fanouts_by_port_size) {
197       return GetMissingFanout();
198     }
199     return regular_fanouts_by_port_[i];
200   }
201 
202   // Returns all controlled fanouts.
GetControlledFanouts()203   const std::vector<FaninViewT>& GetControlledFanouts() const {
204     return controlled_fanouts_;
205   }
206 
207   // Returns the number of regular fanins.
NumRegularFanins()208   int NumRegularFanins() const { return regular_fanins_.size(); }
209 
210   // Returns the number of controlling fanins.
NumControllingFanins()211   int NumControllingFanins() const { return controlling_fanins_.size(); }
212 
213   // Returns the number of regular fanouts.
NumRegularFanouts()214   int NumRegularFanouts() const { return num_regular_fanouts_; }
215 
216   // Returns the number of controlled fanouts.
NumControlledFanouts()217   int NumControlledFanouts() const { return controlled_fanouts_.size(); }
218 
219   // Checks if a fanin exists for the node.
220   virtual bool HasFanin(const FanoutViewT& fanin) const = 0;
221 
222   // Checks if a fanout exists for the node.
223   virtual bool HasFanout(const FaninViewT& fanout) const = 0;
224 
225   // Returns an attribute of the node by key. If no attribute for such key
226   // exists, a `nullptr` is returned.
GetAttr(absl::string_view attr_name)227   const AttrValue* GetAttr(absl::string_view attr_name) const {
228     return attrs_.Find(attr_name);
229   }
230 
231   // Returns all attributes of the node.
GetAttrs()232   const AttrSlice& GetAttrs() const { return attrs_; }
233 
234   // Returns the number of attributes in the node.
NumAttrs()235   int NumAttrs() const { return attrs_.size(); }
236 
237   // Checks if an attribute exist in the node.
HasAttr(absl::string_view attr_name)238   bool HasAttr(absl::string_view attr_name) const {
239     return attrs_.Find(attr_name) != nullptr;
240   }
241 
242  protected:
243   virtual inline const FanoutViewT& GetMissingFanin() const = 0;
244   virtual inline const std::vector<FaninViewT>& GetMissingFanout() const = 0;
245 
246   std::vector<FanoutViewT> regular_fanins_;
247   std::vector<FanoutViewT> controlling_fanins_;
248   std::vector<std::vector<FaninViewT>> regular_fanouts_by_port_;
249   int num_regular_fanouts_ = 0;
250   std::vector<FaninViewT> controlled_fanouts_;
251 
252   GraphViewT* graph_view_;
253   int node_index_;
254   AttrSlice attrs_;
255 };
256 
257 // GraphViewInternal is a helper class to simplify graph traversal. It creates
258 // a view of the nodes and associated fanins and fanouts from the GraphDef
259 // protocol buffer.
260 //
261 // There are two public classes implementing GraphViewInternal:
262 //
263 // - GraphView: constructed from `const GraphDef` and doesn't allow mutating
264 //   the underlying graph and its nodes.
265 // - MutableGraphView: constructed from `GraphDef` and allows mutating the
266 //   underlying graph and its nodes.
267 //
268 // --------------------------- !!! WARNING !!! ---------------------------------
269 //     Modifying the graph outside of implementations of GraphViewInternal
270 //     (i.e. removing nodes from the GraphDef directly) may lead to
271 //     segfaults! Guaranteed by absl::string_view!
272 // -----------------------------------------------------------------------------
273 //
274 template <typename NodeViewT, typename FaninViewT, typename FanoutViewT,
275           bool IsConst>
276 class GraphViewInternal {
277  private:
278   using GraphDefT =
279       typename std::conditional<IsConst, const GraphDef, GraphDef>::type;
280 
281  public:
GraphViewInternal(GraphDefT * graph)282   explicit GraphViewInternal(GraphDefT* graph) : graph_(graph) {}
~GraphViewInternal()283   virtual ~GraphViewInternal() {}
284 
285   bool operator==(const GraphViewInternal& other) const {
286     return graph_ == other.graph_;
287   }
288 
graph()289   GraphDefT* graph() const { return graph_; }
290 
291   // Finds node by index in the graph. If no such node exists in the graph, a
292   // `nullptr` is returned.
GetNode(int node_index)293   const NodeViewT* GetNode(int node_index) const {
294     int nodes_size = nodes_.size();
295     if (node_index < 0 || node_index >= nodes_size) {
296       return nullptr;
297     }
298     return &nodes_[node_index];
299   }
300 
GetNode(int node_index)301   NodeViewT* GetNode(int node_index) {
302     int nodes_size = nodes_.size();
303     if (node_index < 0 || node_index >= nodes_size) {
304       return nullptr;
305     }
306     return &nodes_[node_index];
307   }
308 
309   // Finds node by name. If no such node exists in the graph, a `nullptr` is
310   // returned.
GetNode(absl::string_view node_name)311   const NodeViewT* GetNode(absl::string_view node_name) const {
312     auto it = node_index_by_name_.find(node_name);
313     if (it == node_index_by_name_.end()) {
314       return nullptr;
315     }
316     return &nodes_[it->second];
317   }
318 
GetNode(absl::string_view node_name)319   NodeViewT* GetNode(absl::string_view node_name) {
320     auto it = node_index_by_name_.find(node_name);
321     if (it == node_index_by_name_.end()) {
322       return nullptr;
323     }
324     return &nodes_[it->second];
325   }
326 
327   // Returns all nodes (as NodeView) in the graph.
GetNodes()328   const std::vector<NodeViewT>& GetNodes() const { return nodes_; }
329 
330   // Checks if a node by name exists in the graph.
HasNode(absl::string_view node_name)331   bool HasNode(absl::string_view node_name) const {
332     return node_index_by_name_.contains(node_name);
333   }
334 
335   // Returns the number of nodes in the graph.
NumNodes()336   int NumNodes() const { return nodes_.size(); }
337 
338  protected:
339   // Reset allocated node vector and node map in case of failure.
Reset()340   void Reset() {
341     std::vector<NodeViewT>().swap(nodes_);
342     absl::flat_hash_map<absl::string_view, int>().swap(node_index_by_name_);
343   }
344 
345   // nodes_[i] is a view of graph_.{mutable_}node(i).
346   std::vector<NodeViewT> nodes_;
347   absl::flat_hash_map<absl::string_view, int> node_index_by_name_;
348   GraphDefT* graph_;
349   const FanoutViewT missing_fanin_;
350   const std::vector<FaninViewT> missing_fanout_;
351 };
352 
EmptyTensorId()353 inline SafeTensorId EmptyTensorId() {
354   return SafeTensorId("", internal::kMissingSlot);
355 }
356 
IsEmptyTensorId(const TensorId tensor_id)357 inline bool IsEmptyTensorId(const TensorId tensor_id) {
358   return tensor_id.node().empty() &&
359          tensor_id.index() == internal::kMissingSlot;
360 }
361 
362 // NodeViewDiff is a helper struct holding changes to be made to an existing
363 // node in GraphViewT. This should not be initialized or be used directly.
364 template <typename GraphViewT>
365 struct NodeViewDiff {
NodeViewDiffNodeViewDiff366   explicit NodeViewDiff(GraphViewT* graph_view, int node_index)
367       : graph_view(graph_view), node_index(node_index) {}
368 
369   GraphViewT* graph_view;
370   int node_index;
371   string name;
372   bool update_name = false;
373   string op;
374   bool update_op = false;
375   string device;
376   bool update_device = false;
377   // Fanins to append after existing regular fanins.
378   std::vector<SafeTensorId> regular_inputs_to_add;
379   // Number of fanins to be appended. This is used for a quick comparison with
380   // `regular_inputs_to_add` for if there will be any missing inputs in the
381   // updated node.
382   int num_regular_inputs_to_add = 0;
383   // Fanins to update inplace.
384   std::map<int, SafeTensorId> regular_inputs_to_update;
385   // Fanins from end of regular fanins to remove. This keeps track of existing
386   // regular fanins in the original node to remove.
387   std::vector<bool> regular_inputs_to_remove;
388   // Number of fanins marked for removal. This is used for a quick comparison
389   // with `regular_inputs_to_remove` for if there will be any missing inputs
390   // in the updated node.
391   int num_regular_inputs_to_remove = 0;
392   absl::flat_hash_set<string> controlling_inputs_to_add;
393   std::set<int> controlling_inputs_to_remove;
394   absl::flat_hash_map<string, AttrValue> attrs_to_add;
395   absl::flat_hash_set<string> attrs_to_remove;
396   // AttrValueMap constructor and destructor are very expensive, we will
397   // initialize it lazily only if needed.
398   absl::optional<AttrValueMap> processed_attrs;
399 };
400 
401 // Updates node name. If `name` is the same as the name in the original node,
402 // the field will be cleared in the diff.
403 template <typename GraphViewT>
UpdateName(NodeViewDiff<GraphViewT> * diff,absl::string_view name)404 inline bool UpdateName(NodeViewDiff<GraphViewT>* diff, absl::string_view name) {
405   if (diff->graph_view->GetNode(diff->node_index)->GetName() == name) {
406     diff->name.clear();
407     diff->update_name = false;
408   } else {
409     diff->name = string(name);
410     diff->update_name = true;
411   }
412   return true;
413 }
414 
415 // Updates node op. If `op` is the same as the op in the original node, the
416 // field will be cleared in the diff.
417 template <typename GraphViewT>
UpdateOp(NodeViewDiff<GraphViewT> * diff,absl::string_view op)418 inline bool UpdateOp(NodeViewDiff<GraphViewT>* diff, absl::string_view op) {
419   if (diff->graph_view->GetNode(diff->node_index)->GetOp() == op) {
420     diff->op.clear();
421     diff->update_op = false;
422   } else {
423     diff->op = string(op);
424     diff->update_op = true;
425   }
426   return true;
427 }
428 
429 // Updates node device. If `device` is the same as the device in the original
430 // node, the field will be cleared in the diff.
431 template <typename GraphViewT>
UpdateDevice(NodeViewDiff<GraphViewT> * diff,absl::string_view device)432 inline bool UpdateDevice(NodeViewDiff<GraphViewT>* diff,
433                          absl::string_view device) {
434   if (diff->graph_view->GetNode(diff->node_index)->GetDevice() == device) {
435     diff->device.clear();
436     diff->update_device = false;
437   } else {
438     diff->device = string(device);
439     diff->update_device = true;
440   }
441   return true;
442 }
443 
444 // Adds or updates value in vector `v` at index `i`. This will also resize the
445 // vector if index `i` is out of bounds, padding the vector with
446 // `default_value`. Returns true if a new value was appended or if an update
447 // occurred where an existing value was changed from `default_value`.
448 template <typename T, typename U>
AddOrUpdateAtIndex(std::vector<T> * v,int i,const U & value,const T & default_value)449 inline bool AddOrUpdateAtIndex(std::vector<T>* v, int i, const U& value,
450                                const T& default_value) {
451   int v_size = v->size();
452   if (i > v_size) {
453     // Resize to include `value`, filling the newly introduced gap with
454     // `default_value` for later checks of validity (gaps in vector).
455     v->reserve(i + 1);
456     v->resize(i, default_value);
457     v->push_back({value});
458   } else if (i == v_size) {
459     // Vector is large enough, simply append `value` to the end.
460     v->push_back({value});
461   } else {
462     // Update existing value.
463     bool updated = (*v)[i] == default_value;
464     (*v)[i] = {value};
465     return updated;
466   }
467   return true;
468 }
469 
470 // Checks if a node with name `node_name` will exist in the final mutated graph.
471 template <typename GraphViewT>
CheckNodeNameExists(absl::string_view node_name,const absl::flat_hash_map<absl::string_view,int> & updated_node_names,const GraphViewT * graph_view)472 inline bool CheckNodeNameExists(
473     absl::string_view node_name,
474     const absl::flat_hash_map<absl::string_view, int>& updated_node_names,
475     const GraphViewT* graph_view) {
476   auto it = updated_node_names.find(node_name);
477   if (it != updated_node_names.end()) {
478     return it->second == kNodeNamePresent;
479   }
480   return graph_view->HasNode(node_name);
481 }
482 
483 // Adds or updates regular fanin at `index` of regular fanins. If `index` is
484 // less than the number of regular fanins in the original node, the fanin at
485 // `index` in the original node will be updated with `fanin` if the fanin
486 // differs. If `index` is greater than or equal to the number of regular fanins,
487 // `fanin` will be added beyond the end of regular fanins at `index`.
488 template <typename GraphViewT>
AddOrUpdateRegularFanin(NodeViewDiff<GraphViewT> * diff,int index,const TensorId & fanin)489 inline bool AddOrUpdateRegularFanin(NodeViewDiff<GraphViewT>* diff, int index,
490                                     const TensorId& fanin) {
491   if (index < 0) {
492     // Not a valid index for regular fanins.
493     return false;
494   }
495   auto* node_view = diff->graph_view->GetNode(diff->node_index);
496   const int num_regular_fanins = node_view->NumRegularFanins();
497   if (index < num_regular_fanins) {  // Updating existing fanins.
498     // Calculate (relative) index from end of regular fanins, from absolute
499     // index from beginning of regular fanins.
500     const int relative_removal_index = num_regular_fanins - index - 1;
501     // Check if at relative index fanin was already marked for removal.
502     int diff_regular_inputs_to_remove_size =
503         diff->regular_inputs_to_remove.size();
504     if (relative_removal_index < diff_regular_inputs_to_remove_size &&
505         diff->regular_inputs_to_remove[relative_removal_index]) {
506       // Unmark fanin for removal.
507       diff->regular_inputs_to_remove[relative_removal_index] = false;
508       --diff->num_regular_inputs_to_remove;
509     }
510     const auto& existing_fanin = node_view->GetRegularFanin(index);
511     if (existing_fanin.index() != fanin.index() ||
512         existing_fanin.node_view()->GetName() != fanin.node()) {
513       // Update fanin if it is different from original fanin in node.
514       gtl::InsertOrUpdate(&diff->regular_inputs_to_update, index,
515                           SafeTensorId(fanin));
516     }
517   } else {
518     // Add fanin beyond current fanin range.
519     const int relative_add_index = index - num_regular_fanins;
520     if (AddOrUpdateAtIndex(&diff->regular_inputs_to_add, relative_add_index,
521                            fanin, EmptyTensorId())) {
522       // New fanin was added.
523       ++diff->num_regular_inputs_to_add;
524     }
525   }
526   return true;
527 }
528 
529 // Remove regular fanin at `index` of regular fanins. This can remove existing
530 // fanins and updated/added fanins via AddOrUpdateRegularFanins.
531 template <typename GraphViewT>
RemoveRegularFanin(NodeViewDiff<GraphViewT> * diff,int index)532 inline bool RemoveRegularFanin(NodeViewDiff<GraphViewT>* diff, int index) {
533   if (index < 0) {
534     // Not a valid index for regular fanins.
535     return false;
536   }
537   auto* node_view = diff->graph_view->GetNode(diff->node_index);
538   const int num_regular_fanins = node_view->NumRegularFanins();
539   if (index < num_regular_fanins) {  // Removing existing fanins.
540     // Remove updated fanin if it exists.
541     diff->regular_inputs_to_update.erase(index);
542     // Calculate (relative) index from end of regular fanins, from absolute
543     // index from beginning of regular fanins.
544     const int relative_removal_index = num_regular_fanins - index - 1;
545     if (AddOrUpdateAtIndex(&diff->regular_inputs_to_remove,
546                            relative_removal_index,
547                            /*value=*/true, /*default_value=*/false)) {
548       ++diff->num_regular_inputs_to_remove;
549     }
550   } else {
551     // Relative index from end of regular fanins.
552     const int relative_add_index = index - num_regular_fanins;
553     int diff_regular_inputs_to_add_size = diff->regular_inputs_to_add.size();
554     if (relative_add_index >= diff_regular_inputs_to_add_size ||
555         IsEmptyTensorId(diff->regular_inputs_to_add[relative_add_index])) {
556       // At relative index, appended regular fanin was already marked for
557       // removal.
558       return false;
559     }
560     // Remove added fanin.
561     diff->regular_inputs_to_add[relative_add_index] = EmptyTensorId();
562     --diff->num_regular_inputs_to_add;
563   }
564   return true;
565 }
566 
567 // Adds controlling fanin. If the controlling fanin already exists in the
568 // original node, it will be dedupped. If the controlling fanin is marked for
569 // removal, this will reverse it.
570 template <typename GraphViewT>
AddControllingFanin(NodeViewDiff<GraphViewT> * diff,int control_index,absl::string_view fanin_node_name)571 inline bool AddControllingFanin(NodeViewDiff<GraphViewT>* diff,
572                                 int control_index,
573                                 absl::string_view fanin_node_name) {
574   if (control_index == kMissingIndex) {
575     diff->controlling_inputs_to_add.emplace(fanin_node_name);
576   } else {
577     diff->controlling_inputs_to_remove.erase(control_index);
578   }
579   return true;
580 }
581 
582 // Remove controlling fanin. If the controlling fanin does not exist in the
583 // original node and diff, nothing will happen. If the controlling fanin exists
584 // in the diff, it will be removed. Otherwise the controlling fanin will be
585 // marked for removal from the original node.
586 template <typename GraphViewT>
RemoveControllingFanin(NodeViewDiff<GraphViewT> * diff,int control_index,absl::string_view fanin_node_name)587 inline bool RemoveControllingFanin(NodeViewDiff<GraphViewT>* diff,
588                                    int control_index,
589                                    absl::string_view fanin_node_name) {
590   if (control_index == kMissingIndex) {
591     diff->controlling_inputs_to_add.erase(fanin_node_name);
592   } else {
593     diff->controlling_inputs_to_remove.emplace(control_index);
594   }
595   return true;
596 }
597 
598 // Adds or updates an attribute by name. If an attribute exist in the original
599 // node or diff (including those marked for removal), this will overwrite it.
600 template <typename GraphViewT>
AddOrUpdateAttribute(NodeViewDiff<GraphViewT> * diff,absl::string_view attr_name,const AttrValue & attr_value)601 inline bool AddOrUpdateAttribute(NodeViewDiff<GraphViewT>* diff,
602                                  absl::string_view attr_name,
603                                  const AttrValue& attr_value) {
604   diff->attrs_to_add.empty() ? 0 : diff->attrs_to_remove.erase(attr_name);
605   gtl::InsertOrUpdate(&diff->attrs_to_add, string(attr_name), attr_value);
606   return true;
607 }
608 
609 // Removes an attribute by name. If an attribute exist in the original node or
610 // diff, this will remove it.
611 template <typename GraphViewT>
RemoveAttribute(NodeViewDiff<GraphViewT> * diff,absl::string_view attr_name)612 inline bool RemoveAttribute(NodeViewDiff<GraphViewT>* diff,
613                             absl::string_view attr_name) {
614   const size_t num_erased =
615       diff->attrs_to_add.empty() ? 0 : diff->attrs_to_add.erase(attr_name);
616   auto* node_view = diff->graph_view->GetNode(diff->node_index);
617   if (node_view->HasAttr(attr_name)) {
618     diff->attrs_to_remove.emplace(attr_name);
619     return true;
620   }
621   return num_erased > 0;
622 }
623 
624 // Removes trailing values in vector `v` for values equal to `value`.
625 template <typename T>
ResizeByTrimmingEndForValue(std::vector<T> * v,const T & value)626 inline void ResizeByTrimmingEndForValue(std::vector<T>* v, const T& value) {
627   int curr_index = v->size();
628   const int last_index = v->size() - 1;
629   for (int i = last_index; i >= 0; --i) {
630     if ((*v)[i] == value) {
631       curr_index = i;
632     } else {
633       break;
634     }
635   }
636   if (curr_index <= last_index) {
637     v->resize(curr_index);
638   }
639 }
640 
641 // Checks if any changes are set in the diff.
642 template <typename GraphViewT>
IsEmpty(NodeViewDiff<GraphViewT> * diff)643 inline bool IsEmpty(NodeViewDiff<GraphViewT>* diff) {
644   ResizeByTrimmingEndForValue(&diff->regular_inputs_to_remove, false);
645   ResizeByTrimmingEndForValue(&diff->regular_inputs_to_add, EmptyTensorId());
646   return !diff->update_name && !diff->update_op && !diff->update_device &&
647          diff->regular_inputs_to_add.empty() &&
648          diff->regular_inputs_to_update.empty() &&
649          diff->regular_inputs_to_remove.empty() &&
650          diff->controlling_inputs_to_add.empty() &&
651          diff->controlling_inputs_to_remove.empty() &&
652          diff->attrs_to_add.empty() && diff->attrs_to_remove.empty();
653 }
654 
655 // Resets and clears existing diff.
656 template <typename GraphViewT>
Reset(NodeViewDiff<GraphViewT> * diff)657 inline void Reset(NodeViewDiff<GraphViewT>* diff) {
658   diff->name.clear();
659   diff->update_name = false;
660   diff->op.clear();
661   diff->update_op = false;
662   diff->device.clear();
663   diff->update_device = false;
664   std::vector<SafeTensorId>().swap(diff->regular_inputs_to_add);
665   diff->num_regular_inputs_to_add = false;
666   std::map<int, SafeTensorId>().swap(diff->regular_inputs_to_update);
667   std::vector<bool>().swap(diff->regular_inputs_to_remove);
668   diff->num_regular_inputs_to_remove = 0;
669   absl::flat_hash_set<string>().swap(diff->controlling_inputs_to_add);
670   std::set<int>().swap(diff->controlling_inputs_to_remove);
671   absl::flat_hash_map<string, AttrValue>().swap(diff->attrs_to_add);
672   absl::flat_hash_set<string>().swap(diff->attrs_to_remove);
673 }
674 
675 // Checks if changes to node will result in a valid node.
676 template <typename GraphViewT>
IsWellFormed(NodeViewDiff<GraphViewT> * diff,const absl::flat_hash_map<absl::string_view,int> & updated_node_names)677 inline bool IsWellFormed(
678     NodeViewDiff<GraphViewT>* diff,
679     const absl::flat_hash_map<absl::string_view, int>& updated_node_names) {
680   ResizeByTrimmingEndForValue(&diff->regular_inputs_to_remove, false);
681   ResizeByTrimmingEndForValue(&diff->regular_inputs_to_add, EmptyTensorId());
682   int diff_regular_inputs_to_add_size = diff->regular_inputs_to_add.size();
683   if (diff_regular_inputs_to_add_size != diff->num_regular_inputs_to_add) {
684     // Missing regular fanins in between appended fanins.
685     return false;
686   } else if (diff->num_regular_inputs_to_add > 0 &&
687              !diff->regular_inputs_to_remove.empty()) {
688     // Appending new fanins while removing existing fanins, resulting in missing
689     // regular fanins in between.
690     return false;
691   } else if (static_cast<int>(diff->regular_inputs_to_remove.size()) !=
692              diff->num_regular_inputs_to_remove) {
693     // Regular fanins exist in between removed fanins.
694     return false;
695   }
696   auto* node_view = diff->graph_view->GetNode(diff->node_index);
697   const string& node_name =
698       diff->update_name ? diff->name : node_view->GetName();
699   auto invalid_node_name = [&](absl::string_view fanin_node_name) -> bool {
700     return fanin_node_name == node_name ||
701            !CheckNodeNameExists(fanin_node_name, updated_node_names,
702                                 diff->graph_view);
703   };
704 
705   // Check if nodes of all updated and new fanins exist (from name) and if such
706   // fanins do not introduce self loops. Note, this will not check for if
707   // unmodified fanins exist.
708   if (diff->update_name) {
709     // If name of node was changed in node, check all fanins. Updated fanins are
710     // checked for existence and self loops. Unmodified fanins are checked for
711     // self loops.
712     // `regular_inputs_to_update`, `controlling_inputs_to_remove` are sorted,
713     // so iterators from these maps/sets can be incremented alongside iteration
714     // and be used for comparisons.
715     const int last_index =
716         node_view->NumRegularFanins() - diff->num_regular_inputs_to_remove - 1;
717     auto regular_to_update_it = diff->regular_inputs_to_update.begin();
718     for (int i = 0; i <= last_index; ++i) {
719       if (regular_to_update_it != diff->regular_inputs_to_update.end() &&
720           regular_to_update_it->first < i) {
721         ++regular_to_update_it;
722       }
723       if (regular_to_update_it != diff->regular_inputs_to_update.end() &&
724           regular_to_update_it->first == i) {
725         if (invalid_node_name(regular_to_update_it->second.node())) {
726           return false;
727         }
728       } else {
729         const string& regular_name =
730             node_view->GetRegularFanin(i).node_view()->GetName();
731         if (regular_name == node_name) {
732           return false;
733         }
734       }
735     }
736 
737     auto& controls = node_view->GetControllingFanins();
738     const int num_controls = controls.size();
739     auto control_to_remove_it = diff->controlling_inputs_to_remove.begin();
740     for (int i = 0; i < num_controls; ++i) {
741       if (control_to_remove_it != diff->controlling_inputs_to_remove.end() &&
742           *control_to_remove_it < i) {
743         ++control_to_remove_it;
744       }
745       if (control_to_remove_it != diff->controlling_inputs_to_remove.end() &&
746           *control_to_remove_it == i) {
747         // Control dependency marked for removal, can be ignored.
748         continue;
749       } else if (controls[i].node_view()->GetName() == node_name) {
750         return false;
751       }
752     }
753   } else {
754     // Name of node was not changed, check only updated fanins under the
755     // assumption prior fanins were valid.
756     for (const auto& updated : diff->regular_inputs_to_update) {
757       const string& fanin_name = updated.second.node();
758       if (invalid_node_name(fanin_name)) {
759         return false;
760       }
761     }
762   }
763   // Check appended regular fanins.
764   for (const auto& regular : diff->regular_inputs_to_add) {
765     if (invalid_node_name(regular.node())) {
766       return false;
767     }
768   }
769   // Check new controlling fanins.
770   for (const auto& control : diff->controlling_inputs_to_add) {
771     if (invalid_node_name(control)) {
772       return false;
773     }
774   }
775 
776   return true;
777 }
778 
779 // NewNode is a helper struct holding a new node to be added to a GraphViewT.
780 // This should not be initialized or be used directly.
781 template <typename GraphViewT>
782 struct NewNode {
NewNodeNewNode783   explicit NewNode(GraphViewT* graph_view, NodeDef&& node)
784       : graph_view(graph_view), node(std::move(node)) {}
785 
786   GraphViewT* graph_view;
787   NodeDef node;
788   std::vector<SafeTensorId> regular_fanins;
789   int num_regular_fanins = 0;
790   absl::flat_hash_set<string> controlling_fanins;
791 };
792 
793 // Updates new node name.
794 template <typename GraphViewT>
UpdateName(NewNode<GraphViewT> * new_node,absl::string_view name)795 inline void UpdateName(NewNode<GraphViewT>* new_node, absl::string_view name) {
796   if (name.empty()) {
797     new_node->node.clear_name();
798   } else {
799     new_node->node.set_name(string(name));
800   }
801 }
802 
803 // Updates new node op.
804 template <typename GraphViewT>
UpdateOp(NewNode<GraphViewT> * new_node,absl::string_view op)805 inline void UpdateOp(NewNode<GraphViewT>* new_node, absl::string_view op) {
806   if (op.empty()) {
807     new_node->node.clear_op();
808   } else {
809     new_node->node.set_op(string(op));
810   }
811 }
812 
813 // Updates new node device.
814 template <typename GraphViewT>
UpdateDevice(NewNode<GraphViewT> * new_node,absl::string_view device)815 inline void UpdateDevice(NewNode<GraphViewT>* new_node,
816                          absl::string_view device) {
817   if (device.empty()) {
818     new_node->node.clear_device();
819   } else {
820     new_node->node.set_device(string(device));
821   }
822 }
823 
824 // Adds or updates regular fanin at `index` of regular fanins in the new node.
825 // If another fanin already exists at `index`, it will be replaced with `fanin`.
826 template <typename GraphViewT>
AddOrUpdateRegularFanin(NewNode<GraphViewT> * new_node,int index,const TensorId & fanin)827 inline void AddOrUpdateRegularFanin(NewNode<GraphViewT>* new_node, int index,
828                                     const TensorId& fanin) {
829   if (index < 0) {
830     // Not a valid index for regular fanins.
831     return;
832   } else if (AddOrUpdateAtIndex(&new_node->regular_fanins, index, fanin,
833                                 EmptyTensorId())) {
834     ++new_node->num_regular_fanins;
835   }
836 }
837 
838 // Remove regular fanin at `index` of regular fanins in the new node. This can
839 // remove existing fanins and updated/added fanins via AddOrUpdateRegularFanins.
840 template <typename GraphViewT>
RemoveRegularFanin(NewNode<GraphViewT> * new_node,int index)841 inline void RemoveRegularFanin(NewNode<GraphViewT>* new_node, int index) {
842   int new_node_regular_fanins_size = new_node->regular_fanins.size();
843   if (index < 0 || index >= new_node_regular_fanins_size ||
844       IsEmptyTensorId(new_node->regular_fanins[index])) {
845     return;
846   }
847   new_node->regular_fanins[index] = EmptyTensorId();
848   --new_node->num_regular_fanins;
849 }
850 
851 // Adds controlling fanin to new node.
852 template <typename GraphViewT>
AddControllingFanin(NewNode<GraphViewT> * new_node,absl::string_view fanin_node_name)853 inline void AddControllingFanin(NewNode<GraphViewT>* new_node,
854                                 absl::string_view fanin_node_name) {
855   new_node->controlling_fanins.emplace(fanin_node_name);
856 }
857 
858 // Removes controlling fanin to new node.
859 template <typename GraphViewT>
RemoveControllingFanin(NewNode<GraphViewT> * new_node,absl::string_view fanin_node_name)860 inline void RemoveControllingFanin(NewNode<GraphViewT>* new_node,
861                                    absl::string_view fanin_node_name) {
862   new_node->controlling_fanins.erase(fanin_node_name);
863 }
864 
865 // Adds or updates an attribute by name to a new node.
866 template <typename GraphViewT>
AddOrUpdateAttribute(NewNode<GraphViewT> * new_node,absl::string_view attr_name,const AttrValue & attr_value)867 inline void AddOrUpdateAttribute(NewNode<GraphViewT>* new_node,
868                                  absl::string_view attr_name,
869                                  const AttrValue& attr_value) {
870   gtl::InsertOrUpdate(new_node->node.mutable_attr(), string(attr_name),
871                       attr_value);
872 }
873 
874 // Removes an attribute by name to a new node.
875 template <typename GraphViewT>
RemoveAttribute(NewNode<GraphViewT> * new_node,absl::string_view attr_name)876 inline void RemoveAttribute(NewNode<GraphViewT>* new_node,
877                             absl::string_view attr_name) {
878   new_node->node.mutable_attr()->erase(string(attr_name));
879 }
880 
881 // Checks if current state of new node is a valid node.
882 template <typename GraphViewT>
IsWellFormed(NewNode<GraphViewT> * new_node,const absl::flat_hash_map<absl::string_view,int> & updated_node_names)883 inline bool IsWellFormed(
884     NewNode<GraphViewT>* new_node,
885     const absl::flat_hash_map<absl::string_view, int>& updated_node_names) {
886   ResizeByTrimmingEndForValue(&new_node->regular_fanins, EmptyTensorId());
887   int new_node_regular_fanins_size = new_node->regular_fanins.size();
888   if (new_node_regular_fanins_size != new_node->num_regular_fanins) {
889     return false;
890   }
891 
892   const string& node_name = new_node->node.name();
893   auto invalid_node_name = [new_node, updated_node_names,
894                             node_name](absl::string_view fanin_node_name) {
895     return fanin_node_name == node_name ||
896            !CheckNodeNameExists(fanin_node_name, updated_node_names,
897                                 new_node->graph_view);
898   };
899   // Check if nodes of all fanins exist (from name) and if fanins do not
900   // introduce self loops.
901   for (const auto& regular : new_node->regular_fanins) {
902     if (invalid_node_name(regular.node())) {
903       return false;
904     }
905   }
906   for (const auto& control : new_node->controlling_fanins) {
907     if (invalid_node_name(control)) {
908       return false;
909     }
910   }
911 
912   return true;
913 }
914 
915 }  // namespace internal
916 }  // namespace utils
917 }  // namespace grappler
918 }  // namespace tensorflow
919 
920 #endif  // TENSORFLOW_CORE_GRAPPLER_UTILS_GRAPH_VIEW_INTERNAL_H_
921