1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/core/grappler/optimizers/data/fusion_utils.h"
17
18 #include "tensorflow/core/framework/dataset.h"
19 #include "tensorflow/core/framework/node_def.pb.h"
20 #include "tensorflow/core/framework/node_def_builder.h"
21 #include "tensorflow/core/framework/op_def.pb.h"
22 #include "tensorflow/core/grappler/grappler_item.h"
23 #include "tensorflow/core/grappler/mutable_graph_view.h"
24 #include "tensorflow/core/grappler/op_types.h"
25 #include "tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.h"
26 #include "tensorflow/core/grappler/optimizers/data/function_utils.h"
27 #include "tensorflow/core/grappler/optimizers/data/graph_utils.h"
28 #include "tensorflow/core/grappler/utils.h"
29 #include "tensorflow/core/lib/gtl/flatmap.h"
30 #include "tensorflow/core/lib/gtl/flatset.h"
31 #include "tensorflow/core/lib/gtl/map_util.h"
32 #include "tensorflow/core/lib/strings/strcat.h"
33 #include "tensorflow/core/platform/protobuf.h"
34
35 namespace tensorflow {
36 namespace grappler {
37 namespace fusion_utils {
38
39 namespace {
ParseNodeConnection(const string & name)40 string ParseNodeConnection(const string& name) {
41 // If input/output node name has semicolon, take the prefix. Otherwise take
42 // the whole string.
43 return name.substr(0, name.find(':'));
44 }
45
ParseOutputNode(const string & name)46 string ParseOutputNode(const string& name) {
47 if (name.find(':') == string::npos) return {};
48 return name.substr(name.find(':'), string::npos);
49 }
50
GetOutputNode(const FunctionDef & function,int output_idx)51 string GetOutputNode(const FunctionDef& function, int output_idx) {
52 const auto& ret_output_name =
53 function.signature().output_arg(output_idx).name();
54 return function.ret().at(ret_output_name);
55 }
56
GetMutableOutputNode(FunctionDef * function,int output_idx)57 string& GetMutableOutputNode(FunctionDef* function, int output_idx) {
58 const auto& ret_output_name =
59 function->signature().output_arg(output_idx).name();
60 return function->mutable_ret()->at(ret_output_name);
61 }
62
63 template <typename Iterable>
GetNames(const Iterable & iterable,int allocate_size)64 StringCollection GetNames(const Iterable& iterable, int allocate_size) {
65 StringCollection names;
66 names.reserve(allocate_size);
67 for (auto& arg : iterable) names.push_back(arg.name());
68 return names;
69 }
70
71 template <typename Iterable>
GetNodeNamesSet(const Iterable & nodes)72 gtl::FlatSet<string> GetNodeNamesSet(const Iterable& nodes) {
73 // NOTE(prazek): Cases where the set is not modified after construction
74 // could use sorted vector with binary_search instead, to make it faster.
75 gtl::FlatSet<string> names;
76 for (const auto& node : nodes) {
77 CHECK(gtl::InsertIfNotPresent(&names, node.name()))
78 << "Functions should have unique node names. Node with name "
79 << node.name() << " already exists";
80 }
81 return names;
82 }
83
84 template <typename Iterable>
GetUniqueNames(const Iterable & first_iterable,const Iterable & second_iterable)85 gtl::FlatMap<string, string> GetUniqueNames(const Iterable& first_iterable,
86 const Iterable& second_iterable) {
87 gtl::FlatMap<string, string> changed_node_names;
88 const auto first_names = GetNodeNamesSet(first_iterable);
89 auto second_names = GetNodeNamesSet(first_iterable);
90 int id = second_iterable.size();
91
92 for (const auto& node : second_iterable) {
93 string name_before = node.name();
94 string name = name_before;
95 bool changed_name = false;
96
97 while (first_names.count(name) ||
98 (changed_name && second_names.count(name))) {
99 name = strings::StrCat(name_before, "/_", id);
100 changed_name = true;
101 ++id;
102 }
103 if (changed_name) {
104 changed_node_names[name_before] = name;
105 // We don't want to pick a new name that would collide with another new
106 // name.
107 second_names.insert(std::move(name));
108 }
109 }
110 return changed_node_names;
111 }
112
113 // We need to rename them and the connections of the inputs that refer to them.
114 // Nodes that will be added to the function can have the same name as the nodes
115 // from parent function.
RenameFunctionNodes(const FunctionDef & first_function,protobuf::RepeatedPtrField<NodeDef> * nodes_to_fuse,protobuf::Map<string,string> * rets_to_fuse)116 void RenameFunctionNodes(const FunctionDef& first_function,
117 protobuf::RepeatedPtrField<NodeDef>* nodes_to_fuse,
118 protobuf::Map<string, string>* rets_to_fuse) {
119 const gtl::FlatMap<string, string> changed_node_names =
120 GetUniqueNames(first_function.node_def(), *nodes_to_fuse);
121
122 auto update_name = [&changed_node_names](string* input) {
123 string input_node = ParseNodeConnection(*input);
124 auto iter = changed_node_names.find(input_node);
125 if (iter != changed_node_names.end()) {
126 *input = iter->second + ParseOutputNode(*input);
127 }
128 };
129
130 for (NodeDef& function_node : *nodes_to_fuse) {
131 if (const string* new_name =
132 gtl::FindOrNull(changed_node_names, function_node.name())) {
133 function_node.set_name(*new_name);
134 }
135
136 for (string& input : *function_node.mutable_input()) {
137 update_name(&input);
138 }
139 }
140
141 for (auto& ret : *rets_to_fuse) update_name(&ret.second);
142 }
143
GetFunctionInputs(const FunctionDef & function)144 StringCollection GetFunctionInputs(const FunctionDef& function) {
145 return GetNames(function.signature().input_arg(),
146 function.signature().input_arg_size());
147 }
148
149 // This function produces signature having names that do not conflict with
150 // `first_signature`. The input of returns and nodes that will be fused are
151 // updated to use new names.
GetUniqueSignature(const OpDef & first_signature,const OpDef & second_signature,protobuf::Map<string,string> * rets_to_fuse,protobuf::RepeatedPtrField<NodeDef> * nodes_to_fuse)152 OpDef GetUniqueSignature(const OpDef& first_signature,
153 const OpDef& second_signature,
154 protobuf::Map<string, string>* rets_to_fuse,
155 protobuf::RepeatedPtrField<NodeDef>* nodes_to_fuse) {
156 const gtl::FlatMap<string, string> changed_input_names =
157 GetUniqueNames(first_signature.input_arg(), second_signature.input_arg());
158 OpDef signature;
159 signature.set_name(second_signature.name());
160
161 for (const auto& input_arg : second_signature.input_arg()) {
162 auto& input = *signature.add_input_arg();
163 input = input_arg;
164 if (const string* new_name =
165 gtl::FindOrNull(changed_input_names, input.name())) {
166 input.set_name(*new_name);
167 }
168 }
169 const gtl::FlatMap<string, string> changed_output_names = GetUniqueNames(
170 first_signature.output_arg(), second_signature.output_arg());
171
172 for (const auto& output_arg : second_signature.output_arg()) {
173 auto& output = *signature.add_output_arg();
174 output = output_arg;
175 if (const string* new_name =
176 gtl::FindOrNull(changed_output_names, output.name())) {
177 output.set_name(*new_name);
178 }
179 }
180
181 protobuf::Map<string, string> new_rets;
182 for (const auto& ret : *rets_to_fuse) {
183 const auto& key = changed_output_names.count(ret.first)
184 ? changed_output_names.at(ret.first)
185 : ret.first;
186 const auto& input = ParseNodeConnection(ret.second);
187 const auto& value =
188 changed_input_names.count(input)
189 ? changed_input_names.at(input) + ParseOutputNode(ret.second)
190 : ret.second;
191 new_rets[key] = value;
192 }
193 *rets_to_fuse = std::move(new_rets);
194
195 for (NodeDef& function_node : *nodes_to_fuse) {
196 for (auto& node_input : *function_node.mutable_input()) {
197 const auto& input = ParseNodeConnection(node_input);
198 if (const string* new_name =
199 gtl::FindOrNull(changed_input_names, input)) {
200 node_input = *new_name + ParseOutputNode(node_input);
201 }
202 }
203 }
204
205 return signature;
206 }
207
208 // This function adds new nodes and changes their input to the output nodes
209 // of parent function. It assumes that the name of nodes to fuse are not
210 // conflicting.
FuseFunctionNodes(const StringCollection & first_inputs,const StringCollection & second_inputs,const StringCollection & first_outputs,const SetInputFn & set_input,protobuf::RepeatedPtrField<NodeDef> * nodes_to_fuse)211 void FuseFunctionNodes(const StringCollection& first_inputs,
212 const StringCollection& second_inputs,
213 const StringCollection& first_outputs,
214 const SetInputFn& set_input,
215 protobuf::RepeatedPtrField<NodeDef>* nodes_to_fuse) {
216 for (NodeDef& function_node : *nodes_to_fuse) {
217 for (auto& node_input : *function_node.mutable_input()) {
218 auto parsed_name = ParseNodeConnection(node_input);
219
220 auto input_it =
221 std::find(second_inputs.begin(), second_inputs.end(), parsed_name);
222 if (input_it == second_inputs.end()) continue;
223
224 auto arg_num = std::distance(second_inputs.begin(), input_it);
225 node_input =
226 set_input(first_inputs, second_inputs, first_outputs, arg_num);
227 }
228 }
229 }
230
231 // This function looks for direct edges from input to return and rewrites
232 // them to the corresponding input of the return of `first_function`.
FuseReturns(const StringCollection & first_inputs,const StringCollection & second_inputs,const StringCollection & first_outputs,const SetInputFn & set_input,protobuf::Map<string,string> * fused_ret)233 void FuseReturns(const StringCollection& first_inputs,
234 const StringCollection& second_inputs,
235 const StringCollection& first_outputs,
236 const SetInputFn& set_input,
237 protobuf::Map<string, string>* fused_ret) {
238 for (auto& ret : *fused_ret) {
239 auto return_input = ParseNodeConnection(ret.second);
240 auto input_it =
241 std::find(second_inputs.begin(), second_inputs.end(), return_input);
242 if (input_it == second_inputs.end()) continue;
243
244 auto input_idx = std::distance(second_inputs.begin(), input_it);
245 ret.second =
246 set_input(first_inputs, second_inputs, first_outputs, input_idx);
247 }
248 }
249
250 // Returns collection of node names that are used as a return from function.
GetFunctionOutputs(const FunctionDef & function)251 StringCollection GetFunctionOutputs(const FunctionDef& function) {
252 const auto number_of_outputs = function.signature().output_arg_size();
253 StringCollection outputs;
254 outputs.reserve(number_of_outputs);
255
256 for (int output_idx = 0; output_idx < number_of_outputs; output_idx++)
257 outputs.push_back(GetOutputNode(function, output_idx));
258 return outputs;
259 }
260
CreateFalsePredicate(const protobuf::RepeatedPtrField<OpDef_ArgDef> & fake_args,FunctionDefLibrary * library)261 FunctionDef* CreateFalsePredicate(
262 const protobuf::RepeatedPtrField<OpDef_ArgDef>& fake_args,
263 FunctionDefLibrary* library) {
264 GraphDef graph;
265 MutableGraphView graph_view(&graph);
266 auto* node = graph_utils::AddScalarConstNode(false, &graph_view);
267 auto* false_predicate = library->add_function();
268 graph_utils::SetUniqueGraphFunctionName("false_predicate", library,
269 false_predicate);
270
271 int num = 0;
272 for (const auto& fake_arg : fake_args) {
273 auto* arg = false_predicate->mutable_signature()->add_input_arg();
274 arg->set_type(fake_arg.type());
275 arg->set_name(strings::StrCat("fake_arg", num));
276 num++;
277 }
278
279 auto* output = false_predicate->mutable_signature()->add_output_arg();
280 output->set_name("false_out");
281 output->set_type(DT_BOOL);
282
283 (*false_predicate->mutable_ret())["false_out"] = node->name() + ":output:0";
284 *false_predicate->mutable_node_def() = std::move(*graph.mutable_node());
285 return false_predicate;
286 }
287
CheckIfCanCompose(const OpDef & first_signature,const OpDef & second_signature)288 void CheckIfCanCompose(const OpDef& first_signature,
289 const OpDef& second_signature) {
290 CHECK(CanCompose(first_signature, second_signature))
291 << "The number of input arguments of function " << second_signature.name()
292 << " should be the same as the number of output arguments of function "
293 << first_signature.name() << ".";
294 }
295
296 } // namespace
297
MergeNodes(const FunctionDef & first_function,const FunctionDef & second_function,FunctionDef * fused_function,FunctionDefLibrary * library)298 void MergeNodes(const FunctionDef& first_function,
299 const FunctionDef& second_function, FunctionDef* fused_function,
300 FunctionDefLibrary* library) {
301 // Copy all nodes from first_function.
302 fused_function->mutable_node_def()->CopyFrom(first_function.node_def());
303 // Copy transformed nodes from the second function.
304 fused_function->mutable_node_def()->MergeFrom(second_function.node_def());
305 }
306
CanCompose(const OpDef & first_signature,const OpDef & second_signature)307 bool CanCompose(const OpDef& first_signature, const OpDef& second_signature) {
308 // TODO(prazek): Functions can have additional inputs being placeholders
309 // for a values used in function. We should be able to also fuse these
310 // functions.
311 return first_signature.output_arg_size() == second_signature.input_arg_size();
312 }
313
ComposeInput(const StringCollection & first_inputs,const StringCollection & second_inputs,const StringCollection & first_outputs,int arg_num)314 string ComposeInput(const StringCollection& first_inputs,
315 const StringCollection& second_inputs,
316 const StringCollection& first_outputs, int arg_num) {
317 // Take corresponding parent output.
318 return first_outputs.at(arg_num);
319 }
320
ComposeSignature(const OpDef & first_signature,const OpDef & second_signature,OpDef * fused_signature)321 void ComposeSignature(const OpDef& first_signature,
322 const OpDef& second_signature, OpDef* fused_signature) {
323 CheckIfCanCompose(first_signature, second_signature);
324
325 // Copy input signature from parent function.
326 *fused_signature->mutable_input_arg() = first_signature.input_arg();
327 // Copy output signature from second function.
328 *fused_signature->mutable_output_arg() = second_signature.output_arg();
329 }
330
ComposeOutput(const protobuf::Map<string,string> & first_ret,const protobuf::Map<string,string> & second_ret,protobuf::Map<string,string> * fused_ret)331 void ComposeOutput(const protobuf::Map<string, string>& first_ret,
332 const protobuf::Map<string, string>& second_ret,
333 protobuf::Map<string, string>* fused_ret) {
334 *fused_ret = second_ret;
335 }
336
CombineSignature(const OpDef & first_signature,const OpDef & second_signature,OpDef * fused_signature)337 void CombineSignature(const OpDef& first_signature,
338 const OpDef& second_signature, OpDef* fused_signature) {
339 CheckIfCanCompose(first_signature, second_signature);
340 // Copy input and output signature from parent function.
341 *fused_signature = first_signature;
342
343 // Add new output parameter.
344 fused_signature->mutable_output_arg()->MergeFrom(
345 second_signature.output_arg());
346 }
347
CombineOutput(const protobuf::Map<string,string> & first_ret,const protobuf::Map<string,string> & second_ret,protobuf::Map<string,string> * fused_ret)348 void CombineOutput(const protobuf::Map<string, string>& first_ret,
349 const protobuf::Map<string, string>& second_ret,
350 protobuf::Map<string, string>* fused_ret) {
351 *fused_ret = first_ret;
352 fused_ret->insert(second_ret.begin(), second_ret.end());
353 }
354
SameInput(const StringCollection & first_inputs,const StringCollection & second_inputs,const StringCollection & first_outputs,int arg_num)355 string SameInput(const StringCollection& first_inputs,
356 const StringCollection& second_inputs,
357 const StringCollection& first_outputs, int arg_num) {
358 return first_inputs.at(arg_num);
359 }
360
HasSameSignature(const OpDef & first_signature,const OpDef & second_signature)361 bool HasSameSignature(const OpDef& first_signature,
362 const OpDef& second_signature) {
363 return first_signature.input_arg_size() ==
364 second_signature.input_arg_size() &&
365 first_signature.output_arg_size() ==
366 second_signature.output_arg_size();
367 }
368
SameSignature(const OpDef & first_signature,const OpDef & second_signature,OpDef * fused_signature)369 void SameSignature(const OpDef& first_signature, const OpDef& second_signature,
370 OpDef* fused_signature) {
371 CHECK(HasSameSignature(first_signature, second_signature))
372 << "Functions do not have the same signature";
373 // Copy signature from first function.
374 *fused_signature = first_signature;
375 }
376
LazyConjunctionNodes(const FunctionDef & first_function,const FunctionDef & second_function,FunctionDef * fused_function,FunctionDefLibrary * library)377 void LazyConjunctionNodes(const FunctionDef& first_function,
378 const FunctionDef& second_function,
379 FunctionDef* fused_function,
380 FunctionDefLibrary* library) {
381 fused_function->mutable_node_def()->CopyFrom(first_function.node_def());
382
383 NodeDefBuilder if_builder("", "If");
384 if_builder.Input(GetOutputNode(first_function, 0), 0, DT_BOOL);
385 DataTypeVector in_arg_types;
386 std::vector<NodeDefBuilder::NodeOut> inputs;
387 for (const auto& input_arg : first_function.signature().input_arg()) {
388 inputs.push_back({input_arg.name(), 0, input_arg.type()});
389 in_arg_types.push_back(input_arg.type());
390 }
391 if_builder.Attr("Tin", in_arg_types);
392
393 if_builder.Attr("Tcond", DT_BOOL);
394 if_builder.Attr("Tout", DataTypeVector{DT_BOOL});
395 if_builder.Attr("_lower_using_switch_merge", true);
396
397 NameAttrList then_branch;
398 then_branch.set_name(second_function.signature().name());
399 if_builder.Attr("then_branch", then_branch);
400
401 auto* false_predicate =
402 CreateFalsePredicate(first_function.signature().input_arg(), library);
403
404 NameAttrList else_branch;
405 else_branch.set_name(false_predicate->signature().name());
406 if_builder.Attr("else_branch", else_branch);
407 if_builder.Input(inputs);
408
409 auto* if_node = fused_function->add_node_def();
410 // This is guaranteed to succeed.
411 TF_CHECK_OK(if_builder.Finalize(if_node));
412 function_utils::SetUniqueFunctionNodeName("cond", fused_function, if_node);
413
414 GetMutableOutputNode(fused_function, 0) = if_node->name() + ":output:0";
415 }
416
LazyConjunctionOutput(const protobuf::Map<string,string> & first_ret,const protobuf::Map<string,string> & second_ret,protobuf::Map<string,string> * fused_ret)417 void LazyConjunctionOutput(const protobuf::Map<string, string>& first_ret,
418 const protobuf::Map<string, string>& second_ret,
419 protobuf::Map<string, string>* fused_ret) {
420 CHECK_EQ(first_ret.size(), 1);
421 CHECK_EQ(second_ret.size(), 1);
422 // Temporarily copy returns from first_ret. We are going to change the
423 // output node after creating it.
424 *fused_ret = first_ret;
425 }
426
FuseFunctions(const FunctionDef & first_function,const FunctionDef & second_function,StringPiece fused_name_prefix,const SetFunctionSignatureFn & set_signature,const SetInputFn & set_input,const SetOutputFn & set_output,const SetNodesFn & set_nodes,FunctionDefLibrary * library)427 FunctionDef* FuseFunctions(
428 const FunctionDef& first_function, const FunctionDef& second_function,
429 StringPiece fused_name_prefix, const SetFunctionSignatureFn& set_signature,
430 const SetInputFn& set_input, const SetOutputFn& set_output,
431 const SetNodesFn& set_nodes, FunctionDefLibrary* library) {
432 auto has_unknown_attrs = [](const FunctionDef& func) {
433 int known_attribute_size = 0;
434
435 if (data::IsTFDataFunction(func)) known_attribute_size += 1;
436 if (func.attr().contains("_construction_context"))
437 known_attribute_size += 1;
438
439 return func.attr_size() > known_attribute_size;
440 };
441 if (has_unknown_attrs(first_function) || has_unknown_attrs(second_function)) {
442 return nullptr; // Functions with attributes are currently not supported.
443 }
444
445 // This function will be used as a clone of second function, having unique
446 // names.
447 FunctionDef setup_function = second_function;
448 *setup_function.mutable_signature() = GetUniqueSignature(
449 first_function.signature(), setup_function.signature(),
450 setup_function.mutable_ret(), setup_function.mutable_node_def());
451
452 FunctionDef* fused_function = library->add_function();
453
454 set_signature(first_function.signature(), setup_function.signature(),
455 fused_function->mutable_signature());
456
457 graph_utils::SetUniqueGraphFunctionName(fused_name_prefix, library,
458 fused_function);
459
460 RenameFunctionNodes(first_function, setup_function.mutable_node_def(),
461 setup_function.mutable_ret());
462 set_output(first_function.ret(), setup_function.ret(),
463 fused_function->mutable_ret());
464
465 CHECK(fused_function->signature().output_arg_size() ==
466 fused_function->ret_size())
467 << "Fused function must have the same number of returns as output "
468 "args. Output size: "
469 << fused_function->signature().output_arg_size()
470 << ", ret size: " << fused_function->ret_size();
471
472 const auto first_inputs = GetFunctionInputs(first_function);
473 const auto second_inputs = GetFunctionInputs(setup_function);
474 const auto first_outputs = GetFunctionOutputs(first_function);
475 FuseFunctionNodes(first_inputs, second_inputs, first_outputs, set_input,
476 setup_function.mutable_node_def());
477 FuseReturns(first_inputs, second_inputs, first_outputs, set_input,
478 fused_function->mutable_ret());
479
480 set_nodes(first_function, setup_function, fused_function, library);
481 (*fused_function->mutable_attr())[data::kTFDataFunction].set_b(true);
482
483 // Preserve `_construction_context` attribute in the fused function.
484 auto get_construction_context = [](const FunctionDef& func) {
485 auto iter = func.attr().find("_construction_context");
486 if (iter == func.attr().cend()) return std::string();
487 return iter->second.s();
488 };
489 std::string first_construction_context =
490 get_construction_context(first_function);
491 std::string second_construction_context =
492 get_construction_context(second_function);
493 if (first_construction_context != second_construction_context) {
494 LOG(ERROR) << "_construction_context attribute mismatch during fused "
495 "function optimization pass. First function: "
496 << first_construction_context
497 << " Second function: " << first_construction_context;
498 }
499 if (!first_construction_context.empty()) {
500 (*fused_function->mutable_attr())["_construction_context"].set_s(
501 first_construction_context);
502 }
503
504 return fused_function;
505 }
506
507 } // namespace fusion_utils
508 } // namespace grappler
509 } // namespace tensorflow
510