1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/grappler/optimizers/data/fusion_utils.h"
17 
18 #include "tensorflow/core/framework/dataset.h"
19 #include "tensorflow/core/framework/node_def.pb.h"
20 #include "tensorflow/core/framework/node_def_builder.h"
21 #include "tensorflow/core/framework/op_def.pb.h"
22 #include "tensorflow/core/grappler/grappler_item.h"
23 #include "tensorflow/core/grappler/mutable_graph_view.h"
24 #include "tensorflow/core/grappler/op_types.h"
25 #include "tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.h"
26 #include "tensorflow/core/grappler/optimizers/data/function_utils.h"
27 #include "tensorflow/core/grappler/optimizers/data/graph_utils.h"
28 #include "tensorflow/core/grappler/utils.h"
29 #include "tensorflow/core/lib/gtl/flatmap.h"
30 #include "tensorflow/core/lib/gtl/flatset.h"
31 #include "tensorflow/core/lib/gtl/map_util.h"
32 #include "tensorflow/core/lib/strings/strcat.h"
33 #include "tensorflow/core/platform/protobuf.h"
34 
35 namespace tensorflow {
36 namespace grappler {
37 namespace fusion_utils {
38 
39 namespace {
ParseNodeConnection(const string & name)40 string ParseNodeConnection(const string& name) {
41   // If input/output node name has semicolon, take the prefix.  Otherwise take
42   // the whole string.
43   return name.substr(0, name.find(':'));
44 }
45 
ParseOutputNode(const string & name)46 string ParseOutputNode(const string& name) {
47   if (name.find(':') == string::npos) return {};
48   return name.substr(name.find(':'), string::npos);
49 }
50 
GetOutputNode(const FunctionDef & function,int output_idx)51 string GetOutputNode(const FunctionDef& function, int output_idx) {
52   const auto& ret_output_name =
53       function.signature().output_arg(output_idx).name();
54   return function.ret().at(ret_output_name);
55 }
56 
GetMutableOutputNode(FunctionDef * function,int output_idx)57 string& GetMutableOutputNode(FunctionDef* function, int output_idx) {
58   const auto& ret_output_name =
59       function->signature().output_arg(output_idx).name();
60   return function->mutable_ret()->at(ret_output_name);
61 }
62 
63 template <typename Iterable>
GetNames(const Iterable & iterable,int allocate_size)64 StringCollection GetNames(const Iterable& iterable, int allocate_size) {
65   StringCollection names;
66   names.reserve(allocate_size);
67   for (auto& arg : iterable) names.push_back(arg.name());
68   return names;
69 }
70 
71 template <typename Iterable>
GetNodeNamesSet(const Iterable & nodes)72 gtl::FlatSet<string> GetNodeNamesSet(const Iterable& nodes) {
73   // NOTE(prazek): Cases where the set is not modified after construction
74   // could use sorted vector with binary_search instead, to make it faster.
75   gtl::FlatSet<string> names;
76   for (const auto& node : nodes) {
77     CHECK(gtl::InsertIfNotPresent(&names, node.name()))
78         << "Functions should have unique node names. Node with name "
79         << node.name() << " already exists";
80   }
81   return names;
82 }
83 
84 template <typename Iterable>
GetUniqueNames(const Iterable & first_iterable,const Iterable & second_iterable)85 gtl::FlatMap<string, string> GetUniqueNames(const Iterable& first_iterable,
86                                             const Iterable& second_iterable) {
87   gtl::FlatMap<string, string> changed_node_names;
88   const auto first_names = GetNodeNamesSet(first_iterable);
89   auto second_names = GetNodeNamesSet(first_iterable);
90   int id = second_iterable.size();
91 
92   for (const auto& node : second_iterable) {
93     string name_before = node.name();
94     string name = name_before;
95     bool changed_name = false;
96 
97     while (first_names.count(name) ||
98            (changed_name && second_names.count(name))) {
99       name = strings::StrCat(name_before, "/_", id);
100       changed_name = true;
101       ++id;
102     }
103     if (changed_name) {
104       changed_node_names[name_before] = name;
105       // We don't want to pick a new name that would collide with another new
106       // name.
107       second_names.insert(std::move(name));
108     }
109   }
110   return changed_node_names;
111 }
112 
113 // We need to rename them and the connections of the inputs that refer to them.
114 // Nodes that will be added to the function can have the same name as the nodes
115 // from parent function.
RenameFunctionNodes(const FunctionDef & first_function,protobuf::RepeatedPtrField<NodeDef> * nodes_to_fuse,protobuf::Map<string,string> * rets_to_fuse)116 void RenameFunctionNodes(const FunctionDef& first_function,
117                          protobuf::RepeatedPtrField<NodeDef>* nodes_to_fuse,
118                          protobuf::Map<string, string>* rets_to_fuse) {
119   const gtl::FlatMap<string, string> changed_node_names =
120       GetUniqueNames(first_function.node_def(), *nodes_to_fuse);
121 
122   auto update_name = [&changed_node_names](string* input) {
123     string input_node = ParseNodeConnection(*input);
124     auto iter = changed_node_names.find(input_node);
125     if (iter != changed_node_names.end()) {
126       *input = iter->second + ParseOutputNode(*input);
127     }
128   };
129 
130   for (NodeDef& function_node : *nodes_to_fuse) {
131     if (const string* new_name =
132             gtl::FindOrNull(changed_node_names, function_node.name())) {
133       function_node.set_name(*new_name);
134     }
135 
136     for (string& input : *function_node.mutable_input()) {
137       update_name(&input);
138     }
139   }
140 
141   for (auto& ret : *rets_to_fuse) update_name(&ret.second);
142 }
143 
GetFunctionInputs(const FunctionDef & function)144 StringCollection GetFunctionInputs(const FunctionDef& function) {
145   return GetNames(function.signature().input_arg(),
146                   function.signature().input_arg_size());
147 }
148 
149 // This function produces signature having names that do not conflict with
150 // `first_signature`.  The input of returns and nodes that will be fused are
151 // updated to use new names.
GetUniqueSignature(const OpDef & first_signature,const OpDef & second_signature,protobuf::Map<string,string> * rets_to_fuse,protobuf::RepeatedPtrField<NodeDef> * nodes_to_fuse)152 OpDef GetUniqueSignature(const OpDef& first_signature,
153                          const OpDef& second_signature,
154                          protobuf::Map<string, string>* rets_to_fuse,
155                          protobuf::RepeatedPtrField<NodeDef>* nodes_to_fuse) {
156   const gtl::FlatMap<string, string> changed_input_names =
157       GetUniqueNames(first_signature.input_arg(), second_signature.input_arg());
158   OpDef signature;
159   signature.set_name(second_signature.name());
160 
161   for (const auto& input_arg : second_signature.input_arg()) {
162     auto& input = *signature.add_input_arg();
163     input = input_arg;
164     if (const string* new_name =
165             gtl::FindOrNull(changed_input_names, input.name())) {
166       input.set_name(*new_name);
167     }
168   }
169   const gtl::FlatMap<string, string> changed_output_names = GetUniqueNames(
170       first_signature.output_arg(), second_signature.output_arg());
171 
172   for (const auto& output_arg : second_signature.output_arg()) {
173     auto& output = *signature.add_output_arg();
174     output = output_arg;
175     if (const string* new_name =
176             gtl::FindOrNull(changed_output_names, output.name())) {
177       output.set_name(*new_name);
178     }
179   }
180 
181   protobuf::Map<string, string> new_rets;
182   for (const auto& ret : *rets_to_fuse) {
183     const auto& key = changed_output_names.count(ret.first)
184                           ? changed_output_names.at(ret.first)
185                           : ret.first;
186     const auto& input = ParseNodeConnection(ret.second);
187     const auto& value =
188         changed_input_names.count(input)
189             ? changed_input_names.at(input) + ParseOutputNode(ret.second)
190             : ret.second;
191     new_rets[key] = value;
192   }
193   *rets_to_fuse = std::move(new_rets);
194 
195   for (NodeDef& function_node : *nodes_to_fuse) {
196     for (auto& node_input : *function_node.mutable_input()) {
197       const auto& input = ParseNodeConnection(node_input);
198       if (const string* new_name =
199               gtl::FindOrNull(changed_input_names, input)) {
200         node_input = *new_name + ParseOutputNode(node_input);
201       }
202     }
203   }
204 
205   return signature;
206 }
207 
208 // This function adds new nodes and changes their input to the output nodes
209 // of parent function.  It assumes that the name of nodes to fuse are not
210 // conflicting.
FuseFunctionNodes(const StringCollection & first_inputs,const StringCollection & second_inputs,const StringCollection & first_outputs,const SetInputFn & set_input,protobuf::RepeatedPtrField<NodeDef> * nodes_to_fuse)211 void FuseFunctionNodes(const StringCollection& first_inputs,
212                        const StringCollection& second_inputs,
213                        const StringCollection& first_outputs,
214                        const SetInputFn& set_input,
215                        protobuf::RepeatedPtrField<NodeDef>* nodes_to_fuse) {
216   for (NodeDef& function_node : *nodes_to_fuse) {
217     for (auto& node_input : *function_node.mutable_input()) {
218       auto parsed_name = ParseNodeConnection(node_input);
219 
220       auto input_it =
221           std::find(second_inputs.begin(), second_inputs.end(), parsed_name);
222       if (input_it == second_inputs.end()) continue;
223 
224       auto arg_num = std::distance(second_inputs.begin(), input_it);
225       node_input =
226           set_input(first_inputs, second_inputs, first_outputs, arg_num);
227     }
228   }
229 }
230 
231 // This function looks for direct edges from input to return and rewrites
232 // them to the corresponding input of the return of `first_function`.
FuseReturns(const StringCollection & first_inputs,const StringCollection & second_inputs,const StringCollection & first_outputs,const SetInputFn & set_input,protobuf::Map<string,string> * fused_ret)233 void FuseReturns(const StringCollection& first_inputs,
234                  const StringCollection& second_inputs,
235                  const StringCollection& first_outputs,
236                  const SetInputFn& set_input,
237                  protobuf::Map<string, string>* fused_ret) {
238   for (auto& ret : *fused_ret) {
239     auto return_input = ParseNodeConnection(ret.second);
240     auto input_it =
241         std::find(second_inputs.begin(), second_inputs.end(), return_input);
242     if (input_it == second_inputs.end()) continue;
243 
244     auto input_idx = std::distance(second_inputs.begin(), input_it);
245     ret.second =
246         set_input(first_inputs, second_inputs, first_outputs, input_idx);
247   }
248 }
249 
250 // Returns collection of node names that are used as a return from function.
GetFunctionOutputs(const FunctionDef & function)251 StringCollection GetFunctionOutputs(const FunctionDef& function) {
252   const auto number_of_outputs = function.signature().output_arg_size();
253   StringCollection outputs;
254   outputs.reserve(number_of_outputs);
255 
256   for (int output_idx = 0; output_idx < number_of_outputs; output_idx++)
257     outputs.push_back(GetOutputNode(function, output_idx));
258   return outputs;
259 }
260 
CreateFalsePredicate(const protobuf::RepeatedPtrField<OpDef_ArgDef> & fake_args,FunctionDefLibrary * library)261 FunctionDef* CreateFalsePredicate(
262     const protobuf::RepeatedPtrField<OpDef_ArgDef>& fake_args,
263     FunctionDefLibrary* library) {
264   GraphDef graph;
265   MutableGraphView graph_view(&graph);
266   auto* node = graph_utils::AddScalarConstNode(false, &graph_view);
267   auto* false_predicate = library->add_function();
268   graph_utils::SetUniqueGraphFunctionName("false_predicate", library,
269                                           false_predicate);
270 
271   int num = 0;
272   for (const auto& fake_arg : fake_args) {
273     auto* arg = false_predicate->mutable_signature()->add_input_arg();
274     arg->set_type(fake_arg.type());
275     arg->set_name(strings::StrCat("fake_arg", num));
276     num++;
277   }
278 
279   auto* output = false_predicate->mutable_signature()->add_output_arg();
280   output->set_name("false_out");
281   output->set_type(DT_BOOL);
282 
283   (*false_predicate->mutable_ret())["false_out"] = node->name() + ":output:0";
284   *false_predicate->mutable_node_def() = std::move(*graph.mutable_node());
285   return false_predicate;
286 }
287 
CheckIfCanCompose(const OpDef & first_signature,const OpDef & second_signature)288 void CheckIfCanCompose(const OpDef& first_signature,
289                        const OpDef& second_signature) {
290   CHECK(CanCompose(first_signature, second_signature))
291       << "The number of input arguments of function " << second_signature.name()
292       << " should be the same as the number of output arguments of function "
293       << first_signature.name() << ".";
294 }
295 
296 }  // namespace
297 
MergeNodes(const FunctionDef & first_function,const FunctionDef & second_function,FunctionDef * fused_function,FunctionDefLibrary * library)298 void MergeNodes(const FunctionDef& first_function,
299                 const FunctionDef& second_function, FunctionDef* fused_function,
300                 FunctionDefLibrary* library) {
301   // Copy all nodes from first_function.
302   fused_function->mutable_node_def()->CopyFrom(first_function.node_def());
303   // Copy transformed nodes from the second function.
304   fused_function->mutable_node_def()->MergeFrom(second_function.node_def());
305 }
306 
CanCompose(const OpDef & first_signature,const OpDef & second_signature)307 bool CanCompose(const OpDef& first_signature, const OpDef& second_signature) {
308   // TODO(prazek): Functions can have additional inputs being placeholders
309   // for a values used in function.  We should be able to also fuse these
310   // functions.
311   return first_signature.output_arg_size() == second_signature.input_arg_size();
312 }
313 
ComposeInput(const StringCollection & first_inputs,const StringCollection & second_inputs,const StringCollection & first_outputs,int arg_num)314 string ComposeInput(const StringCollection& first_inputs,
315                     const StringCollection& second_inputs,
316                     const StringCollection& first_outputs, int arg_num) {
317   // Take corresponding parent output.
318   return first_outputs.at(arg_num);
319 }
320 
ComposeSignature(const OpDef & first_signature,const OpDef & second_signature,OpDef * fused_signature)321 void ComposeSignature(const OpDef& first_signature,
322                       const OpDef& second_signature, OpDef* fused_signature) {
323   CheckIfCanCompose(first_signature, second_signature);
324 
325   // Copy input signature from parent function.
326   *fused_signature->mutable_input_arg() = first_signature.input_arg();
327   // Copy output signature from second function.
328   *fused_signature->mutable_output_arg() = second_signature.output_arg();
329 }
330 
ComposeOutput(const protobuf::Map<string,string> & first_ret,const protobuf::Map<string,string> & second_ret,protobuf::Map<string,string> * fused_ret)331 void ComposeOutput(const protobuf::Map<string, string>& first_ret,
332                    const protobuf::Map<string, string>& second_ret,
333                    protobuf::Map<string, string>* fused_ret) {
334   *fused_ret = second_ret;
335 }
336 
CombineSignature(const OpDef & first_signature,const OpDef & second_signature,OpDef * fused_signature)337 void CombineSignature(const OpDef& first_signature,
338                       const OpDef& second_signature, OpDef* fused_signature) {
339   CheckIfCanCompose(first_signature, second_signature);
340   // Copy input and output signature from parent function.
341   *fused_signature = first_signature;
342 
343   // Add new output parameter.
344   fused_signature->mutable_output_arg()->MergeFrom(
345       second_signature.output_arg());
346 }
347 
CombineOutput(const protobuf::Map<string,string> & first_ret,const protobuf::Map<string,string> & second_ret,protobuf::Map<string,string> * fused_ret)348 void CombineOutput(const protobuf::Map<string, string>& first_ret,
349                    const protobuf::Map<string, string>& second_ret,
350                    protobuf::Map<string, string>* fused_ret) {
351   *fused_ret = first_ret;
352   fused_ret->insert(second_ret.begin(), second_ret.end());
353 }
354 
SameInput(const StringCollection & first_inputs,const StringCollection & second_inputs,const StringCollection & first_outputs,int arg_num)355 string SameInput(const StringCollection& first_inputs,
356                  const StringCollection& second_inputs,
357                  const StringCollection& first_outputs, int arg_num) {
358   return first_inputs.at(arg_num);
359 }
360 
HasSameSignature(const OpDef & first_signature,const OpDef & second_signature)361 bool HasSameSignature(const OpDef& first_signature,
362                       const OpDef& second_signature) {
363   return first_signature.input_arg_size() ==
364              second_signature.input_arg_size() &&
365          first_signature.output_arg_size() ==
366              second_signature.output_arg_size();
367 }
368 
SameSignature(const OpDef & first_signature,const OpDef & second_signature,OpDef * fused_signature)369 void SameSignature(const OpDef& first_signature, const OpDef& second_signature,
370                    OpDef* fused_signature) {
371   CHECK(HasSameSignature(first_signature, second_signature))
372       << "Functions do not have the same signature";
373   // Copy signature from first function.
374   *fused_signature = first_signature;
375 }
376 
LazyConjunctionNodes(const FunctionDef & first_function,const FunctionDef & second_function,FunctionDef * fused_function,FunctionDefLibrary * library)377 void LazyConjunctionNodes(const FunctionDef& first_function,
378                           const FunctionDef& second_function,
379                           FunctionDef* fused_function,
380                           FunctionDefLibrary* library) {
381   fused_function->mutable_node_def()->CopyFrom(first_function.node_def());
382 
383   NodeDefBuilder if_builder("", "If");
384   if_builder.Input(GetOutputNode(first_function, 0), 0, DT_BOOL);
385   DataTypeVector in_arg_types;
386   std::vector<NodeDefBuilder::NodeOut> inputs;
387   for (const auto& input_arg : first_function.signature().input_arg()) {
388     inputs.push_back({input_arg.name(), 0, input_arg.type()});
389     in_arg_types.push_back(input_arg.type());
390   }
391   if_builder.Attr("Tin", in_arg_types);
392 
393   if_builder.Attr("Tcond", DT_BOOL);
394   if_builder.Attr("Tout", DataTypeVector{DT_BOOL});
395   if_builder.Attr("_lower_using_switch_merge", true);
396 
397   NameAttrList then_branch;
398   then_branch.set_name(second_function.signature().name());
399   if_builder.Attr("then_branch", then_branch);
400 
401   auto* false_predicate =
402       CreateFalsePredicate(first_function.signature().input_arg(), library);
403 
404   NameAttrList else_branch;
405   else_branch.set_name(false_predicate->signature().name());
406   if_builder.Attr("else_branch", else_branch);
407   if_builder.Input(inputs);
408 
409   auto* if_node = fused_function->add_node_def();
410   // This is guaranteed to succeed.
411   TF_CHECK_OK(if_builder.Finalize(if_node));
412   function_utils::SetUniqueFunctionNodeName("cond", fused_function, if_node);
413 
414   GetMutableOutputNode(fused_function, 0) = if_node->name() + ":output:0";
415 }
416 
LazyConjunctionOutput(const protobuf::Map<string,string> & first_ret,const protobuf::Map<string,string> & second_ret,protobuf::Map<string,string> * fused_ret)417 void LazyConjunctionOutput(const protobuf::Map<string, string>& first_ret,
418                            const protobuf::Map<string, string>& second_ret,
419                            protobuf::Map<string, string>* fused_ret) {
420   CHECK_EQ(first_ret.size(), 1);
421   CHECK_EQ(second_ret.size(), 1);
422   // Temporarily copy returns from first_ret.  We are going to change the
423   // output node after creating it.
424   *fused_ret = first_ret;
425 }
426 
FuseFunctions(const FunctionDef & first_function,const FunctionDef & second_function,StringPiece fused_name_prefix,const SetFunctionSignatureFn & set_signature,const SetInputFn & set_input,const SetOutputFn & set_output,const SetNodesFn & set_nodes,FunctionDefLibrary * library)427 FunctionDef* FuseFunctions(
428     const FunctionDef& first_function, const FunctionDef& second_function,
429     StringPiece fused_name_prefix, const SetFunctionSignatureFn& set_signature,
430     const SetInputFn& set_input, const SetOutputFn& set_output,
431     const SetNodesFn& set_nodes, FunctionDefLibrary* library) {
432   auto has_unknown_attrs = [](const FunctionDef& func) {
433     int known_attribute_size = 0;
434 
435     if (data::IsTFDataFunction(func)) known_attribute_size += 1;
436     if (func.attr().contains("_construction_context"))
437       known_attribute_size += 1;
438 
439     return func.attr_size() > known_attribute_size;
440   };
441   if (has_unknown_attrs(first_function) || has_unknown_attrs(second_function)) {
442     return nullptr;  // Functions with attributes are currently not supported.
443   }
444 
445   // This function will be used as a clone of second function, having unique
446   // names.
447   FunctionDef setup_function = second_function;
448   *setup_function.mutable_signature() = GetUniqueSignature(
449       first_function.signature(), setup_function.signature(),
450       setup_function.mutable_ret(), setup_function.mutable_node_def());
451 
452   FunctionDef* fused_function = library->add_function();
453 
454   set_signature(first_function.signature(), setup_function.signature(),
455                 fused_function->mutable_signature());
456 
457   graph_utils::SetUniqueGraphFunctionName(fused_name_prefix, library,
458                                           fused_function);
459 
460   RenameFunctionNodes(first_function, setup_function.mutable_node_def(),
461                       setup_function.mutable_ret());
462   set_output(first_function.ret(), setup_function.ret(),
463              fused_function->mutable_ret());
464 
465   CHECK(fused_function->signature().output_arg_size() ==
466         fused_function->ret_size())
467       << "Fused function must have the same number of returns as output "
468          "args.  Output size: "
469       << fused_function->signature().output_arg_size()
470       << ", ret size: " << fused_function->ret_size();
471 
472   const auto first_inputs = GetFunctionInputs(first_function);
473   const auto second_inputs = GetFunctionInputs(setup_function);
474   const auto first_outputs = GetFunctionOutputs(first_function);
475   FuseFunctionNodes(first_inputs, second_inputs, first_outputs, set_input,
476                     setup_function.mutable_node_def());
477   FuseReturns(first_inputs, second_inputs, first_outputs, set_input,
478               fused_function->mutable_ret());
479 
480   set_nodes(first_function, setup_function, fused_function, library);
481   (*fused_function->mutable_attr())[data::kTFDataFunction].set_b(true);
482 
483   // Preserve `_construction_context` attribute in the fused function.
484   auto get_construction_context = [](const FunctionDef& func) {
485     auto iter = func.attr().find("_construction_context");
486     if (iter == func.attr().cend()) return std::string();
487     return iter->second.s();
488   };
489   std::string first_construction_context =
490       get_construction_context(first_function);
491   std::string second_construction_context =
492       get_construction_context(second_function);
493   if (first_construction_context != second_construction_context) {
494     LOG(ERROR) << "_construction_context attribute mismatch during fused "
495                   "function optimization pass. First function: "
496                << first_construction_context
497                << " Second function: " << first_construction_context;
498   }
499   if (!first_construction_context.empty()) {
500     (*fused_function->mutable_attr())["_construction_context"].set_s(
501         first_construction_context);
502   }
503 
504   return fused_function;
505 }
506 
507 }  // namespace fusion_utils
508 }  // namespace grappler
509 }  // namespace tensorflow
510