1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/grappler/utils/frame.h"
17 #include <deque>
18 #include "tensorflow/core/framework/attr_value.pb.h"
19 #include "tensorflow/core/framework/node_def.pb.h"
20 #include "tensorflow/core/grappler/graph_view.h"
21 #include "tensorflow/core/grappler/op_types.h"
22 #include "tensorflow/core/lib/core/errors.h"
23 
24 namespace tensorflow {
25 namespace grappler {
26 
27 namespace {}  // namespace
28 
InferFromGraphView(const GraphView & graph_view)29 Status FrameView::InferFromGraphView(const GraphView& graph_view) {
30   if (is_inferred_) {
31     return errors::Internal("FrameView was already inferred from the graph");
32   }
33   is_inferred_ = true;
34 
35   std::deque<const NodeDef*> ready_nodes;
36 
37   // All nodes without inputs are automatically added to the ready queue.
38   for (const NodeDef& node : graph_view.graph()->node()) {
39     if (node.input_size() == 0) {
40       ready_nodes.push_back(&node);
41       node_to_frames_[&node] = node_has_no_frames_;
42     }
43   }
44 
45   // We assign unique int id to each frame, and use this map to track what
46   // frames we've already seen in the graph.
47   absl::flat_hash_map<string, int> frame_name_to_id;
48 
49   while (!ready_nodes.empty()) {
50     const NodeDef* ready_node = ready_nodes.front();
51 
52     absl::flat_hash_set<GraphView::InputPort> fanouts =
53         graph_view.GetFanouts(*ready_node, /*include_controlled_nodes=*/true);
54 
55     for (const GraphView::InputPort& fanout : fanouts) {
56       if (node_to_frames_.count(fanout.node) < 1) {
57         // If we have never seen this node before, we add all frames from the
58         // incoming node (and pop/push frames if coming from Exit/Enter nodes).
59         std::vector<int> frame_ids = node_to_frames_[ready_node];
60 
61         if (IsExit(*ready_node)) {
62           frame_ids.pop_back();
63         }
64 
65         if (IsEnter(*fanout.node)) {
66           const AttrValue* frame_name_attr =
67               AttrSlice(*fanout.node).Find("frame_name");
68 
69           if (!frame_name_attr) {
70             return errors::InvalidArgument(
71                 "Missing frame name for the Enter node: ",
72                 SummarizeNodeDef(*fanout.node));
73           }
74 
75           absl::string_view frame_name = frame_name_attr->s();
76           int frame_id;
77 
78           if (frame_name_to_id.count(frame_name)) {
79             frame_id = frame_name_to_id[frame_name];
80           } else {
81             frame_id = static_cast<int>(frame_name_to_id.size());
82             frame_name_to_id[frame_name] = frame_id;
83           }
84 
85           frame_ids.push_back(frame_id);
86         }
87 
88         ready_nodes.push_back(fanout.node);
89         node_to_frames_[fanout.node] = std::move(frame_ids);
90 
91       } else {
92         // If we've already seen this node before, we need to make sure that
93         // graph is correct and same nodes doesn't have incoming edges with
94         // conflicting frames (all inputs must be produces in the same frame).
95 
96         std::vector<int> frame_ids_fanout = node_to_frames_[fanout.node];
97         std::vector<int> frame_ids_node = node_to_frames_[ready_node];
98 
99         if (IsEnter(*fanout.node)) {
100           frame_ids_fanout.pop_back();
101         }
102         if (IsExit(*ready_node)) {
103           frame_ids_node.pop_back();
104         }
105 
106         if (frame_ids_node != frame_ids_fanout) {
107           return errors::InvalidArgument(
108               "Invalid graph: Frame ids for node ", ready_node->name(),
109               " does not match frame ids for it's fanout ",
110               fanout.node->name());
111         }
112       }
113     }
114 
115     ready_nodes.pop_front();
116   }
117 
118   num_frames_ = static_cast<int>(frame_name_to_id.size());
119   return Status::OK();
120 }
121 
InferFromGraph(const GraphDef & graph)122 Status FrameView::InferFromGraph(const GraphDef& graph) {
123   return InferFromGraphView(GraphView(&graph));
124 }
125 
Frames(const NodeDef & node) const126 const std::vector<int>& FrameView::Frames(const NodeDef& node) const {
127   DCHECK(is_inferred_) << "FrameView is not initialized";
128   auto frames = node_to_frames_.find(&node);
129   if (frames == node_to_frames_.end()) {
130     LOG(WARNING) << "Node doesn't belong to the graph used for initialization";
131     return node_has_no_frames_;
132   } else {
133     return frames->second;
134   }
135 }
136 
IsInFrame(const NodeDef & node) const137 bool FrameView::IsInFrame(const NodeDef& node) const {
138   return !Frames(node).empty();
139 }
140 
141 }  // namespace grappler
142 }  // namespace tensorflow
143