1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/grappler/utils/frame.h"
17 
18 #include <deque>
19 
20 #include "tensorflow/core/framework/attr_value.pb.h"
21 #include "tensorflow/core/framework/node_def.pb.h"
22 #include "tensorflow/core/grappler/op_types.h"
23 #include "tensorflow/core/lib/core/errors.h"
24 
25 namespace tensorflow {
26 namespace grappler {
27 
28 namespace {}  // namespace
29 
30 template <typename GraphViewT>
InferFromGraphViewT(const GraphViewT & graph_view)31 inline Status FrameView::InferFromGraphViewT(const GraphViewT& graph_view) {
32   if (is_inferred_) {
33     return errors::Internal("FrameView was already inferred from the graph");
34   }
35   is_inferred_ = true;
36 
37   std::deque<int> ready_node_indices;
38 
39   // All nodes without inputs are automatically added to the ready queue.
40   for (const auto& node : graph_view.GetNodes()) {
41     if (node.NumRegularFanins() + node.NumControllingFanins() == 0) {
42       ready_node_indices.push_back(node.node_index());
43       node_to_frames_[node.node()] = node_has_no_frames_;
44     }
45   }
46 
47   const auto* graph = graph_view.graph();
48 
49   // We assign unique int id to each frame, and use this map to track what
50   // frames we've already seen in the graph.
51   absl::flat_hash_map<string, int> frame_name_to_id;
52 
53   auto process_fanout = [this, graph](
54                             absl::flat_hash_map<string, int>* frame_name_to_id,
55                             std::deque<int>* ready_node_indices,
56                             const NodeDef* ready_node, int fanout_node_index) {
57     const NodeDef* fanout_node = &graph->node(fanout_node_index);
58     if (!node_to_frames_.contains(fanout_node)) {
59       // If we have never seen this node before, we add all frames from the
60       // incoming node (and pop/push frames if coming from Exit/Enter nodes).
61       std::vector<int> frame_ids = node_to_frames_[ready_node];
62 
63       if (IsExit(*ready_node)) {
64         frame_ids.pop_back();
65       }
66 
67       if (IsEnter(*fanout_node)) {
68         const AttrValue* frame_name_attr =
69             AttrSlice(*fanout_node).Find("frame_name");
70 
71         if (!frame_name_attr) {
72           return errors::InvalidArgument(
73               "Missing frame name for the Enter node: ",
74               SummarizeNodeDef(*fanout_node));
75         }
76 
77         const string& frame_name = frame_name_attr->s();
78         int frame_id;
79 
80         if (frame_name_to_id->contains(frame_name)) {
81           frame_id = (*frame_name_to_id)[frame_name];
82         } else {
83           frame_id = static_cast<int>(frame_name_to_id->size());
84           (*frame_name_to_id)[frame_name] = frame_id;
85         }
86 
87         frame_ids.push_back(frame_id);
88       }
89 
90       ready_node_indices->push_back(fanout_node_index);
91       node_to_frames_[fanout_node] = std::move(frame_ids);
92 
93     } else {
94       // If we've already seen this node before, we need to make sure that graph
95       // is correct and same nodes doesn't have incoming edges with conflicting
96       // frames (all inputs must be produces in the same frame).
97 
98       std::vector<int> frame_ids_fanout = node_to_frames_[fanout_node];
99       std::vector<int> frame_ids_node = node_to_frames_[ready_node];
100 
101       if (IsEnter(*fanout_node)) {
102         frame_ids_fanout.pop_back();
103       }
104       if (IsExit(*ready_node)) {
105         frame_ids_node.pop_back();
106       }
107 
108       if (frame_ids_node != frame_ids_fanout) {
109         return errors::InvalidArgument(
110             "Invalid graph: Frame ids for node ", ready_node->name(),
111             " does not match frame ids for it's fanout ", fanout_node->name());
112       }
113     }
114     return Status::OK();
115   };
116 
117   while (!ready_node_indices.empty()) {
118     const int ready_node_index = ready_node_indices.front();
119     ready_node_indices.pop_front();
120     const auto* ready_node_view = graph_view.GetNode(ready_node_index);
121     const NodeDef* ready_node_def = ready_node_view->node();
122 
123     for (const auto& regular_fanouts_port_i :
124          ready_node_view->GetRegularFanouts()) {
125       for (const auto& regular_fanout : regular_fanouts_port_i) {
126         TF_RETURN_IF_ERROR(process_fanout(&frame_name_to_id,
127                                           &ready_node_indices, ready_node_def,
128                                           regular_fanout.node_index()));
129       }
130     }
131 
132     for (const auto& controlled_fanout :
133          ready_node_view->GetControlledFanouts()) {
134       TF_RETURN_IF_ERROR(process_fanout(&frame_name_to_id, &ready_node_indices,
135                                         ready_node_def,
136                                         controlled_fanout.node_index()));
137     }
138   }
139 
140   num_frames_ = static_cast<int>(frame_name_to_id.size());
141   return Status::OK();
142 }
143 
InferFromGraphView(const utils::GraphView & graph_view)144 Status FrameView::InferFromGraphView(const utils::GraphView& graph_view) {
145   return InferFromGraphViewT(graph_view);
146 }
147 
InferFromGraphView(const utils::MutableGraphView & graph_view)148 Status FrameView::InferFromGraphView(
149     const utils::MutableGraphView& graph_view) {
150   return InferFromGraphViewT(graph_view);
151 }
152 
InferFromGraph(const GraphDef & graph)153 Status FrameView::InferFromGraph(const GraphDef& graph) {
154   Status status;
155   utils::GraphView graph_view(&graph, &status);
156   TF_RETURN_IF_ERROR(status);
157   return InferFromGraphViewT(graph_view);
158 }
159 
Frames(const NodeDef & node) const160 const std::vector<int>& FrameView::Frames(const NodeDef& node) const {
161   DCHECK(is_inferred_) << "FrameView is not initialized";
162   auto frames = node_to_frames_.find(&node);
163   if (frames == node_to_frames_.end()) {
164     LOG(WARNING) << "Node '" << node.name()
165                  << "' doesn't belong to the graph used for initialization";
166     return node_has_no_frames_;
167   } else {
168     return frames->second;
169   }
170 }
171 
IsInFrame(const NodeDef & node) const172 bool FrameView::IsInFrame(const NodeDef& node) const {
173   return !Frames(node).empty();
174 }
175 
176 }  // namespace grappler
177 }  // namespace tensorflow
178