1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/common_runtime/lower_case_op.h"
17 
18 #include "tensorflow/core/common_runtime/inline_function_utils.h"
19 #include "tensorflow/core/framework/node_def_builder.h"
20 #include "tensorflow/core/graph/graph.h"
21 #include "tensorflow/core/graph/node_builder.h"
22 #include "tensorflow/core/lib/core/errors.h"
23 
24 namespace tensorflow {
25 
26 namespace {
27 
28 using NodeOut = NodeBuilder::NodeOut;
29 
30 constexpr const char* const kLowerAsMultiDeviceFunctionAttr =
31     LowerFunctionalOpsConstants::kLowerAsMultiDeviceFunctionAttr;
32 
33 // Convenience builder to make it easy to construct a case with a single
34 // function call in each branch. This first converts the Case node
35 // into switches (for inputs) and merges (for outputs) around a function call
36 // per branch.
37 class CaseBuilder {
38  public:
39   // Create a CaseBuilder to create the lowered form of `case` with branch
40   // functions identified by `branch_fn_names` in the `graph`.
41   CaseBuilder(Node* case_op, const std::vector<string>& branch_fn_names,
42               bool keep_node_fetchable, Graph* graph);
43 
44   // Constructs the basic conditional control flow using switch and merge nodes.
45   Status CreatePivotNodes();
46 
47   // Adds the inputs from the if node to the merge nodes of the lowered if.
48   Status AddInputs();
49 
50   // Adds the outputs from the if node to the merge nodes of the lowered if.
51   // Note: no inputs can be added once outputs are added as the then and else
52   // nodes are finalized while adding outputs.
53   Status AddOutputs();
54 
55   // Builds an identity node with the same outputs as Case.
56   Status BuildLoweredCaseOutput();
57 
58  private:
59   // Returns unique name containing the name of the Case op being rewritten
60   // (name_), infix and a suffix to ensure it is unique within the graph.
61   string NewName(const string& infix);
62 
63   // Adds input to both the then and else nodes from src:src_output.
64   Status AddInput(Node* src, int src_output);
65 
66   // The merged outputs of the then and else nodes.
67   std::vector<NodeOut> outputs_;
68 
69   // The node that dominates all execution of the then and else body nodes.
70   Node* control_predecessor_;
71   // The original Case op.
72   Node* case_op_;
73   // The node with the same name as the original Case op:
74   //   (a) IdentityN node with same outputs if 'keep_node_fetchable_ == true'
75   //       and if the original Case op had non-zero data outputs.
76   //   (b) NoOp node with control edge from 'branch_executed_node_' otherwise.
77   Node* lowered_case_output_;
78   // The branch selector of the case.
79   OutputTensor branch_index_;
80   int num_branches_;
81   // Nodes corresponding to pivot branch of branch_index _SwitchN, which is
82   // the pivot node that dominates all nodes in the i'th branch.
83   std::vector<Node*> pivots_;
84   std::vector<Node*> call_nodes_;
85   // Merge node that has inputs from each of pivots_ and control edges from
86   // [^call_node for call_node in call_nodes_]. This node will guarantee that
87   // even when branch functions do not have outputs, they still will be executed
88   // for the side effects.
89   Node* branch_executed_node_;
90   Graph* graph_;
91   string name_;
92   bool keep_node_fetchable_;
93 
94   NodeDebugInfo debug_info_;
95   std::vector<NodeBuilder> branch_call_builders_;
96 };
97 
CaseBuilder(Node * case_op,const std::vector<string> & branch_fn_names,bool keep_node_fetchable,Graph * graph)98 CaseBuilder::CaseBuilder(Node* case_op,
99                          const std::vector<string>& branch_fn_names,
100                          bool keep_node_fetchable, Graph* graph)
101     : case_op_(case_op),
102       num_branches_(branch_fn_names.size()),
103       graph_(graph),
104       name_(case_op->name()),
105       keep_node_fetchable_(keep_node_fetchable),
106       debug_info_(*case_op_) {
107   branch_call_builders_.reserve(num_branches_);
108   for (int b = 0; b < num_branches_; b++) {
109     branch_call_builders_.emplace_back(NewName(strings::StrCat("branch", b)),
110                                        branch_fn_names[b], graph->op_registry(),
111                                        &debug_info_);
112     branch_call_builders_[b].Device(case_op_->requested_device());
113     branch_call_builders_[b].Attr(kLowerAsMultiDeviceFunctionAttr, true);
114   }
115   TF_CHECK_OK(case_op_->input_tensor(0, &branch_index_));
116 }
117 
CreatePivotNodes()118 Status CaseBuilder::CreatePivotNodes() {
119   // Construct the basic case body (consisting of feeding in the val to
120   // create pivot nodes).
121   Node* branch_index;
122   TF_RETURN_IF_ERROR(NodeBuilder(NewName("branch_index"), "_SwitchN",
123                                  graph_->op_registry(), &debug_info_)
124                          .Input(NodeOut(branch_index_))
125                          .Input(NodeOut(branch_index_))
126                          .Attr("num_outs", num_branches_)
127                          .Device(case_op_->requested_device())
128                          .Finalize(graph_, &branch_index));
129   control_predecessor_ = branch_index;
130   pivots_.resize(num_branches_, nullptr);
131   for (int b = 0; b < num_branches_; b++) {
132     TF_RETURN_IF_ERROR(NodeBuilder(NewName(strings::StrCat("pivot_", b)),
133                                    "Identity", graph_->op_registry(),
134                                    &debug_info_)
135                            .Input(branch_index, b)
136                            .Device(case_op_->requested_device())
137                            .Finalize(graph_, &pivots_[b]));
138   }
139   return Status::OK();
140 }
141 
NewName(const string & infix)142 string CaseBuilder::NewName(const string& infix) {
143   return graph_->NewName(strings::StrCat(name_, "/", infix));
144 }
145 
AddInput(Node * src,int src_output)146 Status CaseBuilder::AddInput(Node* src, int src_output) {
147   Node* input;
148   NodeDebugInfo debug_info(*src);
149   // Colocate the Switch node with the `src` node.
150   //
151   // This is to avoid unnecessary Host<->Device copies between src and the
152   // _SwitchN node. This aligns with the implementation of legacy tf.cond in
153   // control_flow_ops.py. The legacy impl colocates the Switch with the
154   // input tensor which resets the device stack and forces the Switch to have
155   // the same device as the input node (if set) and sets the colocation _class
156   // attr. It also ignores the existing colocation constraints on the input node
157   // using colocate_with(ignore_existing=True).
158   TF_RETURN_IF_ERROR(NodeBuilder(NewName(src->name()), "_SwitchN",
159                                  graph_->op_registry(), &debug_info)
160                          .Input(src, src_output)
161                          .Input(branch_index_)
162                          .Device(src->requested_device())
163                          .Attr("_class", {src->name()})
164                          .Attr("num_outs", num_branches_)
165                          .Finalize(graph_, &input));
166   for (int b = 0; b < num_branches_; b++) {
167     branch_call_builders_[b].Input(input, b);
168   }
169   return Status::OK();
170 }
171 
AddInputs()172 Status CaseBuilder::AddInputs() {
173   // Add input data edges.
174   std::vector<const Edge*> edges;
175   TF_RETURN_IF_ERROR(case_op_->input_edges(&edges));
176   // Start at index 1 as the first input is the branch index.
177   for (int i = 1; i < edges.size(); ++i) {
178     const Edge* e = edges[i];
179     TF_RETURN_IF_ERROR(AddInput(e->src(), e->src_output()));
180   }
181   // Add input control edges.
182   for (const Edge* e : case_op_->in_edges()) {
183     if (e->IsControlEdge()) {
184       graph_->AddControlEdge(e->src(), control_predecessor_);
185     }
186   }
187   return Status::OK();
188 }
189 
AddOutputs()190 Status CaseBuilder::AddOutputs() {
191   // Construct the call nodes for each branch.
192   call_nodes_.resize(num_branches_, nullptr);
193   for (int b = 0; b < num_branches_; b++) {
194     TF_RETURN_IF_ERROR(
195         branch_call_builders_[b].Finalize(graph_, &call_nodes_[b]));
196     graph_->AddControlEdge(pivots_[b], call_nodes_[b]);
197   }
198 
199   // Merge the outputs from the N branches (all branches have matching outputs).
200   const int num_outputs = call_nodes_[0]->num_outputs();
201   std::vector<Node*> merges(num_outputs);
202   outputs_.resize(merges.size());
203   for (int i = 0; i < num_outputs; ++i) {
204     std::vector<NodeOut> merge_input;
205     merge_input.reserve(num_branches_);
206     for (int j = 0; j < num_branches_; j++) {
207       merge_input.emplace_back(call_nodes_[j], i);
208     }
209     TF_RETURN_IF_ERROR(NodeBuilder(NewName("merge"), "Merge",
210                                    graph_->op_registry(), &debug_info_)
211                            .Input(merge_input)
212                            .Device(case_op_->requested_device())
213                            .Finalize(graph_, &merges[i]));
214     outputs_[i] = NodeOut(merges[i], 0);
215   }
216 
217   // Add a Merge node that will be used as a control dependency source for the
218   // lowered output node. This Merge node will guarantee that lowered else/then
219   // function calls will be executed even if they do not have data outputs.
220   //
221   // Furthermore it will guarantee that all function side effects will be
222   // executed, if the function will be inlined into the graph. Having data
223   // outputs is not enough, because they might become unused after inlining.
224   //
225   // We will use this node to rewrite outgoing control edges from lowered 'Case'
226   // node. All data edges will read tensors directly from Merge nodes.
227   std::vector<NodeOut> pivots(num_branches_);
228   for (int j = 0; j < num_branches_; j++) {
229     pivots[j] = NodeOut(pivots_[j]);
230   }
231   TF_RETURN_IF_ERROR(NodeBuilder(NewName("branch_executed"), "Merge",
232                                  graph_->op_registry(), &debug_info_)
233                          .Input(pivots)
234                          .ControlInputs(call_nodes_)
235                          .Device(case_op_->requested_device())
236                          .Finalize(graph_, &branch_executed_node_));
237 
238   TF_RETURN_IF_ERROR(BuildLoweredCaseOutput());
239 
240   // Add outputs.
241   for (const Edge* e : case_op_->out_edges()) {
242     if (e->IsControlEdge()) {
243       graph_->AddControlEdge(branch_executed_node_, e->dst());
244     } else {
245       // Feed the outputs directly from the merge nodes so that downstream ops
246       // can start before all the outputs have been computed.
247       graph_->AddEdge(merges[e->src_output()], 0, e->dst(), e->dst_input());
248     }
249   }
250   return Status::OK();
251 }
252 
BuildLoweredCaseOutput()253 Status CaseBuilder::BuildLoweredCaseOutput() {
254   // If outputs are empty, it means that we might have only output control
255   // edges (already connected to the `branch_executed_node`). Furthermore it's
256   // illegal to have an IdentityN with empty inputs.
257   //
258   // We still must keep lowered Case node as a valid source of control edges,
259   // because it might be a part of function control output set.
260   NodeBuilder builder = keep_node_fetchable_ && !outputs_.empty()
261                             ? NodeBuilder(name_, "IdentityN").Input(outputs_)
262                             : NodeBuilder(name_, "NoOp");
263   return builder.Device(case_op_->requested_device())
264       .ControlInput(branch_executed_node_)
265       .Finalize(graph_, &lowered_case_output_);
266 }
267 
268 }  // namespace
269 
RewriteCaseNode(Node * n,Graph * g,bool keep_node_fetchable)270 Status RewriteCaseNode(Node* n, Graph* g, bool keep_node_fetchable) {
271   VLOG(2) << "Lower Case node (keep_node_fetchable=" << keep_node_fetchable
272           << "): " << SummarizeNode(*n);
273   const AttrValue* branches_attr = n->attrs().Find("branches");
274   if (branches_attr == nullptr) {
275     return errors::InvalidArgument("branch functions missing");
276   }
277 
278   int num_branches = branches_attr->list().func_size();
279   std::vector<string> branch_fn_names;
280   branch_fn_names.reserve(num_branches);
281   for (int b = 0; b < num_branches; b++) {
282     branch_fn_names.emplace_back(branches_attr->list().func(b).name());
283   }
284   CaseBuilder cb(n, branch_fn_names, keep_node_fetchable, g);
285   TF_RETURN_IF_ERROR(cb.CreatePivotNodes());
286   TF_RETURN_IF_ERROR(cb.AddInputs());
287   TF_RETURN_IF_ERROR(cb.AddOutputs());
288   g->RemoveNode(n);
289 
290   return Status::OK();
291 }
292 
293 }  // namespace tensorflow
294