1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/grappler/optimizers/data/fusion_utils.h"
17 
18 #include "tensorflow/core/framework/node_def.pb.h"
19 #include "tensorflow/core/framework/node_def_builder.h"
20 #include "tensorflow/core/framework/op_def.pb.h"
21 #include "tensorflow/core/grappler/grappler_item.h"
22 #include "tensorflow/core/grappler/mutable_graph_view.h"
23 #include "tensorflow/core/grappler/op_types.h"
24 #include "tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.h"
25 #include "tensorflow/core/grappler/optimizers/data/function_utils.h"
26 #include "tensorflow/core/grappler/optimizers/data/graph_utils.h"
27 #include "tensorflow/core/grappler/utils.h"
28 #include "tensorflow/core/lib/gtl/flatmap.h"
29 #include "tensorflow/core/lib/gtl/flatset.h"
30 #include "tensorflow/core/lib/gtl/map_util.h"
31 #include "tensorflow/core/lib/strings/strcat.h"
32 #include "tensorflow/core/platform/protobuf.h"
33 
34 namespace tensorflow {
35 namespace grappler {
36 namespace fusion_utils {
37 
38 namespace {
ParseNodeConnection(const string & name)39 string ParseNodeConnection(const string& name) {
40   // If input/output node name has semicolon, take the prefix.  Otherwise take
41   // the whole string.
42   return name.substr(0, name.find(':'));
43 }
44 
ParseOutputNode(const string & name)45 string ParseOutputNode(const string& name) {
46   if (name.find(':') == string::npos) return {};
47   return name.substr(name.find(':'), string::npos);
48 }
49 
GetOutputNode(const FunctionDef & function,int output_idx)50 string GetOutputNode(const FunctionDef& function, int output_idx) {
51   const auto& ret_output_name =
52       function.signature().output_arg(output_idx).name();
53   return function.ret().at(ret_output_name);
54 }
55 
GetMutableOutputNode(FunctionDef * function,int output_idx)56 string& GetMutableOutputNode(FunctionDef* function, int output_idx) {
57   const auto& ret_output_name =
58       function->signature().output_arg(output_idx).name();
59   return function->mutable_ret()->at(ret_output_name);
60 }
61 
62 template <typename Iterable>
GetNames(const Iterable & iterable,int allocate_size)63 StringCollection GetNames(const Iterable& iterable, int allocate_size) {
64   StringCollection names;
65   names.reserve(allocate_size);
66   for (auto& arg : iterable) names.push_back(arg.name());
67   return names;
68 }
69 
70 template <typename Iterable>
GetNodeNamesSet(const Iterable & nodes)71 gtl::FlatSet<string> GetNodeNamesSet(const Iterable& nodes) {
72   // NOTE(prazek): Cases where the set is not modified after construction
73   // could use sorted vector with binary_search instead, to make it faster.
74   gtl::FlatSet<string> names;
75   for (const auto& node : nodes) {
76     CHECK(gtl::InsertIfNotPresent(&names, node.name()))
77         << "Functions should have unique node names. Node with name "
78         << node.name() << " already exists";
79   }
80   return names;
81 }
82 
83 template <typename Iterable>
GetUniqueNames(const Iterable & first_iterable,const Iterable & second_iterable)84 gtl::FlatMap<string, string> GetUniqueNames(const Iterable& first_iterable,
85                                             const Iterable& second_iterable) {
86   gtl::FlatMap<string, string> changed_node_names;
87   const auto first_names = GetNodeNamesSet(first_iterable);
88   auto second_names = GetNodeNamesSet(first_iterable);
89   int id = second_iterable.size();
90 
91   for (const auto& node : second_iterable) {
92     string name_before = node.name();
93     string name = name_before;
94     bool changed_name = false;
95 
96     while (first_names.count(name) ||
97            (changed_name && second_names.count(name))) {
98       name = strings::StrCat(name_before, "/_", id);
99       changed_name = true;
100       ++id;
101     }
102     if (changed_name) {
103       changed_node_names[name_before] = name;
104       // We don't want to pick a new name that would collide with another new
105       // name.
106       second_names.insert(std::move(name));
107     }
108   }
109   return changed_node_names;
110 }
111 
112 // We need to rename them and the connections of the inputs that refer to them.
113 // Nodes that will be added to the function can have the same name as the nodes
114 // from parent function.
RenameFunctionNodes(const FunctionDef & first_function,protobuf::RepeatedPtrField<NodeDef> * nodes_to_fuse,protobuf::Map<string,string> * rets_to_fuse)115 void RenameFunctionNodes(const FunctionDef& first_function,
116                          protobuf::RepeatedPtrField<NodeDef>* nodes_to_fuse,
117                          protobuf::Map<string, string>* rets_to_fuse) {
118   const gtl::FlatMap<string, string> changed_node_names =
119       GetUniqueNames(first_function.node_def(), *nodes_to_fuse);
120 
121   auto update_name = [&changed_node_names](string* input) {
122     string input_node = ParseNodeConnection(*input);
123     auto iter = changed_node_names.find(input_node);
124     if (iter != changed_node_names.end()) {
125       *input = iter->second + ParseOutputNode(*input);
126     }
127   };
128 
129   for (NodeDef& function_node : *nodes_to_fuse) {
130     if (const string* new_name =
131             gtl::FindOrNull(changed_node_names, function_node.name())) {
132       function_node.set_name(*new_name);
133     }
134 
135     for (string& input : *function_node.mutable_input()) {
136       update_name(&input);
137     }
138   }
139 
140   for (auto& ret : *rets_to_fuse) update_name(&ret.second);
141 }
142 
GetFunctionInputs(const FunctionDef & function)143 StringCollection GetFunctionInputs(const FunctionDef& function) {
144   return GetNames(function.signature().input_arg(),
145                   function.signature().input_arg_size());
146 }
147 
148 // This function produces signature having names that do not conflict with
149 // `first_signature`.  The input of returns and nodes that will be fused are
150 // updated to use new names.
GetUniqueSignature(const OpDef & first_signature,const OpDef & second_signature,protobuf::Map<string,string> * rets_to_fuse,protobuf::RepeatedPtrField<NodeDef> * nodes_to_fuse)151 OpDef GetUniqueSignature(const OpDef& first_signature,
152                          const OpDef& second_signature,
153                          protobuf::Map<string, string>* rets_to_fuse,
154                          protobuf::RepeatedPtrField<NodeDef>* nodes_to_fuse) {
155   const gtl::FlatMap<string, string> changed_input_names =
156       GetUniqueNames(first_signature.input_arg(), second_signature.input_arg());
157   OpDef signature;
158   signature.set_name(second_signature.name());
159 
160   for (const auto& input_arg : second_signature.input_arg()) {
161     auto& input = *signature.add_input_arg();
162     input = input_arg;
163     if (const string* new_name =
164             gtl::FindOrNull(changed_input_names, input.name())) {
165       input.set_name(*new_name);
166     }
167   }
168   const gtl::FlatMap<string, string> changed_output_names = GetUniqueNames(
169       first_signature.output_arg(), second_signature.output_arg());
170 
171   for (const auto& output_arg : second_signature.output_arg()) {
172     auto& output = *signature.add_output_arg();
173     output = output_arg;
174     if (const string* new_name =
175             gtl::FindOrNull(changed_output_names, output.name())) {
176       output.set_name(*new_name);
177     }
178   }
179 
180   protobuf::Map<string, string> new_rets;
181   for (const auto& ret : *rets_to_fuse) {
182     const auto& key = changed_output_names.count(ret.first)
183                           ? changed_output_names.at(ret.first)
184                           : ret.first;
185     const auto& input = ParseNodeConnection(ret.second);
186     const auto& value =
187         changed_input_names.count(input)
188             ? changed_input_names.at(input) + ParseOutputNode(ret.second)
189             : ret.second;
190     new_rets[key] = value;
191   }
192   *rets_to_fuse = std::move(new_rets);
193 
194   for (NodeDef& function_node : *nodes_to_fuse) {
195     for (auto& node_input : *function_node.mutable_input()) {
196       const auto& input = ParseNodeConnection(node_input);
197       if (const string* new_name =
198               gtl::FindOrNull(changed_input_names, input)) {
199         node_input = *new_name + ParseOutputNode(node_input);
200       }
201     }
202   }
203 
204   return signature;
205 }
206 
207 // This function adds new nodes and changes their input to the output nodes
208 // of parent function.  It assumes that the name of nodes to fuse are not
209 // conflicting.
FuseFunctionNodes(const StringCollection & first_inputs,const StringCollection & second_inputs,const StringCollection & first_outputs,const SetInputFn & set_input,protobuf::RepeatedPtrField<NodeDef> * nodes_to_fuse)210 void FuseFunctionNodes(const StringCollection& first_inputs,
211                        const StringCollection& second_inputs,
212                        const StringCollection& first_outputs,
213                        const SetInputFn& set_input,
214                        protobuf::RepeatedPtrField<NodeDef>* nodes_to_fuse) {
215   for (NodeDef& function_node : *nodes_to_fuse) {
216     for (auto& node_input : *function_node.mutable_input()) {
217       auto parsed_name = ParseNodeConnection(node_input);
218 
219       auto input_it =
220           std::find(second_inputs.begin(), second_inputs.end(), parsed_name);
221       if (input_it == second_inputs.end()) continue;
222 
223       auto arg_num = std::distance(second_inputs.begin(), input_it);
224       node_input =
225           set_input(first_inputs, second_inputs, first_outputs, arg_num);
226     }
227   }
228 }
229 
230 // This function looks for direct edges from input to return and rewrites
231 // them to the corresponding input of the return of `first_function`.
FuseReturns(const StringCollection & first_inputs,const StringCollection & second_inputs,const StringCollection & first_outputs,const SetInputFn & set_input,protobuf::Map<string,string> * fused_ret)232 void FuseReturns(const StringCollection& first_inputs,
233                  const StringCollection& second_inputs,
234                  const StringCollection& first_outputs,
235                  const SetInputFn& set_input,
236                  protobuf::Map<string, string>* fused_ret) {
237   for (auto& ret : *fused_ret) {
238     auto return_input = ParseNodeConnection(ret.second);
239     auto input_it =
240         std::find(second_inputs.begin(), second_inputs.end(), return_input);
241     if (input_it == second_inputs.end()) continue;
242 
243     auto input_idx = std::distance(second_inputs.begin(), input_it);
244     ret.second =
245         set_input(first_inputs, second_inputs, first_outputs, input_idx);
246   }
247 }
248 
249 // Returns collection of node names that are used as a return from function.
GetFunctionOutputs(const FunctionDef & function)250 StringCollection GetFunctionOutputs(const FunctionDef& function) {
251   const auto number_of_outputs = function.signature().output_arg_size();
252   StringCollection outputs;
253   outputs.reserve(number_of_outputs);
254 
255   for (int output_idx = 0; output_idx < number_of_outputs; output_idx++)
256     outputs.push_back(GetOutputNode(function, output_idx));
257   return outputs;
258 }
259 
CreateFalsePredicate(const protobuf::RepeatedPtrField<OpDef_ArgDef> & fake_args,FunctionDefLibrary * library)260 FunctionDef* CreateFalsePredicate(
261     const protobuf::RepeatedPtrField<OpDef_ArgDef>& fake_args,
262     FunctionDefLibrary* library) {
263   GraphDef graph;
264   MutableGraphView graph_view(&graph);
265   auto* node = graph_utils::AddScalarConstNode(false, &graph_view);
266   auto* false_predicate = library->add_function();
267   graph_utils::SetUniqueGraphFunctionName("false_predicate", library,
268                                           false_predicate);
269 
270   int num = 0;
271   for (const auto& fake_arg : fake_args) {
272     auto* arg = false_predicate->mutable_signature()->add_input_arg();
273     arg->set_type(fake_arg.type());
274     arg->set_name(strings::StrCat("fake_arg", num));
275     num++;
276   }
277 
278   auto* output = false_predicate->mutable_signature()->add_output_arg();
279   output->set_name("false_out");
280   output->set_type(DT_BOOL);
281 
282   (*false_predicate->mutable_ret())["false_out"] = node->name() + ":output:0";
283   *false_predicate->mutable_node_def() = std::move(*graph.mutable_node());
284   return false_predicate;
285 }
286 
CheckIfCanCompose(const OpDef & first_signature,const OpDef & second_signature)287 void CheckIfCanCompose(const OpDef& first_signature,
288                        const OpDef& second_signature) {
289   CHECK(CanCompose(first_signature, second_signature))
290       << "The number of input arguments of function " << second_signature.name()
291       << " should be the same as the number of output arguments of function "
292       << first_signature.name() << ".";
293 }
294 
295 }  // namespace
296 
MergeNodes(const FunctionDef & first_function,const FunctionDef & second_function,FunctionDef * fused_function,FunctionDefLibrary * library)297 void MergeNodes(const FunctionDef& first_function,
298                 const FunctionDef& second_function, FunctionDef* fused_function,
299                 FunctionDefLibrary* library) {
300   // Copy all nodes from first_function.
301   fused_function->mutable_node_def()->CopyFrom(first_function.node_def());
302   // Copy transformed nodes from the second function.
303   fused_function->mutable_node_def()->MergeFrom(second_function.node_def());
304 }
305 
CanCompose(const OpDef & first_signature,const OpDef & second_signature)306 bool CanCompose(const OpDef& first_signature, const OpDef& second_signature) {
307   // TODO(prazek): Functions can have additional inputs being placeholders
308   // for a values used in function.  We should be able to also fuse these
309   // functions.
310   return first_signature.output_arg_size() == second_signature.input_arg_size();
311 }
312 
ComposeInput(const StringCollection & first_inputs,const StringCollection & second_inputs,const StringCollection & first_outputs,int arg_num)313 string ComposeInput(const StringCollection& first_inputs,
314                     const StringCollection& second_inputs,
315                     const StringCollection& first_outputs, int arg_num) {
316   // Take corresponding parent output.
317   return first_outputs.at(arg_num);
318 }
319 
ComposeSignature(const OpDef & first_signature,const OpDef & second_signature,OpDef * fused_signature)320 void ComposeSignature(const OpDef& first_signature,
321                       const OpDef& second_signature, OpDef* fused_signature) {
322   CheckIfCanCompose(first_signature, second_signature);
323 
324   // Copy input signature from parent function.
325   *fused_signature->mutable_input_arg() = first_signature.input_arg();
326   // Copy output signature from second function.
327   *fused_signature->mutable_output_arg() = second_signature.output_arg();
328 }
329 
ComposeOutput(const protobuf::Map<string,string> & first_ret,const protobuf::Map<string,string> & second_ret,protobuf::Map<string,string> * fused_ret)330 void ComposeOutput(const protobuf::Map<string, string>& first_ret,
331                    const protobuf::Map<string, string>& second_ret,
332                    protobuf::Map<string, string>* fused_ret) {
333   *fused_ret = second_ret;
334 }
335 
CombineSignature(const OpDef & first_signature,const OpDef & second_signature,OpDef * fused_signature)336 void CombineSignature(const OpDef& first_signature,
337                       const OpDef& second_signature, OpDef* fused_signature) {
338   CheckIfCanCompose(first_signature, second_signature);
339   // Copy input and output signature from parent function.
340   *fused_signature = first_signature;
341 
342   // Add new output parameter.
343   fused_signature->mutable_output_arg()->MergeFrom(
344       second_signature.output_arg());
345 }
346 
CombineOutput(const protobuf::Map<string,string> & first_ret,const protobuf::Map<string,string> & second_ret,protobuf::Map<string,string> * fused_ret)347 void CombineOutput(const protobuf::Map<string, string>& first_ret,
348                    const protobuf::Map<string, string>& second_ret,
349                    protobuf::Map<string, string>* fused_ret) {
350   *fused_ret = first_ret;
351   fused_ret->insert(second_ret.begin(), second_ret.end());
352 }
353 
SameInput(const StringCollection & first_inputs,const StringCollection & second_inputs,const StringCollection & first_outputs,int arg_num)354 string SameInput(const StringCollection& first_inputs,
355                  const StringCollection& second_inputs,
356                  const StringCollection& first_outputs, int arg_num) {
357   return first_inputs.at(arg_num);
358 }
359 
HasSameSignature(const OpDef & first_signature,const OpDef & second_signature)360 bool HasSameSignature(const OpDef& first_signature,
361                       const OpDef& second_signature) {
362   return first_signature.input_arg_size() ==
363              second_signature.input_arg_size() &&
364          first_signature.output_arg_size() ==
365              second_signature.output_arg_size();
366 }
367 
SameSignature(const OpDef & first_signature,const OpDef & second_signature,OpDef * fused_signature)368 void SameSignature(const OpDef& first_signature, const OpDef& second_signature,
369                    OpDef* fused_signature) {
370   CHECK(HasSameSignature(first_signature, second_signature))
371       << "Functions do not have the same signature";
372   // Copy signature from first function.
373   *fused_signature = first_signature;
374 }
375 
LazyConjunctionNodes(const FunctionDef & first_function,const FunctionDef & second_function,FunctionDef * fused_function,FunctionDefLibrary * library)376 void LazyConjunctionNodes(const FunctionDef& first_function,
377                           const FunctionDef& second_function,
378                           FunctionDef* fused_function,
379                           FunctionDefLibrary* library) {
380   fused_function->mutable_node_def()->CopyFrom(first_function.node_def());
381 
382   NodeDefBuilder if_builder("", "If");
383   if_builder.Input(GetOutputNode(first_function, 0), 0, DT_BOOL);
384   DataTypeVector in_arg_types;
385   std::vector<NodeDefBuilder::NodeOut> inputs;
386   for (const auto& input_arg : first_function.signature().input_arg()) {
387     inputs.push_back({input_arg.name(), 0, input_arg.type()});
388     in_arg_types.push_back(input_arg.type());
389   }
390   if_builder.Attr("Tin", in_arg_types);
391 
392   if_builder.Attr("Tcond", DT_BOOL);
393   if_builder.Attr("Tout", DataTypeVector{DT_BOOL});
394   if_builder.Attr("_lower_using_switch_merge", true);
395 
396   NameAttrList then_branch;
397   then_branch.set_name(second_function.signature().name());
398   if_builder.Attr("then_branch", then_branch);
399 
400   auto* false_predicate =
401       CreateFalsePredicate(first_function.signature().input_arg(), library);
402 
403   NameAttrList else_branch;
404   else_branch.set_name(false_predicate->signature().name());
405   if_builder.Attr("else_branch", else_branch);
406   if_builder.Input(inputs);
407 
408   auto* if_node = fused_function->add_node_def();
409   // This is guaranteed to succeed.
410   TF_CHECK_OK(if_builder.Finalize(if_node));
411   function_utils::SetUniqueFunctionNodeName("cond", fused_function, if_node);
412 
413   GetMutableOutputNode(fused_function, 0) = if_node->name() + ":output:0";
414 }
415 
LazyConjunctionOutput(const protobuf::Map<string,string> & first_ret,const protobuf::Map<string,string> & second_ret,protobuf::Map<string,string> * fused_ret)416 void LazyConjunctionOutput(const protobuf::Map<string, string>& first_ret,
417                            const protobuf::Map<string, string>& second_ret,
418                            protobuf::Map<string, string>* fused_ret) {
419   CHECK_EQ(first_ret.size(), 1);
420   CHECK_EQ(second_ret.size(), 1);
421   // Temporarily copy returns from first_ret.  We are going to change the
422   // output node after creating it.
423   *fused_ret = first_ret;
424 }
425 
FuseFunctions(const FunctionDef & first_function,const FunctionDef & second_function,StringPiece fused_name_prefix,const SetFunctionSignatureFn & set_signature,const SetInputFn & set_input,const SetOutputFn & set_output,const SetNodesFn & set_nodes,FunctionDefLibrary * library)426 FunctionDef* FuseFunctions(
427     const FunctionDef& first_function, const FunctionDef& second_function,
428     StringPiece fused_name_prefix, const SetFunctionSignatureFn& set_signature,
429     const SetInputFn& set_input, const SetOutputFn& set_output,
430     const SetNodesFn& set_nodes, FunctionDefLibrary* library) {
431   if (first_function.attr_size() != 0 || second_function.attr_size() != 0)
432     return nullptr;  // Functions with attributes are currently not supported
433 
434   // This function will be used as a clone of second function, having unique
435   // names.
436   FunctionDef setup_function = second_function;
437   *setup_function.mutable_signature() = GetUniqueSignature(
438       first_function.signature(), setup_function.signature(),
439       setup_function.mutable_ret(), setup_function.mutable_node_def());
440 
441   FunctionDef* fused_function = library->add_function();
442 
443   set_signature(first_function.signature(), setup_function.signature(),
444                 fused_function->mutable_signature());
445 
446   graph_utils::SetUniqueGraphFunctionName(fused_name_prefix, library,
447                                           fused_function);
448 
449   RenameFunctionNodes(first_function, setup_function.mutable_node_def(),
450                       setup_function.mutable_ret());
451   set_output(first_function.ret(), setup_function.ret(),
452              fused_function->mutable_ret());
453 
454   CHECK(fused_function->signature().output_arg_size() ==
455         fused_function->ret_size())
456       << "Fused function must have the same number of returns as output "
457          "args.  Output size: "
458       << fused_function->signature().output_arg_size()
459       << ", ret size: " << fused_function->ret_size();
460 
461   const auto first_inputs = GetFunctionInputs(first_function);
462   const auto second_inputs = GetFunctionInputs(setup_function);
463   const auto first_outputs = GetFunctionOutputs(first_function);
464   FuseFunctionNodes(first_inputs, second_inputs, first_outputs, set_input,
465                     setup_function.mutable_node_def());
466   FuseReturns(first_inputs, second_inputs, first_outputs, set_input,
467               fused_function->mutable_ret());
468 
469   set_nodes(first_function, setup_function, fused_function, library);
470 
471   return fused_function;
472 }
473 
474 }  // namespace fusion_utils
475 }  // namespace grappler
476 }  // namespace tensorflow
477