1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/core/grappler/utils/frame.h"
17
18 #include <deque>
19
20 #include "tensorflow/core/framework/attr_value.pb.h"
21 #include "tensorflow/core/framework/node_def.pb.h"
22 #include "tensorflow/core/grappler/op_types.h"
23 #include "tensorflow/core/lib/core/errors.h"
24
25 namespace tensorflow {
26 namespace grappler {
27
28 namespace {} // namespace
29
30 template <typename GraphViewT>
InferFromGraphViewT(const GraphViewT & graph_view)31 inline Status FrameView::InferFromGraphViewT(const GraphViewT& graph_view) {
32 if (is_inferred_) {
33 return errors::Internal("FrameView was already inferred from the graph");
34 }
35 is_inferred_ = true;
36
37 std::deque<int> ready_node_indices;
38
39 // All nodes without inputs are automatically added to the ready queue.
40 for (const auto& node : graph_view.GetNodes()) {
41 if (node.NumRegularFanins() + node.NumControllingFanins() == 0) {
42 ready_node_indices.push_back(node.node_index());
43 node_to_frames_[node.node()] = node_has_no_frames_;
44 }
45 }
46
47 const auto* graph = graph_view.graph();
48
49 // We assign unique int id to each frame, and use this map to track what
50 // frames we've already seen in the graph.
51 absl::flat_hash_map<string, int> frame_name_to_id;
52
53 auto process_fanout = [this, graph](
54 absl::flat_hash_map<string, int>* frame_name_to_id,
55 std::deque<int>* ready_node_indices,
56 const NodeDef* ready_node, int fanout_node_index) {
57 const NodeDef* fanout_node = &graph->node(fanout_node_index);
58 if (!node_to_frames_.contains(fanout_node)) {
59 // If we have never seen this node before, we add all frames from the
60 // incoming node (and pop/push frames if coming from Exit/Enter nodes).
61 std::vector<int> frame_ids = node_to_frames_[ready_node];
62
63 if (IsExit(*ready_node)) {
64 frame_ids.pop_back();
65 }
66
67 if (IsEnter(*fanout_node)) {
68 const AttrValue* frame_name_attr =
69 AttrSlice(*fanout_node).Find("frame_name");
70
71 if (!frame_name_attr) {
72 return errors::InvalidArgument(
73 "Missing frame name for the Enter node: ",
74 SummarizeNodeDef(*fanout_node));
75 }
76
77 const string& frame_name = frame_name_attr->s();
78 int frame_id;
79
80 if (frame_name_to_id->contains(frame_name)) {
81 frame_id = (*frame_name_to_id)[frame_name];
82 } else {
83 frame_id = static_cast<int>(frame_name_to_id->size());
84 (*frame_name_to_id)[frame_name] = frame_id;
85 }
86
87 frame_ids.push_back(frame_id);
88 }
89
90 ready_node_indices->push_back(fanout_node_index);
91 node_to_frames_[fanout_node] = std::move(frame_ids);
92
93 } else {
94 // If we've already seen this node before, we need to make sure that graph
95 // is correct and same nodes doesn't have incoming edges with conflicting
96 // frames (all inputs must be produces in the same frame).
97
98 std::vector<int> frame_ids_fanout = node_to_frames_[fanout_node];
99 std::vector<int> frame_ids_node = node_to_frames_[ready_node];
100
101 if (IsEnter(*fanout_node)) {
102 frame_ids_fanout.pop_back();
103 }
104 if (IsExit(*ready_node)) {
105 frame_ids_node.pop_back();
106 }
107
108 if (frame_ids_node != frame_ids_fanout) {
109 return errors::InvalidArgument(
110 "Invalid graph: Frame ids for node ", ready_node->name(),
111 " does not match frame ids for it's fanout ", fanout_node->name());
112 }
113 }
114 return Status::OK();
115 };
116
117 while (!ready_node_indices.empty()) {
118 const int ready_node_index = ready_node_indices.front();
119 ready_node_indices.pop_front();
120 const auto* ready_node_view = graph_view.GetNode(ready_node_index);
121 const NodeDef* ready_node_def = ready_node_view->node();
122
123 for (const auto& regular_fanouts_port_i :
124 ready_node_view->GetRegularFanouts()) {
125 for (const auto& regular_fanout : regular_fanouts_port_i) {
126 TF_RETURN_IF_ERROR(process_fanout(&frame_name_to_id,
127 &ready_node_indices, ready_node_def,
128 regular_fanout.node_index()));
129 }
130 }
131
132 for (const auto& controlled_fanout :
133 ready_node_view->GetControlledFanouts()) {
134 TF_RETURN_IF_ERROR(process_fanout(&frame_name_to_id, &ready_node_indices,
135 ready_node_def,
136 controlled_fanout.node_index()));
137 }
138 }
139
140 num_frames_ = static_cast<int>(frame_name_to_id.size());
141 return Status::OK();
142 }
143
InferFromGraphView(const utils::GraphView & graph_view)144 Status FrameView::InferFromGraphView(const utils::GraphView& graph_view) {
145 return InferFromGraphViewT(graph_view);
146 }
147
InferFromGraphView(const utils::MutableGraphView & graph_view)148 Status FrameView::InferFromGraphView(
149 const utils::MutableGraphView& graph_view) {
150 return InferFromGraphViewT(graph_view);
151 }
152
InferFromGraph(const GraphDef & graph)153 Status FrameView::InferFromGraph(const GraphDef& graph) {
154 Status status;
155 utils::GraphView graph_view(&graph, &status);
156 TF_RETURN_IF_ERROR(status);
157 return InferFromGraphViewT(graph_view);
158 }
159
Frames(const NodeDef & node) const160 const std::vector<int>& FrameView::Frames(const NodeDef& node) const {
161 DCHECK(is_inferred_) << "FrameView is not initialized";
162 auto frames = node_to_frames_.find(&node);
163 if (frames == node_to_frames_.end()) {
164 LOG(WARNING) << "Node '" << node.name()
165 << "' doesn't belong to the graph used for initialization";
166 return node_has_no_frames_;
167 } else {
168 return frames->second;
169 }
170 }
171
IsInFrame(const NodeDef & node) const172 bool FrameView::IsInFrame(const NodeDef& node) const {
173 return !Frames(node).empty();
174 }
175
176 } // namespace grappler
177 } // namespace tensorflow
178