1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/core/grappler/optimizers/function_optimizer.h"
17
18 #include <vector>
19
20 #include "absl/algorithm/container.h"
21 #include "absl/container/flat_hash_map.h"
22 #include "absl/container/flat_hash_set.h"
23 #include "absl/memory/memory.h"
24 #include "absl/strings/str_replace.h"
25 #include "absl/strings/substitute.h"
26 #include "tensorflow/core/common_runtime/device.h"
27 #include "tensorflow/core/common_runtime/device_mgr.h"
28 #include "tensorflow/core/common_runtime/device_set.h"
29 #include "tensorflow/core/common_runtime/function.h"
30 #include "tensorflow/core/common_runtime/lower_if_while.h"
31 #include "tensorflow/core/common_runtime/optimization_registry.h"
32 #include "tensorflow/core/common_runtime/placer.h"
33 #include "tensorflow/core/common_runtime/process_function_library_runtime.h"
34 #include "tensorflow/core/framework/attr_value_util.h"
35 #include "tensorflow/core/framework/function.h"
36 #include "tensorflow/core/framework/function.pb.h"
37 #include "tensorflow/core/framework/graph_def_util.h"
38 #include "tensorflow/core/framework/node_def.pb.h"
39 #include "tensorflow/core/framework/node_def_util.h"
40 #include "tensorflow/core/framework/op_def.pb.h"
41 #include "tensorflow/core/framework/versions.pb.h"
42 #include "tensorflow/core/graph/graph_constructor.h"
43 #include "tensorflow/core/graph/tensor_id.h"
44 #include "tensorflow/core/grappler/graph_topology_view.h"
45 #include "tensorflow/core/grappler/grappler_item.h"
46 #include "tensorflow/core/grappler/mutable_graph_view.h"
47 #include "tensorflow/core/grappler/op_types.h"
48 #include "tensorflow/core/grappler/utils.h"
49 #include "tensorflow/core/grappler/utils/functions.h"
50 #include "tensorflow/core/grappler/utils/topological_sort.h"
51 #include "tensorflow/core/grappler/utils/traversal.h"
52 #include "tensorflow/core/lib/gtl/map_util.h"
53
54 namespace tensorflow {
55 namespace grappler {
56 namespace {
57
58 // WARNING: Code in this file implicitly assumes that function input and output
59 // arguments are plain tensors (tensor lists are not supported). Function inputs
60 // and outputs are always expanded to a single placeholder or output tensor.
61 // With this assumption, the calling node's input/output ports always match
62 // function input/output arguments.
63 //
64 // This is guaranteed by the implementation of MakeGrapplerFunctionItem.
65
66 // Mark functions that were created as a result of function specialization.
67 constexpr char kGrapplerSpecializedFuncAttr[] = "_GrapplerSpecializedFunc";
68
69 // Name of the attribute that defines the function for indirect function calls.
70 constexpr char kFuncAttrName[] = "f";
71
72 constexpr char kNoInlineAttr[] = "_noinline";
73
74 // Name of the node that will have control edges from function input nodes, and
75 // also used as a new destination for incoming control edges.
76 constexpr char kInputsReadyNodeName[] = "inputs_ready";
77
78 // Name of the node that will have control edges from function control output
79 // nodes, and also used as a new source of outgoing control edges. This node
80 // will guarantee that all side-effects inside function body will be executed
81 // after function inlining.
82 constexpr char kSideEffectsExecutedNodeName[] = "side_effects_executed";
83
AttrIsTrue(const FunctionDef & func,const string & attr)84 bool AttrIsTrue(const FunctionDef& func, const string& attr) {
85 return func.attr().count(attr) != 0 && func.attr().at(attr).b();
86 }
87
MarkedSpecialized(const FunctionDef & func)88 bool MarkedSpecialized(const FunctionDef& func) {
89 return AttrIsTrue(func, kGrapplerSpecializedFuncAttr);
90 }
91
MarkedNoInline(const FunctionDef & func)92 bool MarkedNoInline(const FunctionDef& func) {
93 return AttrIsTrue(func, kNoInlineAttr);
94 }
95
96 // There are two ways of calling a Tensorflow function:
97 //
98 // 1. Direct function call: node.op() is the name of the function.
99 //
100 // 2. Indirect function call: the function name is passed through a node
101 // attribute, and special Tensorflow kernels are responsible for calling the
102 // function through the FunctionLibraryRuntime. Example: PartitionedCallOp.
103
104 // Check if func_node.op() matches the name in FunctionDef signature.
IsDirectFunctionCall(const FunctionDef & func,const NodeDef & func_node)105 bool IsDirectFunctionCall(const FunctionDef& func, const NodeDef& func_node) {
106 return func_node.op() == func.signature().name();
107 }
108
109 // Check if func_node has function attribute with a function name matching
110 // FunctionDef signature.
IsIndirectFunctionCall(const FunctionDef & func,const NodeDef & func_node)111 bool IsIndirectFunctionCall(const FunctionDef& func, const NodeDef& func_node) {
112 auto* func_attr = AttrSlice(func_node).Find(kFuncAttrName);
113 return func_attr != nullptr && func_attr->has_func() &&
114 func_attr->func().name() == func.signature().name();
115 }
116
FunctionInstantiationAttributes(const FunctionDef & func,const NodeDef & func_node)117 AttrSlice FunctionInstantiationAttributes(const FunctionDef& func,
118 const NodeDef& func_node) {
119 if (IsDirectFunctionCall(func, func_node)) {
120 return AttrSlice(func_node);
121
122 } else if (IsIndirectFunctionCall(func, func_node)) {
123 auto* func_attr = AttrSlice(func_node).Find(kFuncAttrName);
124 return AttrSlice(&func_attr->func().attr());
125
126 } else {
127 LOG(WARNING) << "Can't resolve function instantiation attributes: "
128 << SummarizeNodeDef(func_node);
129 return AttrSlice();
130 }
131 }
132
133 // This is a fake device that should not be used for any op kernel execution,
134 // the only purpose of this device is to be passed as a part of DeviceSet to the
135 // Placer.
136 class FakeDevice : public Device {
137 public:
FakeDevice(Env * env,const string & device)138 FakeDevice(Env* env, const string& device) : Device(env, attr(device)) {}
FakeDevice(const string & device)139 explicit FakeDevice(const string& device) : FakeDevice(nullptr, device) {}
Sync()140 Status Sync() override { return Status::OK(); }
141
142 private:
attr(const string & device)143 static DeviceAttributes attr(const string& device) {
144 DeviceNameUtils::ParsedName parsed_name;
145 bool parsed = DeviceNameUtils::ParseFullName(device, &parsed_name);
146 DCHECK(parsed) << "Failed to parse full device name: " << device;
147
148 DeviceAttributes attr;
149 attr.set_name(device);
150 attr.set_device_type(parsed_name.type);
151 return attr;
152 }
153 };
154
155 // -------------------------------------------------------------------------- //
156 // Function specialization.
157 //
158 // FunctionDef is somewhat similar to function template in C++, given all the
159 // type parameters (and attribute values) it generates a statically defined
160 // graph from the type parametrized "graph template" (function body).
161 //
162 // Function specialization instantiates a parametrized FunctionDef into a
163 // statically defined graph, and then converts it back to the fully defined
164 // FunctionDef (it doesn't have any unknown type parameters or attribute
165 // values, known as placeholders).
166 //
167 // Given the fully specified graph we can apply all the Grappler optimizers to
168 // it (see details in MetaOptimizer). Also we can push known constant inputs
169 // into the function body, and remove unused outputs/inputs.
170
171 // Specialized function instantiation type parameters, body parameters, and
172 // const inputs.
173 struct FunctionSpecializationSignature {
174 // Currently we do not support functions with tensor lists as inputs or
175 // outputs, so caller node input/output ports always match function
176 // input/output arguments.
177 using InputPort = int;
178 using OutputPort = int;
179
180 string func_name;
181 bool is_in_fetch_set;
182 absl::flat_hash_set<OutputPort> active_outputs;
183 absl::flat_hash_map<string, DataType> type_parameters;
184 absl::flat_hash_map<string, AttrValue> body_parameters;
185 absl::flat_hash_map<InputPort, string> const_inputs;
186
operator ==tensorflow::grappler::__anondd39c0ba0111::FunctionSpecializationSignature187 bool operator==(const FunctionSpecializationSignature& other) const {
188 bool equals = func_name == other.func_name &&
189 is_in_fetch_set == other.is_in_fetch_set &&
190 active_outputs == other.active_outputs &&
191 type_parameters == other.type_parameters &&
192 const_inputs == other.const_inputs;
193
194 if (!equals) return false;
195
196 // Equality is not defined for AttrValue.
197 if (body_parameters.size() != other.body_parameters.size()) return false;
198
199 for (const auto& lhs : body_parameters) {
200 auto it = other.body_parameters.find(lhs.first);
201 if (it == other.body_parameters.end()) return false;
202 if (!FastAreAttrValuesEqual(lhs.second, (*it).second)) return false;
203 }
204
205 return true;
206 }
207
208 template <typename H>
AbslHashValue(H h,const FunctionSpecializationSignature & s)209 friend H AbslHashValue(H h, const FunctionSpecializationSignature& s) {
210 H base = H::combine(std::move(h), s.func_name, s.is_in_fetch_set);
211
212 // First pre-compute hashes for all values in collections with
213 // non-deterministic iteration order.
214 std::vector<uint64> hashes;
215 hashes.reserve(s.active_outputs.size() //
216 + s.type_parameters.size() * 2 //
217 + s.body_parameters.size() * 2 //
218 + s.const_inputs.size() * 2);
219
220 absl::c_transform(s.active_outputs, std::back_inserter(hashes),
221 hash<OutputPort>());
222
223 using TypeParam = std::pair<const string, DataType>;
224 absl::c_for_each(s.type_parameters, [&hashes](const TypeParam& type_param) {
225 AttrValue attr_value;
226 attr_value.set_type(type_param.second);
227 hashes.push_back(Hash64(type_param.first));
228 hashes.push_back(AttrValueHash(attr_value));
229 });
230
231 using BodyParam = std::pair<const string, AttrValue>;
232 absl::c_for_each(s.body_parameters, [&hashes](const BodyParam& body_param) {
233 hashes.push_back(Hash64(body_param.first));
234 hashes.push_back(FastAttrValueHash(body_param.second));
235 });
236
237 using ConstInput = std::pair<const InputPort, string>;
238 absl::c_for_each(s.const_inputs, [&hashes](const ConstInput& const_input) {
239 hashes.push_back(hash<InputPort>()(const_input.first));
240 hashes.push_back(Hash64(const_input.second));
241 });
242
243 // Combine all pre-computed hashes in a deterministic order.
244 absl::c_sort(hashes);
245 return H::combine_contiguous(std::move(base), hashes.data(), hashes.size());
246 }
247 };
248
249 struct FunctionSpecialization {
250 string specialized_func_name;
251 // True if the function caller node is in GrapplerItem fetch set.
252 bool is_in_fetch_set;
253 // Names of the tensors that were pushed down into the function body.
254 absl::flat_hash_set<string> const_inputs;
255 // Control dependencies of pushed down const inputs have to be attached to
256 // function caller node.
257 absl::flat_hash_set<string> control_deps;
258 // Output tensors (ports) that consumed by other nodes in the graph or in a
259 // GrapplerItem fetch set.
260 absl::flat_hash_set<int> active_outputs;
261 // Mapping from original function output port to the output port of
262 // specialized function. If function specialization changes the number of
263 // function outputs it's required to update all node consumers.
264 std::vector<std::pair<int, int>> output_mapping;
265 };
266
267 // Function optimizer context initialized once for each optimization pass, and
268 // it uses the latest available graph (for the first iteration it will be the
269 // GrapplerItem.graph, for next iterations it will be the output of previous
270 // function optimizer pass).
271 class FunctionOptimizerContext {
272 public:
FunctionOptimizerContext(const GrapplerItem & item,RewriterConfig::Toggle opt_level,const GraphDef & graph)273 explicit FunctionOptimizerContext(const GrapplerItem& item,
274 RewriterConfig::Toggle opt_level,
275 const GraphDef& graph)
276 : item_(&item),
277 opt_level_(opt_level),
278 function_library_(OpRegistry::Global(), graph.library()),
279 truly_const_nodes_(InferTrulyConstNodes(item, graph)),
280 graph_view_(&graph) {}
281
item() const282 const GrapplerItem& item() const { return *item_; }
283
graph_version() const284 const int graph_version() const { return item_->graph.versions().producer(); }
285
opt_level() const286 RewriterConfig::Toggle opt_level() const { return opt_level_; }
287
function_library() const288 const FunctionLibraryDefinition& function_library() const {
289 return function_library_;
290 }
291
mutable_function_library()292 FunctionLibraryDefinition* mutable_function_library() {
293 return &function_library_;
294 }
295
mutable_function_library_runtime()296 FunctionLibraryRuntime* mutable_function_library_runtime() {
297 InitializeFunctionLibraryRuntime();
298 return flr_;
299 }
300
301 const absl::flat_hash_map<SafeTensorId, SafeTensorId, SafeTensorId::Hasher>&
tensor_mapping() const302 tensor_mapping() const {
303 return tensor_mapping_;
304 }
305
control_overrides() const306 const absl::flat_hash_map<string, std::vector<string>>& control_overrides()
307 const {
308 return control_overrides_;
309 }
310
graph_view() const311 const GraphView& graph_view() const { return graph_view_; }
312
devices() const313 const DeviceSet* devices() const {
314 // Create fake devices lazily only if we need a DeviceSet.
315 if (available_devices_.empty() && !item_->devices().empty()) {
316 for (const string& name : item_->devices()) {
317 auto device = absl::make_unique<FakeDevice>(name);
318 available_device_set_.AddDevice(device.get());
319 available_devices_.push_back(std::move(device));
320 }
321 }
322 return &available_device_set_;
323 }
324
IsFetchNode(const string & node_name) const325 bool IsFetchNode(const string& node_name) const {
326 return absl::c_any_of(item_->fetch, [&](const string& fetch) {
327 return ParseTensorName(fetch).node() == node_name;
328 });
329 }
330
IsKeepOp(const string & node_name) const331 bool IsKeepOp(const string& node_name) const {
332 return absl::c_any_of(item_->keep_ops, [&](const string& keep_node) {
333 return keep_node == node_name;
334 });
335 }
336
IsTrulyConst(const string & name) const337 bool IsTrulyConst(const string& name) const {
338 return TrulyConstNode(name) != nullptr;
339 }
340
TrulyConstNode(const string & name) const341 const NodeDef* TrulyConstNode(const string& name) const {
342 return gtl::FindWithDefault(truly_const_nodes_, name, nullptr);
343 }
344
FindFunctionSpecialization(const FunctionSpecializationSignature & sig) const345 const FunctionSpecialization* FindFunctionSpecialization(
346 const FunctionSpecializationSignature& sig) const {
347 return gtl::FindOrNull(specialized_functions_, sig);
348 }
349
AddSpecializedFunction(const FunctionSpecializationSignature & sig,const FunctionSpecialization & specialized_func)350 void AddSpecializedFunction(const FunctionSpecializationSignature& sig,
351 const FunctionSpecialization& specialized_func) {
352 specialized_functions_.emplace(sig, specialized_func);
353 }
354
AddTensorMapping(const SafeTensorId & from,const SafeTensorId & to)355 void AddTensorMapping(const SafeTensorId& from, const SafeTensorId& to) {
356 DCHECK(from.index() != Graph::kControlSlot)
357 << "Tensor mapping must be from regular tensor";
358 DCHECK(to.index() != Graph::kControlSlot)
359 << "Tensor mapping must be to regular tensor";
360
361 auto inserted = tensor_mapping_.insert({from, to});
362 DCHECK(inserted.second)
363 << "Failed to insert duplicated tensor mapping: "
364 << "from=" << from.ToString() << " to=" << to.ToString();
365 }
366
AddTensorMapping(const string & func_node,const FunctionSpecialization & specialized_func)367 void AddTensorMapping(const string& func_node,
368 const FunctionSpecialization& specialized_func) {
369 for (const auto& pair : specialized_func.output_mapping) {
370 int from_idx = pair.first;
371 int to_idx = pair.second;
372 if (from_idx != to_idx) {
373 SafeTensorId from_tensor(func_node, from_idx);
374 SafeTensorId to_tensor(func_node, to_idx);
375 AddTensorMapping(from_tensor, to_tensor);
376 }
377 }
378 }
379
AddControlOverrides(const NodeDef & func_node,const std::vector<string> & control_overrides)380 void AddControlOverrides(const NodeDef& func_node,
381 const std::vector<string>& control_overrides) {
382 VLOG(4) << "Add control overrides: from=" << func_node.name() << " to: ["
383 << absl::StrJoin(control_overrides, ", ") << "]";
384
385 control_overrides_[func_node.name()].reserve(control_overrides.size());
386 for (const string& control_override : control_overrides) {
387 control_overrides_[func_node.name()].push_back(control_override);
388 }
389 }
390
391 private:
InferTrulyConstNodes(const GrapplerItem & item,const GraphDef & graph)392 static absl::flat_hash_map<string, const NodeDef*> InferTrulyConstNodes(
393 const GrapplerItem& item, const GraphDef& graph) {
394 absl::flat_hash_set<absl::string_view> feed_nodes;
395 for (const auto& feed : item.feed) {
396 feed_nodes.insert(feed.first);
397 }
398
399 absl::flat_hash_map<string, const NodeDef*> const_nodes;
400 for (const NodeDef& node : graph.node()) {
401 if (IsConstant(node) && !feed_nodes.contains(node.name())) {
402 const_nodes[node.name()] = &node;
403 }
404 }
405
406 return const_nodes;
407 }
408
InitializeFunctionLibraryRuntime()409 void InitializeFunctionLibraryRuntime() {
410 if (!flr_) {
411 Env* env = Env::Default();
412 std::vector<std::unique_ptr<Device>> devices;
413 devices.push_back(absl::make_unique<FakeDevice>(env, "/device:CPU:0"));
414 device_mgr_ = absl::make_unique<DeviceMgr>(std::move(devices));
415 OptimizerOptions optimizer_opts;
416 optimizer_opts.set_do_function_inlining(true);
417 process_flr_.reset(new ProcessFunctionLibraryRuntime(
418 device_mgr_.get(), env, item_->graph.versions().producer(),
419 &function_library_, optimizer_opts));
420 flr_ = process_flr_->GetFLR(device_mgr_->ListDevices()[0]->name());
421 }
422 }
423
424 const GrapplerItem* item_; // must outlive this object
425 RewriterConfig::Toggle opt_level_;
426
427 // Function library constructed from current graph.
428 FunctionLibraryDefinition function_library_;
429
430 // These fields initialized lazily only if needed.
431 std::unique_ptr<DeviceMgr> device_mgr_;
432 std::unique_ptr<ProcessFunctionLibraryRuntime> process_flr_;
433 FunctionLibraryRuntime* flr_ = nullptr;
434
435 // List of available `FakedDevices` (lazily initialized, see devices()).
436 mutable std::vector<std::unique_ptr<Device>> available_devices_;
437
438 // DeviceSet of fake devices (`FakeDevice`) constructed from
439 // item_.devices() (lazily initialized).
440 mutable DeviceSet available_device_set_;
441
442 // Nodes that are Const and not in feed.
443 absl::flat_hash_map<string, const NodeDef*> truly_const_nodes_;
444 // Specialized functions.
445 absl::flat_hash_map<FunctionSpecializationSignature,
446 const FunctionSpecialization>
447 specialized_functions_;
448
449 // After function inlining and specialization, the optimized graph might be in
450 // invalid state, nodes can read from non-existing function call nodes that
451 // were inlined, or they can read from output index that is no longer valid
452 // after unused outputs pruning.
453 //
454 // Tensor mapping that has to be applied to the graph after all functions
455 // optimizations (invalidated tensor id -> optimized graph tensor id).
456 absl::flat_hash_map<SafeTensorId, SafeTensorId, SafeTensorId::Hasher>
457 tensor_mapping_;
458
459 // When we inline a function into the optimized graph, we no longer have the
460 // function call node to anchor control dependencies. Instead we must expand
461 // each function call control output edge into multiple control dependencies
462 // to all side-effectful ops inside the function body.
463 //
464 // Invalidated function call node name -> Inlined side-effectful nodes
465 absl::flat_hash_map<string, std::vector<string>> control_overrides_;
466
467 // Use graph view to find active outputs of the function caller nodes.
468 GraphView graph_view_;
469
470 TF_DISALLOW_COPY_AND_ASSIGN(FunctionOptimizerContext);
471 };
472
473 // Returns a pointer to the called function definition iff the given node is
474 // indeed a function call. Otherwise returns nullptr.
FindFunctionCall(const FunctionOptimizerContext & ctx,const NodeDef & node)475 const FunctionDef* FindFunctionCall(const FunctionOptimizerContext& ctx,
476 const NodeDef& node) {
477 // Check if a node does indirect function call via PartitionedCallOp.
478 if (IsPartitionedCall(node) || IsStatefulPartitionedCall(node)) {
479 const AttrValue* func_attr = AttrSlice(node).Find("f");
480 return (func_attr != nullptr && func_attr->has_func())
481 ? ctx.function_library().Find(func_attr->func().name())
482 : nullptr;
483 }
484
485 // Check if the function op itself is a function name.
486 return ctx.function_library().Find(node.op());
487 }
488
GetActiveOutputs(const NodeDef & node,const FunctionOptimizerContext & ctx,int size_hint=0)489 absl::flat_hash_set<int> GetActiveOutputs(const NodeDef& node,
490 const FunctionOptimizerContext& ctx,
491 int size_hint = 0) {
492 absl::flat_hash_set<int> active_outputs;
493 active_outputs.reserve(static_cast<size_t>(size_hint));
494
495 // 1. Output can be consumed by the other graph node.
496 const auto node_fanout_edges =
497 ctx.graph_view().GetFanoutEdges(node, /*include_controlled_edges=*/false);
498 for (const GraphView::Edge& edge : node_fanout_edges) {
499 active_outputs.insert(edge.src.port_id);
500 }
501
502 // 2. Or it can be in a fetch set.
503 for (const string& fetch : ctx.item().fetch) {
504 TensorId fetch_tensor = ParseTensorName(fetch);
505 if (fetch_tensor.node() == node.name()) {
506 active_outputs.insert(fetch_tensor.index());
507 }
508 }
509
510 return active_outputs;
511 }
512
HasTrulyConstInputs(const NodeDef & node,const FunctionOptimizerContext & ctx)513 bool HasTrulyConstInputs(const NodeDef& node,
514 const FunctionOptimizerContext& ctx) {
515 const auto is_truly_const = [&ctx](const string& input) {
516 return ctx.IsTrulyConst(NodeName(input));
517 };
518 return absl::c_any_of(node.input(), is_truly_const);
519 }
520
HasUnusedOutputs(const NodeDef & func_node,const FunctionDef & func,const FunctionOptimizerContext & ctx)521 bool HasUnusedOutputs(const NodeDef& func_node, const FunctionDef& func,
522 const FunctionOptimizerContext& ctx) {
523 // Functions with tensor list outputs are not supported right now, so the
524 // number of output args is the same as number of possible function caller
525 // node outputs.
526 int num_outputs = func.signature().output_arg_size();
527 const absl::flat_hash_set<int> active_outputs =
528 GetActiveOutputs(func_node, ctx, /*size_hind*/ num_outputs);
529
530 return active_outputs.size() != num_outputs;
531 }
532
533 // Return pruned FunctionDefLibrary with functions that are reachable from
534 // the optimized graph.
PruneFunctionLibrary(const FunctionLibraryDefinition & flib,const GraphDef & optimized_graph)535 FunctionDefLibrary PruneFunctionLibrary(const FunctionLibraryDefinition& flib,
536 const GraphDef& optimized_graph) {
537 FunctionLibraryDefinition pruned_flib =
538 flib.ReachableDefinitions(optimized_graph);
539
540 int pruned_functions = static_cast<int>(pruned_flib.num_functions()) -
541 static_cast<int>(flib.num_functions());
542
543 VLOG(3) << "Pruned function library: " << pruned_flib.num_functions()
544 << " functions (" << pruned_functions << ")";
545
546 return pruned_flib.ToProto();
547 }
548
549 // Push all constant inputs of an instantiating node into the function body.
PushDownConstInputs(const NodeDef & func_node,const FunctionOptimizerContext & ctx,GrapplerFunctionItem * item,absl::flat_hash_set<string> * const_inputs,absl::flat_hash_set<string> * control_deps)550 Status PushDownConstInputs(const NodeDef& func_node,
551 const FunctionOptimizerContext& ctx,
552 GrapplerFunctionItem* item,
553 absl::flat_hash_set<string>* const_inputs,
554 absl::flat_hash_set<string>* control_deps) {
555 // Record node control dependencies in the control_deps set.
556 const auto record_control_deps = [&](const NodeDef* const_input) {
557 for (int i = const_input->input_size() - 1; i >= 0; --i) {
558 const string& input = const_input->input(i);
559 if (IsControlInput(input))
560 control_deps->insert(input);
561 else
562 break;
563 }
564 };
565
566 for (int i = func_node.input_size() - 1; i >= 0; --i) {
567 const string& input = func_node.input(i);
568 if (IsControlInput(input)) continue;
569
570 const string node_name = NodeName(input);
571 if (ctx.IsTrulyConst(node_name)) {
572 VLOG(3) << "Push const into function body: input=" << input;
573 const auto* const_input = CHECK_NOTNULL(ctx.TrulyConstNode(node_name));
574 const_inputs->insert(input);
575 record_control_deps(const_input);
576 TF_RETURN_IF_ERROR(ReplaceInputWithConst(*const_input, i, item));
577 }
578 }
579
580 return Status::OK();
581 }
582
583 // Remove inputs that were pushed into the function body, and attach their
584 // control dependencies to the function caller node.
RemovePushedDownConstInputs(const FunctionSpecialization & specialization,NodeDef * specialized_func_node)585 void RemovePushedDownConstInputs(const FunctionSpecialization& specialization,
586 NodeDef* specialized_func_node) {
587 // Nothing to do if it was no const inputs to the function node.
588 if (specialization.const_inputs.empty()) return;
589
590 // Keep only non-const inputs.
591 std::vector<string> keep_inputs;
592 const auto& inputs = specialized_func_node->input();
593 std::copy_if(inputs.begin(), inputs.end(), std::back_inserter(keep_inputs),
594 [&](const string& input) {
595 return specialization.const_inputs.find(input) ==
596 specialization.const_inputs.end();
597 });
598
599 specialized_func_node->clear_input();
600 for (const auto& keep : keep_inputs) specialized_func_node->add_input(keep);
601
602 // Attach control dependencies of pushed down const input to the caller node.
603 if (!specialization.control_deps.empty()) {
604 absl::flat_hash_set<string> existing_control_deps;
605
606 for (const string& input : keep_inputs) {
607 existing_control_deps.insert(AsControlDependency(NodeName(input)));
608 }
609
610 for (const string& ctrl : specialization.control_deps) {
611 if (existing_control_deps.find(ctrl) == existing_control_deps.end()) {
612 VLOG(3) << "Forward control dependency: input=" << ctrl;
613 specialized_func_node->add_input(ctrl);
614 }
615 }
616 }
617 }
618
619 // Remove Tin type parameters for pushed down const inputs.
RemovePushedDownConstInputTypes(const FunctionSpecialization & specialization,const NodeDef & func_node,NodeDef * specialized_func_node)620 void RemovePushedDownConstInputTypes(
621 const FunctionSpecialization& specialization, const NodeDef& func_node,
622 NodeDef* specialized_func_node) {
623 // Nothing to do if it was no const inputs to the function node.
624 if (specialization.const_inputs.empty()) return;
625
626 // Make sure that original function caller has Tin attribute.
627 const AttrValue* tin = AttrSlice(func_node).Find("Tin");
628 if (tin == nullptr || !tin->has_list()) return;
629
630 // Clear input types for the specialized node.
631 auto* attr = specialized_func_node->mutable_attr();
632 (*attr)["Tin"].mutable_list()->clear_type();
633
634 // Keep types of non-const inputs.
635 for (int i = 0; i < func_node.input_size(); ++i) {
636 const string& input = func_node.input(i);
637 if (IsControlInput(input)) break;
638
639 if (specialization.const_inputs.find(input) ==
640 specialization.const_inputs.end()) {
641 DataType dt = tin->list().type(i);
642 (*attr)["Tin"].mutable_list()->add_type(dt);
643 }
644 }
645 }
646
647 // Remove Tout type parameters for pruned function outputs.
RemoveUnusedOutputsTypes(const FunctionSpecialization & specialization,const NodeDef & func_node,NodeDef * specialized_func_node)648 void RemoveUnusedOutputsTypes(const FunctionSpecialization& specialization,
649 const NodeDef& func_node,
650 NodeDef* specialized_func_node) {
651 // Make sure that original function caller has Tout attribute.
652 const AttrValue* tout = AttrSlice(func_node).Find("Tout");
653 if (tout == nullptr || !tout->has_list()) return;
654
655 // Nothing to do if all outputs are active.
656 if (specialization.active_outputs.size() == tout->list().type_size()) return;
657
658 // Clear input types for the specialized node.
659 auto* attr = specialized_func_node->mutable_attr();
660 (*attr)["Tout"].mutable_list()->clear_type();
661
662 // Keep output types of active outputs only.
663 for (int i = 0; i < tout->list().type_size(); ++i) {
664 if (specialization.active_outputs.find(i) !=
665 specialization.active_outputs.end()) {
666 DataType dt = tout->list().type(i);
667 (*attr)["Tout"].mutable_list()->add_type(dt);
668 }
669 }
670 }
671
UpdateSpecializedFunctionCallSite(const FunctionDef & func,const NodeDef & func_node,const string & specialized_func_name,NodeDef * specialized_func_node)672 Status UpdateSpecializedFunctionCallSite(const FunctionDef& func,
673 const NodeDef& func_node,
674 const string& specialized_func_name,
675 NodeDef* specialized_func_node) {
676 if (IsDirectFunctionCall(func, func_node)) {
677 specialized_func_node->set_op(specialized_func_name);
678
679 } else if (IsIndirectFunctionCall(func, func_node)) {
680 auto* attr = specialized_func_node->mutable_attr();
681 (*attr)[kFuncAttrName].mutable_func()->set_name(specialized_func_name);
682
683 } else {
684 return errors::InvalidArgument("Unknown function call site");
685 }
686
687 return Status::OK();
688 }
689
690 // Update a graph node created from the original function caller node, to the
691 // function specialization. Function specialization might change the number of
692 // inputs and outputs, so we have to make sure that graph node is updated
693 // accordingly.
UpdateSpecializedFunctionNode(const FunctionDef & func,const NodeDef & func_node,const FunctionSpecialization & specialization,NodeDef * specialized_func_node)694 Status UpdateSpecializedFunctionNode(
695 const FunctionDef& func, const NodeDef& func_node,
696 const FunctionSpecialization& specialization,
697 NodeDef* specialized_func_node) {
698 // Function called indirectly via custom kernel (e.g. PartitionedCallOp).
699 bool is_indirect_call = IsIndirectFunctionCall(func, func_node);
700
701 // 1. Call the specialized function instead of original one.
702 TF_RETURN_IF_ERROR(UpdateSpecializedFunctionCallSite(
703 func, func_node, specialization.specialized_func_name,
704 specialized_func_node));
705
706 // 2. Remove inputs corresponding to the pushed down consts.
707 RemovePushedDownConstInputs(specialization, specialized_func_node);
708
709 // NOTE: PartitionedCallOp has `Tin` and `Tout` attributes for input/output
710 // types, that must be in sync with updated function signature.
711
712 // 3. Update input types for the indirect function calls.
713 if (is_indirect_call) {
714 RemovePushedDownConstInputTypes(specialization, func_node,
715 specialized_func_node);
716 }
717
718 // 4. Update output types for the indirect function call. It's unsafe to
719 // change the number of outputs for the fetch nodes, so we just skip them.
720 if (is_indirect_call && !specialization.is_in_fetch_set) {
721 RemoveUnusedOutputsTypes(specialization, func_node, specialized_func_node);
722 }
723
724 // 5. Remove custom gradient annotation.
725 specialized_func_node->mutable_attr()->erase("_gradient_op_type");
726
727 return Status::OK();
728 }
729
InitializeFunctionSpecializationSignature(const NodeDef & func_node,const FunctionDef & func,const AttrSlice & func_instantiation_attr,const FunctionOptimizerContext & ctx,FunctionSpecializationSignature * sig)730 Status InitializeFunctionSpecializationSignature(
731 const NodeDef& func_node, const FunctionDef& func,
732 const AttrSlice& func_instantiation_attr,
733 const FunctionOptimizerContext& ctx, FunctionSpecializationSignature* sig) {
734 DCHECK(sig->const_inputs.empty());
735 DCHECK(sig->active_outputs.empty());
736
737 sig->func_name = func.signature().name();
738 sig->is_in_fetch_set = ctx.IsFetchNode(func_node.name());
739 sig->active_outputs = GetActiveOutputs(func_node, ctx);
740
741 TF_RETURN_IF_ERROR(InstantiationTypeParameters(func, func_instantiation_attr,
742 &sig->type_parameters));
743 TF_RETURN_IF_ERROR(InstantiationBodyParameters(func, func_instantiation_attr,
744 &sig->body_parameters));
745
746 for (int i = 0; i < func_node.input_size(); ++i) {
747 const string& input = func_node.input(i);
748 if (IsControlInput(input)) break;
749 if (ctx.IsTrulyConst(input)) {
750 sig->const_inputs.emplace(i, input);
751 }
752 }
753
754 return Status::OK();
755 }
756
757 // Create a name for the function specialization. The name of the function, name
758 // of the node instantiating it, and a Grappler item id should generate unique
759 // function name. Meta optimizer might create multiple Grappler items for the
760 // same graph when optimizing functions, but it's guaranteed that they all will
761 // have unique ids.
SpecializedFunctionName(const FunctionOptimizerContext & ctx,const FunctionDef & func,const NodeDef & func_node)762 string SpecializedFunctionName(const FunctionOptimizerContext& ctx,
763 const FunctionDef& func,
764 const NodeDef& func_node) {
765 return absl::Substitute(
766 "$0_specialized_for_$1_at_$2", func.signature().name(),
767 absl::StrReplaceAll(func_node.name(), {{"/", "_"}}), ctx.item().id);
768 }
769
SpecializeFunction(const NodeDef & func_node,const FunctionDef & func,FunctionOptimizerContext * ctx,GraphDef * optimized_graph)770 Status SpecializeFunction(const NodeDef& func_node, const FunctionDef& func,
771 FunctionOptimizerContext* ctx,
772 GraphDef* optimized_graph) {
773 VLOG(2) << "Specialize function call: " << SummarizeNodeDef(func_node);
774
775 const AttrSlice func_instantiation_attr =
776 FunctionInstantiationAttributes(func, func_node);
777
778 FunctionSpecializationSignature signature;
779 TF_RETURN_IF_ERROR(InitializeFunctionSpecializationSignature(
780 func_node, func, func_instantiation_attr, *ctx, &signature));
781
782 // Check if function was already specialized for identical context.
783 const FunctionSpecialization* already_specialized =
784 ctx->FindFunctionSpecialization(signature);
785
786 if (already_specialized) {
787 VLOG(2) << "Function was already specialized in identical context: "
788 "specialized_name="
789 << already_specialized->specialized_func_name;
790
791 // Add a function call node for the specialized function.
792 NodeDef* specialized_func_node = optimized_graph->add_node();
793 *specialized_func_node = func_node;
794
795 TF_RETURN_IF_ERROR(UpdateSpecializedFunctionNode(
796 func, func_node, *already_specialized, specialized_func_node));
797
798 ctx->AddTensorMapping(specialized_func_node->name(), *already_specialized);
799
800 return Status::OK();
801 }
802
803 // Add a new specialized function definition to the library.
804 const auto& flib = ctx->function_library();
805
806 // Make a GrapplerFunctionItem and convert it back to FunctionDef after
807 // pushing all constant inputs into the function body.
808 GrapplerFunctionItem item;
809 TF_RETURN_IF_ERROR(MakeGrapplerFunctionItem(
810 func, func_instantiation_attr, flib, ctx->graph_version(), &item));
811
812 // Push const inputs into the function body, and keep track of their control
813 // dependencies.
814 absl::flat_hash_set<string> const_inputs;
815 absl::flat_hash_set<string> control_deps;
816 TF_RETURN_IF_ERROR(PushDownConstInputs(func_node, *ctx, &item, &const_inputs,
817 &control_deps));
818
819 // Remove function outputs that do not have any consumers. We can't safely
820 // update outputs for the fetch nodes, so we just skip them.
821 std::vector<std::pair<int, int>> output_mapping;
822 if (!signature.is_in_fetch_set) {
823 int num_func_outputs = 0;
824 for (const auto& out_arg : item.outputs()) {
825 num_func_outputs += out_arg.output_nodes.size();
826 }
827
828 absl::flat_hash_set<int> remove;
829 for (int i = 0; i < num_func_outputs; ++i) {
830 if (!signature.active_outputs.count(i)) remove.insert(i);
831 }
832
833 TF_RETURN_IF_ERROR(RemoveFunctionOutputs(remove, &item, &output_mapping));
834 }
835
836 // TODO(ezhulenev): Push down known input shapes.
837 FunctionDef specialized_func;
838 TF_RETURN_IF_ERROR(MakeFunctionDef(item, flib, &specialized_func));
839
840 // Find a name for specialized function.
841 const string specialized_func_name =
842 SpecializedFunctionName(*ctx, func, func_node);
843 if (flib.Contains(specialized_func_name)) {
844 // NOTE(ezhulenev): This should never happen. If it happens, it's a sign of
845 // a serious internal error, that must be investigated.
846 return errors::Internal("Created duplicate function specialization");
847 }
848
849 specialized_func.mutable_signature()->set_name(specialized_func_name);
850 auto* specialized_attr = specialized_func.mutable_attr();
851 (*specialized_attr)[kGrapplerSpecializedFuncAttr].set_b(true);
852
853 // Add specialized function to the library.
854 TF_RETURN_IF_ERROR(
855 ctx->mutable_function_library()->AddFunctionDef(specialized_func));
856
857 // Add a function call node for the specialized function.
858 NodeDef* specialized_func_node = optimized_graph->add_node();
859 *specialized_func_node = func_node;
860
861 FunctionSpecialization func_specialization = {
862 specialized_func_name, signature.is_in_fetch_set, const_inputs,
863 control_deps, signature.active_outputs, output_mapping};
864
865 TF_RETURN_IF_ERROR(UpdateSpecializedFunctionNode(
866 func, func_node, func_specialization, specialized_func_node));
867
868 ctx->AddSpecializedFunction(signature, func_specialization);
869 ctx->AddTensorMapping(specialized_func_node->name(), func_specialization);
870
871 return Status::OK();
872 }
873
874 // -------------------------------------------------------------------------- //
875 // Inline direct functions calls.
876 //
877 // When we inline direct function calls, we instantiate the function body from
878 // its FunctionDef and caller node attributes, and embed the instantiated graph
879 // into the "main graph". When we do that, we must preserve the function call
880 // semantics:
881 //
882 // 1) All input nodes must be executed before any of function body nodes will
883 // start executing.
884 // 2) All function body nodes must be executed before any of the nodes, reading
885 // outputs of the function will start executing.
886 // 3) All nodes with side effects inside a function must be executed, this is
887 // different from the nodes with side effects in the main graph, that can be
888 // pruned if they are not in transitive dependency set of any of the fetch
889 // nodes.
890 // 4) All nodes of the function body must be execute on the device specified by
891 // the function caller node.
892 //
893 // To guarantee that function call semantics are preserved after inlining, we
894 // insert an IdentityN node before the inlined function body, and hook all
895 // inputs into that, and we insert another IdentityN node to hook all function
896 // outputs to it.
897
898 // Returns `Status::OK()` iff `node` is a direct function call of `func`, and we
899 // know how to inline it into the main graph, otherwise returns and error
900 // indicating why the function call is not inlinable.
IsInlinableDirectFunctionCall(const FunctionOptimizerContext & ctx,const FunctionDef & func,const NodeDef & func_node)901 Status IsInlinableDirectFunctionCall(const FunctionOptimizerContext& ctx,
902 const FunctionDef& func,
903 const NodeDef& func_node) {
904 // Indirect function calls (PartitionedCallOp) have automatic control
905 // dependencies and inlined separately from direct function calls.
906 if (!IsDirectFunctionCall(func, func_node)) {
907 return errors::InvalidArgument("Unsupported function call type: ",
908 SummarizeNodeDef(func_node));
909 }
910
911 // For direct function calls we insert IdentityN nodes before/after inlined
912 // function body to preserve function call semantics (all inputs evaluated
913 // before function evaluation starts, and all function body nodes finished
914 // before output consumed by other nodes).
915 if (func.signature().input_arg_size() == 0) {
916 return errors::FailedPrecondition(
917 "Can't inline direct function call with empty inputs: ",
918 SummarizeNodeDef(func_node));
919 }
920
921 // TODO(ezhulenev): Relax constraint on output args?
922 if (func.signature().output_arg_size() == 0) {
923 return errors::FailedPrecondition(
924 "Can't inline direct function call with empty outputs: ",
925 SummarizeNodeDef(func_node));
926 }
927
928 // Function must execute all the nodes in a function body that might have side
929 // effects. After inlining these nodes into the main graph, we can no longer
930 // guarantee that. For now we disable inlining functions with side effects.
931 //
932 // Attaching control dependency to the output IdentityN node is not safe,
933 // because it might be split or pruned in a later optimization pass.
934 //
935 // Indirect function calls (via PartitionedCallOp) have automatic dependency
936 // tracking, and allow us to safely inline functions with side effects.
937 bool has_side_effects =
938 absl::c_any_of(func.node_def(), [&ctx](const NodeDef& node) {
939 return !IsFreeOfSideEffect(node, &ctx.function_library());
940 });
941 if (has_side_effects) {
942 return errors::FailedPrecondition(
943 "Can't inline function with side-effects in the function body: ",
944 SummarizeNodeDef(func_node));
945 }
946
947 // We ignore `_noinline` marker in aggressive mode.
948 bool aggressive = ctx.opt_level() == RewriterConfig::AGGRESSIVE;
949 if (MarkedNoInline(func) && !aggressive) {
950 return errors::FailedPrecondition(
951 "Can't inline function marked with '_noinline': ",
952 SummarizeNodeDef(func_node));
953 }
954
955 // Function specialization and inlining must be mutually exclusive.
956 if (MarkedSpecialized(func)) {
957 return errors::FailedPrecondition(
958 "Can't inline function created in Grappler function specialization: ",
959 SummarizeNodeDef(func_node));
960 }
961
962 return Status::OK();
963 }
964
965 // Create an IdentityN node to hook the function inputs to: this ensures that
966 // they're all evaluated before the evaluation of the function body starts.
InlinedFunctionInputsNode(const NodeDef & func_node,const GrapplerFunctionItem & item)967 NodeDef InlinedFunctionInputsNode(const NodeDef& func_node,
968 const GrapplerFunctionItem& item) {
969 NodeDef inputs;
970 inputs.set_name(strings::StrCat(func_node.name(), "/", "inlined_inputs"));
971 inputs.set_op("IdentityN");
972 inputs.set_device(func_node.device());
973 *inputs.mutable_input() = func_node.input();
974 AttrValue::ListValue* type_list =
975 (*inputs.mutable_attr())["T"].mutable_list();
976
977 for (const InputArgExpansion& input_arg : item.inputs()) {
978 for (int i = 0; i < input_arg.placeholders.size(); ++i) {
979 type_list->add_type(input_arg.data_type);
980 }
981 }
982
983 return inputs;
984 }
985
986 // Create an IdentityN node to hook the function outputs to: this ensures that
987 // the function body is fully evaluated before its fanout gets scheduled.
InlinedFunctionOutputsNode(const NodeDef & func_node,const GrapplerFunctionItem & item,const absl::flat_hash_map<absl::string_view,absl::string_view> output_tensors)988 NodeDef InlinedFunctionOutputsNode(
989 const NodeDef& func_node, const GrapplerFunctionItem& item,
990 const absl::flat_hash_map<absl::string_view, absl::string_view>
991 output_tensors) {
992 NodeDef outputs;
993 outputs.set_name(func_node.name());
994 outputs.set_op("IdentityN");
995 outputs.set_device(func_node.device());
996 AttrValue::ListValue* type_list =
997 (*outputs.mutable_attr())["T"].mutable_list();
998
999 for (const OutputArgExpansion& output_arg : item.outputs()) {
1000 for (const string& output_node : output_arg.output_nodes) {
1001 const absl::string_view output_tensor = output_tensors.at(output_node);
1002 type_list->add_type(output_arg.data_type);
1003 outputs.add_input(strings::StrCat(func_node.name(), "/", output_tensor));
1004 }
1005 }
1006
1007 return outputs;
1008 }
1009
InlineDirectFunctionCall(const NodeDef & func_node,const FunctionDef & func,const FunctionOptimizerContext & ctx,GraphDef * optimized_graph)1010 Status InlineDirectFunctionCall(const NodeDef& func_node,
1011 const FunctionDef& func,
1012 const FunctionOptimizerContext& ctx,
1013 GraphDef* optimized_graph) {
1014 VLOG(2) << "Inline direct function call: " << SummarizeNodeDef(func_node);
1015 TF_RETURN_IF_ERROR(IsInlinableDirectFunctionCall(ctx, func, func_node));
1016
1017 const AttrSlice func_instantiation_attr =
1018 FunctionInstantiationAttributes(func, func_node);
1019
1020 GrapplerFunctionItem item;
1021 Status item_status = MakeGrapplerFunctionItem(func, func_instantiation_attr,
1022 ctx.function_library(),
1023 ctx.graph_version(), &item);
1024
1025 if (!item_status.ok()) {
1026 return errors::InvalidArgument("Failed to inline function ", func_node.op(),
1027 " instantiated by ", func_node.name(),
1028 ". Error: ", item_status.error_message());
1029 }
1030
1031 // Mapping from input placeholder name to function input position.
1032 absl::flat_hash_map<absl::string_view, int> input_placeholders_idx;
1033 for (const InputArgExpansion& input_arg : item.inputs()) {
1034 for (const string& placeholder : input_arg.placeholders) {
1035 const int idx = input_placeholders_idx.size();
1036 input_placeholders_idx[placeholder] = idx;
1037 }
1038 }
1039
1040 // Bypass identity nodes added to the graph in place of function outputs.
1041 absl::flat_hash_set<absl::string_view> output_nodes;
1042 for (const OutputArgExpansion& output_arg : item.outputs()) {
1043 for (const string& output_node : output_arg.output_nodes) {
1044 output_nodes.insert(output_node);
1045 }
1046 }
1047
1048 // For each function output value we added an identity node that reads the
1049 // tensor from one of the function body nodes. When we inline function into
1050 // the main graph we want to bypass these nodes, so we keep a mapping from
1051 // 'output node name' -> 'output tensor name'.
1052 absl::flat_hash_map<absl::string_view, absl::string_view> output_tensors;
1053
1054 // Hook inlined function inputs to IdentityN node.
1055 NodeDef* func_inputs = optimized_graph->add_node();
1056 *func_inputs = InlinedFunctionInputsNode(func_node, item);
1057
1058 for (NodeDef& func_body_node : *item.mutable_function_body().mutable_node()) {
1059 const string& node_name = func_body_node.name();
1060
1061 // Skip output identity node, and update a mapping to the output tensor.
1062 if (IsIdentity(func_body_node) && output_nodes.count(node_name)) {
1063 output_tensors.emplace(node_name, func_body_node.input(0));
1064 continue;
1065 }
1066
1067 // Turn placeholders added in place of input arguments into identity nodes.
1068 const auto input_placeholder_idx = input_placeholders_idx.find(node_name);
1069 if (input_placeholder_idx != input_placeholders_idx.end()) {
1070 CHECK_EQ(0, func_body_node.input_size());
1071 func_body_node.set_op("Identity");
1072 (*func_body_node.mutable_attr())["T"] = func_body_node.attr().at("dtype");
1073 func_body_node.mutable_attr()->erase("dtype");
1074 func_body_node.mutable_attr()->erase("shape");
1075 func_body_node.add_input(strings::StrCat(func_inputs->name(), ":",
1076 input_placeholder_idx->second));
1077 } else {
1078 // Update the input names if any.
1079 for (string& input : *func_body_node.mutable_input()) {
1080 input = AddPrefixToNodeName(input, /*prefix=*/func_node.name());
1081 }
1082 // If the node has no input, make hook it up to the func_inputs node to
1083 // ensure it runs in the same frame as the other nodes of the function
1084 // body.
1085 if (func_body_node.input_size() == 0) {
1086 *func_body_node.add_input() = AsControlDependency(func_inputs->name());
1087 }
1088 }
1089
1090 // Add the function node name as a prefix 1) to node name to avoid
1091 // collisions; 2) to frame name to avoid multiple LoopCond nodes in one
1092 // frame after inlining.
1093 const string prefix = strings::StrCat(func_node.name(), "/");
1094 TF_RETURN_IF_ERROR(
1095 AddPrefixAndSuffixToNode(prefix, "" /* suffix */, &func_body_node));
1096
1097 // Make sure the node is placed.
1098 func_body_node.set_device(func_node.device());
1099
1100 // Move the node to the main graph.
1101 optimized_graph->add_node()->Swap(&func_body_node);
1102 }
1103
1104 DCHECK(output_tensors.size() == item.output_size())
1105 << "Each function output must be mapped to an output tensor";
1106
1107 // Hook inlined function outputs to IdentityN node.
1108 NodeDef* func_outputs = optimized_graph->add_node();
1109 *func_outputs = InlinedFunctionOutputsNode(func_node, item, output_tensors);
1110
1111 return Status::OK();
1112 }
1113
InlineSymbolicGradient(const NodeDef & node,FunctionOptimizerContext * ctx,GraphDef * optimized_graph)1114 Status InlineSymbolicGradient(const NodeDef& node,
1115 FunctionOptimizerContext* ctx,
1116 GraphDef* optimized_graph) {
1117 VLOG(2) << "Inline symbolic gradient: " << SummarizeNodeDef(node);
1118
1119 GraphDef graph_def;
1120
1121 // Create a node to anchor the gradient inputs
1122 NodeDef* inlined_input = graph_def.add_node();
1123 inlined_input->set_name("FunctionInputs");
1124 inlined_input->set_op("IdentityN");
1125 AttrValue::ListValue* type_list =
1126 (*inlined_input->mutable_attr())["T"].mutable_list();
1127 for (const auto& type : node.attr().at("Tin").list().type()) {
1128 type_list->add_type(static_cast<DataType>(type));
1129 }
1130
1131 // Add the gradient node
1132 NodeDef* inlined = graph_def.add_node();
1133 *inlined = node;
1134 inlined->clear_input();
1135 for (int i = 0; i < node.attr().at("Tin").list().type_size(); ++i) {
1136 inlined->add_input(strings::StrCat(inlined_input->name(), ":", i));
1137 }
1138
1139 // Create a node to anchor the gradient outputs
1140 NodeDef* inlined_output = graph_def.add_node();
1141 inlined_output->set_name("FunctionOutputs");
1142 inlined_output->set_op("IdentityN");
1143 type_list = (*inlined_output->mutable_attr())["T"].mutable_list();
1144 for (const auto& type : node.attr().at("Tout").list().type()) {
1145 type_list->add_type(static_cast<DataType>(type));
1146 }
1147 for (int i = 0; i < node.attr().at("Tout").list().type_size(); ++i) {
1148 inlined_output->add_input(strings::StrCat(inlined->name(), ":", i));
1149 }
1150
1151 // Convert the graphdef to a graph
1152 GraphConstructorOptions graph_ctor_opts;
1153 graph_ctor_opts.allow_internal_ops = true;
1154 graph_ctor_opts.expect_device_spec = false;
1155 Graph graph(ctx->function_library());
1156 TF_RETURN_IF_ERROR(
1157 ConvertGraphDefToGraph(graph_ctor_opts, graph_def, &graph));
1158
1159 FunctionLibraryRuntime* flr = ctx->mutable_function_library_runtime();
1160
1161 // 1. Inline symbolic gradient node.
1162 const bool expanded = ExpandInlineFunctions(flr, &graph);
1163 if (!expanded) {
1164 return errors::Internal("Failed to expand SymbolicGradient op");
1165 }
1166
1167 // TODO(ezhulenev): InlineFunctionBody in common_runtime/function silently
1168 // fails to inline function into the graph, and leaves the graph unmodified.
1169 // We check that graph has our symbolic gradient inlined, otherwise we return
1170 // a error.
1171 const auto is_symbolic_gradient_op = [&](const Node* node) {
1172 return node->name() == inlined->name() &&
1173 node->type_string() == "SymbolicGradient";
1174 };
1175 for (Node* node : graph.nodes()) {
1176 if (is_symbolic_gradient_op(node)) {
1177 return errors::Internal("Failed to inline symbolic gradient node: ",
1178 SummarizeNode(*node));
1179 }
1180 }
1181
1182 // 2. Recursively inline nested function calls.
1183 int iteration = 0;
1184 while (ExpandInlineFunctions(flr, &graph)) {
1185 if (++iteration >= 50) {
1186 VLOG(2) << "Break symbolic gradient inlining loop at iteration #"
1187 << iteration;
1188 break;
1189 }
1190 }
1191
1192 GraphDef inlined_graph_def;
1193 graph.ToGraphDef(&inlined_graph_def);
1194
1195 // Add the default values of attributes to the nodes that have been inlined.
1196 TF_RETURN_IF_ERROR(AddDefaultAttrsToGraphDef(&inlined_graph_def,
1197 *graph.op_registry(), 0, true));
1198
1199 // Add the inlined nodes to the graph
1200 for (NodeDef& inlined_node : *inlined_graph_def.mutable_node()) {
1201 if (inlined_node.name() == "FunctionOutputs") {
1202 inlined_node.set_name(node.name());
1203 for (int i = 0; i < inlined_node.input_size(); ++i) {
1204 inlined_node.set_input(
1205 i, AddPrefixToNodeName(inlined_node.input(i), node.name()));
1206 }
1207 } else if (inlined_node.name() == "FunctionInputs") {
1208 inlined_node.set_name(
1209 AddPrefixToNodeName(inlined_node.name(), node.name()));
1210 inlined_node.clear_input();
1211 for (int i = 0; i < node.input_size(); ++i) {
1212 inlined_node.add_input(node.input(i));
1213 }
1214 } else {
1215 inlined_node.set_name(
1216 AddPrefixToNodeName(inlined_node.name(), node.name()));
1217 for (int i = 0; i < inlined_node.input_size(); ++i) {
1218 inlined_node.set_input(
1219 i, AddPrefixToNodeName(inlined_node.input(i), node.name()));
1220 }
1221 // If the node has no input, hook it up to the function input node to make
1222 // sure it runs in the same frame as the other nodes of the function body.
1223 if (inlined_node.input_size() == 0) {
1224 *inlined_node.add_input() = AsControlDependency(
1225 AddPrefixToNodeName("FunctionInputs", node.name()));
1226 }
1227 }
1228 inlined_node.set_device(node.device());
1229 optimized_graph->add_node()->Swap(&inlined_node);
1230 }
1231
1232 return Status::OK();
1233 }
1234
1235 // -------------------------------------------------------------------------- //
1236 // Inline indirect functions calls (aka PartitionedCallOp).
1237 //
1238 // When we inline indirect function calls, we instantiate the function body from
1239 // its FunctionDef and caller node attributes, and embed the instantiated graph
1240 // into the "main graph".
1241 //
1242 // In contrast to direct function calls, `PartitionedCallOp` has automatic
1243 // dependency tracking via input/output control edges, and we relax some of the
1244 // constraints that we have for direct function call inlining.
1245 //
1246 // Automatic control dependency rules:
1247 //
1248 // 1) "When a `PartitionedCallOp` function has a resource (DT_RESOURCE data
1249 // type) input argument it "captures" the mutable resource. This is
1250 // implemented by automatically adding a incoming control edge from the
1251 // previous side-effectful op touching that resource, and an outgoing control
1252 // edge to the next side-effectful op using the same resource. This
1253 // serializes the mutations of the resource to make graph execution
1254 // deterministic.
1255 //
1256 // 2) All stateful ops inside a function body are guaranteed to execute in
1257 // program order, this is achieved by adding control edges between stateful
1258 // ops at graph construction time.
1259 //
1260 // 3) Furthermore, all ops accepting the same resource as an input are
1261 // guaranteed to run in program order. This is also done by adding control
1262 // edges at graph construction time. The last op touching the resource
1263 // will have an outgoing control edge to all function return nodes, which
1264 // will guarantee that all side effects to the resource will happen before
1265 // function completion.
1266 //
1267 // Function call inlining must preserve side effect visibility:
1268 //
1269 // 1) All side effects to the captured resources, that happened before function
1270 // call must be visible to the function body nodes using that resources.
1271 // 2) All side effects to the captured resources, that happened inside function
1272 // body, must be visible to every op/function using that resource after the
1273 // function call completed.
1274 //
1275 // To guarantee that these properties are preserved after inlining we:
1276 //
1277 // 1) Create "input_control" NoOp. Function call node incoming control edges
1278 // will be forwarded *to* this node. Function inputs (Identity nodes) will
1279 // have a control edge *from* this node. If function has no inputs, by
1280 // construction it must have nodes without inputs in the function body, and
1281 // in this case these nodes will have a control edge *from* this node.
1282
1283 // 2) Create "output_control" NoOp. All nodes that have incoming control edge
1284 // *from* the function call node, will be forwarded to this node. Function
1285 // outputs (Identity nodes) will have a control edge *to* this node. This
1286 // will guarantee that nodes that have control dependency on the function
1287 // call, will observe all side-effects (guaranteed by graph construction with
1288 // automatic control dependencies tracking).
1289 //
1290 // If after function instantiation we find a stateful or a dataset op inside
1291 // the function body, that is not reachable from any of the function outputs (or
1292 // if the function has no outputs), we do not inline it, because we can't
1293 // guarantee that these nodes will be executed in correct order (or executed at
1294 // all) after inlining.
1295 //
1296 // We do not try to add any extra control edges to make sure that all
1297 // side-effectful nodes will be executed, that should be handled at graph
1298 // construction time.
1299
1300 struct MaybeDeadOutput {
1301 const NodeDef* dead_tensor_src;
1302 const NodeDef* output_node_dst;
1303 };
1304
1305 // Finds all function outputs that might return a dead tensor. This can happen
1306 // if there is no `Merge` node on the path from the `Switch` node, to the
1307 // function output.
MaybeDeadOutputs(const FunctionOptimizerContext & ctx,const GrapplerFunctionItem & item,std::vector<MaybeDeadOutput> * maybe_dead)1308 Status MaybeDeadOutputs(const FunctionOptimizerContext& ctx,
1309 const GrapplerFunctionItem& item,
1310 std::vector<MaybeDeadOutput>* maybe_dead) {
1311 VLOG(3) << "Find function outputs that might return dead tensors: item.id="
1312 << item.id;
1313 DCHECK(maybe_dead->empty()) << "Input argument must be an empty vector";
1314
1315 std::vector<const NodeDef*> dead_tensor_srcs;
1316 for (const NodeDef& node : item.graph.node()) {
1317 if (IsSwitch(node)) {
1318 VLOG(4) << "Add dead tensors source. Switch node: " << node.name();
1319 dead_tensor_srcs.push_back(&node);
1320 continue;
1321 }
1322
1323 // Regular (aka 'direct') function call can also produce dead tensors if
1324 // the function body has mergeless switches.
1325 const FunctionDef* func = ctx.function_library().Find(node.op());
1326 if (func != nullptr) {
1327 GrapplerFunctionItem func_item;
1328 TF_RETURN_IF_ERROR(MakeGrapplerFunctionItem(
1329 *func, FunctionInstantiationAttributes(*func, node),
1330 ctx.function_library(), ctx.graph_version(), &func_item));
1331
1332 std::vector<MaybeDeadOutput> func_dead_outputs;
1333 TF_RETURN_IF_ERROR(MaybeDeadOutputs(ctx, func_item, &func_dead_outputs));
1334
1335 if (!func_dead_outputs.empty()) {
1336 VLOG(4) << "Add dead tensors source. Function call: " << node.op()
1337 << " node=" << node.name();
1338 dead_tensor_srcs.push_back(&node);
1339 }
1340 }
1341 }
1342
1343 // If we do not have dead tensor sources in the function body, it's
1344 // guaranteed that all output tensors can't become dead.
1345 if (dead_tensor_srcs.empty()) return Status::OK();
1346
1347 // Names of the function body nodes that return function output values.
1348 absl::flat_hash_set<absl::string_view> output_nodes;
1349 for (const auto& output_expansion : item.outputs()) {
1350 for (const auto& output_node : output_expansion.output_nodes) {
1351 output_nodes.insert(output_node);
1352 }
1353 }
1354
1355 GraphTopologyView topology_view;
1356 TF_RETURN_IF_ERROR(topology_view.InitializeFromGraph(item.graph));
1357
1358 for (const NodeDef* dead_tensor_src : dead_tensor_srcs) {
1359 DfsTraversal(topology_view, {dead_tensor_src},
1360 TraversalDirection::kFollowOutputs,
1361 // Stop traversal when reached first `Merge` node.
1362 DfsPredicates::Advance(
1363 [](const NodeDef* node) { return !IsMerge(*node); }),
1364 // If we reached output node, add MaybeDeadOutput edge.
1365 DfsCallbacks::PreOrder([&](const NodeDef* node) {
1366 if (output_nodes.find(node->name()) != output_nodes.end()) {
1367 maybe_dead->push_back({dead_tensor_src, node});
1368 }
1369 }));
1370 }
1371
1372 return Status::OK();
1373 }
1374
1375 // Returns `Status::OK()` iff `node` is an indirect function call of `func`, and
1376 // we know how to inline it into the main graph, otherwise returns and error
1377 // indicating why the function call is not inlinable.
IsInlinableIndirectFunctionCall(const FunctionOptimizerContext & ctx,const FunctionDef & func,const NodeDef & func_node)1378 Status IsInlinableIndirectFunctionCall(const FunctionOptimizerContext& ctx,
1379 const FunctionDef& func,
1380 const NodeDef& func_node) {
1381 // We inline direct function calls above, using different rules.
1382 if (!IsIndirectFunctionCall(func, func_node)) {
1383 return errors::InvalidArgument("Unsupported function call type: ",
1384 SummarizeNodeDef(func_node));
1385 }
1386
1387 if (MarkedNoInline(func)) {
1388 return errors::FailedPrecondition(
1389 "Can't inline function marked with '_noinline': ",
1390 SummarizeNodeDef(func_node));
1391 }
1392
1393 // Function specialization and inlining must be mutually exclusive.
1394 if (MarkedSpecialized(func)) {
1395 return errors::FailedPrecondition(
1396 "Can't inline function created in Grappler function specialization: ",
1397 SummarizeNodeDef(func_node));
1398 }
1399
1400 // We can't inline functions that are in a fetch set, because it would
1401 // invalidate fetch tensors (function call node fully inlined and doesn't
1402 // exist in the optimized graph).
1403 if (ctx.IsFetchNode(func_node.name())) {
1404 return errors::FailedPrecondition(
1405 "Can't inline function in a Grappler item fetch set: ",
1406 SummarizeNodeDef(func_node));
1407 }
1408
1409 return Status::OK();
1410 }
1411
1412 // Checks that all side-effects will be executed in well defined order. We do it
1413 // by checking if there is a path from stateful/dataset ops to one of the
1414 // control output nodes.
CheckThatSideEffectsWillExecute(const FunctionOptimizerContext & ctx,const GraphTopologyView & graph_topo_view,const absl::flat_hash_set<string> control_output_nodes)1415 Status CheckThatSideEffectsWillExecute(
1416 const FunctionOptimizerContext& ctx,
1417 const GraphTopologyView& graph_topo_view,
1418 const absl::flat_hash_set<string> control_output_nodes) {
1419 // In aggressive mode we just print a warning for side-effectful nodes that
1420 // might not be executed after inlining.
1421 const bool aggressive = ctx.opt_level() == RewriterConfig::AGGRESSIVE;
1422
1423 for (const NodeDef& func_body_node : graph_topo_view.graph()->node()) {
1424 const bool node_must_execute =
1425 IsDataset(func_body_node) ||
1426 IsStateful(func_body_node, &ctx.function_library());
1427
1428 // If op has DT_RESOURCE argument it will be marked as stateful, though if
1429 // it only reads from that resource, it's allowed to prune it, because it
1430 // can't produce any visible side-effects.
1431 const bool read_only = IsReadVariableOp(func_body_node);
1432
1433 if (read_only || !node_must_execute) continue;
1434
1435 VLOG(3) << "Check that node " << func_body_node.name()
1436 << " will execute after inlining.";
1437 bool will_execute = false;
1438
1439 // Check if we reached one of the output nodes.
1440 const auto callbacks = DfsCallbacks::PreOrder([&](const NodeDef* node) {
1441 if (control_output_nodes.contains(node->name())) {
1442 VLOG(4) << "Found a path to control output node: " << node->name();
1443 will_execute = true;
1444 }
1445 });
1446
1447 // Stop if we already proved that node will execute.
1448 const auto predicates = DfsPredicates::Enter(
1449 [&](const NodeDef* node) { return !will_execute; });
1450
1451 DfsTraversal(graph_topo_view, {&func_body_node},
1452 TraversalDirection::kFollowOutputs, predicates, callbacks);
1453
1454 if (!will_execute) {
1455 const string error_message = absl::StrCat(
1456 "Can't guarantee execution of a side-effectful node, that is not "
1457 "reachable from function outputs. Function body node: ",
1458 SummarizeNodeDef(func_body_node));
1459
1460 if (aggressive) {
1461 LOG(WARNING) << error_message;
1462 } else {
1463 return errors::Internal(error_message);
1464 }
1465 }
1466 }
1467
1468 return Status::OK();
1469 }
1470
PlaceInlinedFunctionBody(const NodeDef & func_node,const GrapplerFunctionItem & item,const absl::flat_hash_map<absl::string_view,int> & input_placeholders_idx,FunctionOptimizerContext * ctx,GraphDef * placed_graph_def)1471 Status PlaceInlinedFunctionBody(
1472 const NodeDef& func_node, const GrapplerFunctionItem& item,
1473 const absl::flat_hash_map<absl::string_view, int>& input_placeholders_idx,
1474 FunctionOptimizerContext* ctx, GraphDef* placed_graph_def) {
1475 // Control flow lowering and Placer works with a Graph object.
1476 std::unique_ptr<Graph> func_body_graph =
1477 absl::make_unique<Graph>(ctx->function_library());
1478
1479 GraphConstructorOptions opts;
1480 TF_RETURN_IF_ERROR(
1481 ConvertGraphDefToGraph(opts, item.graph, func_body_graph.get()));
1482
1483 // ------------------------------------------------------------------------ //
1484 // Grappler receives the graph after PRE_PLACEMENT, Placer, and POST_PLACEMENT
1485 // passes, so each node has a valid device assignment. Also V2 control
1486 // flow ops (functional If and While) should have been lowered to V1 control
1487 // flow (Switch and Merge nodes). To keep the graph valid for execution we
1488 // must assign device to every inlined graph node, and also lower the control
1489 // flow.
1490
1491 GraphOptimizationPassOptions opt_options;
1492 opt_options.graph = &func_body_graph;
1493 opt_options.flib_def = ctx->mutable_function_library();
1494
1495 // TODO(ezhulenev): Should we run full PRE_PLACEMENT pass here? And
1496 // POST_PLACEMENT after placer?
1497 LowerIfWhilePass pass;
1498 TF_RETURN_IF_ERROR(pass.Run(opt_options));
1499
1500 // ------------------------------------------------------------------------ //
1501 // Before placing the function body nodes we pin input placeholders to the
1502 // same device as their corresponding input nodes.
1503
1504 for (Node* func_body_node : func_body_graph->nodes()) {
1505 const auto input_placeholder_idx =
1506 input_placeholders_idx.find(func_body_node->name());
1507
1508 if (input_placeholder_idx != input_placeholders_idx.end()) {
1509 const int input_idx = input_placeholder_idx->second;
1510 const GraphView::OutputPort output_port =
1511 ctx->graph_view().GetRegularFanin({&func_node, input_idx});
1512
1513 VLOG(3) << "Pin inlined function input node '" << func_body_node->name()
1514 << "' to the '" << output_port.node->device() << "' device.";
1515 func_body_node->set_requested_device(output_port.node->device());
1516 }
1517 }
1518
1519 // ------------------------------------------------------------------------ //
1520 // After placing nodes corresponding to the function inputs, we need to assign
1521 // device placements to all other function body nodes.
1522
1523 const DeviceSet* devices = ctx->devices();
1524
1525 if (devices->devices().empty()) {
1526 // If there are no devices available for placer, we just put all nodes to
1527 // the same device as a function caller node. This can happen if Grappler is
1528 // running "offline", without active runtime session, for example as a part
1529 // of a batch job for graph analysis/optimization.
1530 VLOG(3) << "Assign function call node device to all function body nodes. "
1531 << "Device: " << func_node.device();
1532 for (Node* func_body_node : func_body_graph->nodes()) {
1533 func_body_node->set_requested_device(func_node.device());
1534 }
1535 } else {
1536 // If we are running in an active runtime session, Grappler will get the
1537 // graph after initial placing is done, and we should have devices for the
1538 // placer.
1539 VLOG(3) << "Run placer for instantiated function body. Devices: ["
1540 << absl::StrJoin(
1541 devices->devices(), ", ",
1542 [](string* out, const Device* d) { out->append(d->name()); })
1543 << "]";
1544
1545 // Use function caller node device as a default for placer.
1546 const Device* default_device =
1547 devices->FindDeviceByName(func_node.device());
1548
1549 Placer placer(func_body_graph.get(), devices, default_device);
1550 TF_RETURN_IF_ERROR(placer.Run());
1551 }
1552
1553 // Convert Graph back to the placed GraphDef.
1554 func_body_graph->ToGraphDef(placed_graph_def);
1555
1556 return Status::OK();
1557 }
1558
InlineIndirectFunctionCall(const NodeDef & func_node,const FunctionDef & func,FunctionOptimizerContext * ctx,GraphDef * optimized_graph)1559 Status InlineIndirectFunctionCall(const NodeDef& func_node,
1560 const FunctionDef& func,
1561 FunctionOptimizerContext* ctx,
1562 GraphDef* optimized_graph) {
1563 VLOG(2) << "Inline indirect function call: " << SummarizeNodeDef(func_node);
1564 VLOG(4) << "Inlined function definition: " << DebugString(func);
1565 TF_RETURN_IF_ERROR(IsInlinableIndirectFunctionCall(*ctx, func, func_node));
1566
1567 const AttrSlice func_instantiation_attr =
1568 FunctionInstantiationAttributes(func, func_node);
1569
1570 GrapplerFunctionItem item;
1571 Status item_status = MakeGrapplerFunctionItem(func, func_instantiation_attr,
1572 ctx->function_library(),
1573 ctx->graph_version(), &item);
1574
1575 if (!item_status.ok()) {
1576 return errors::InvalidArgument("Failed to inline function ", func_node.op(),
1577 " instantiated by ", func_node.name(),
1578 ". Error: ", item_status.error_message());
1579 }
1580
1581 // `PartitionedCallOp` invokes functions with `allow_dead_tensors = true` to
1582 // reset dead flag, and return default initialized tensors instead of a dead
1583 // tensors. There is no way to express this in a regular Tensorflow graph, so
1584 // we choose not to inline if a function can have dead tensors as an output
1585 // position. In practice `mergeless switches` should not exists in a function
1586 // body, because tf-eager will only use v2 control flow ops.
1587 std::vector<MaybeDeadOutput> maybe_dead_outputs;
1588 TF_RETURN_IF_ERROR(MaybeDeadOutputs(*ctx, item, &maybe_dead_outputs));
1589 if (!maybe_dead_outputs.empty()) {
1590 struct MaybeDeadOutputFormatter {
1591 void operator()(string* out, const MaybeDeadOutput& md) const {
1592 absl::StrAppend(out, SummarizeNodeDef(*md.dead_tensor_src));
1593 }
1594 };
1595 return errors::FailedPrecondition(
1596 "Can't inline function with dead outputs. Dead tensor sources (size = ",
1597 maybe_dead_outputs.size(), "): ",
1598 absl::StrJoin(maybe_dead_outputs, "\n", MaybeDeadOutputFormatter()));
1599 }
1600
1601 GraphView::InputPort control_input_port =
1602 ctx->graph_view().GetInputPort(func_node.name(), Graph::kControlSlot);
1603 GraphView::OutputPort control_output_port =
1604 ctx->graph_view().GetOutputPort(func_node.name(), Graph::kControlSlot);
1605
1606 // Nodes that have side effects to the captured resources.
1607 std::vector<string> happens_before;
1608 absl::c_transform(
1609 ctx->graph_view().GetFanin(control_input_port),
1610 std::back_inserter(happens_before),
1611 [](const GraphView::OutputPort port) { return port.node->name(); });
1612
1613 VLOG(3) << "Happens before set (size = " << happens_before.size()
1614 << "): " << absl::StrJoin(happens_before, ", ");
1615
1616 // Nodes that must observe side effects to the captured resources.
1617 std::vector<string> happens_after;
1618 absl::c_transform(
1619 ctx->graph_view().GetFanout(control_output_port),
1620 std::back_inserter(happens_after),
1621 [](const GraphView::InputPort port) { return port.node->name(); });
1622
1623 VLOG(3) << "Happens after set (size = " << happens_after.size()
1624 << "): " << absl::StrJoin(happens_after, ", ");
1625
1626 // Regular (data) inputs to the function call.
1627 std::vector<SafeTensorId> inputs;
1628 for (const string& input : func_node.input()) {
1629 SafeTensorId tensor_id = ParseTensorName(input);
1630 if (tensor_id.index() == Graph::kControlSlot) break;
1631 inputs.push_back(tensor_id);
1632 }
1633
1634 // Mapping from input placeholder name to function input position.
1635 absl::flat_hash_map<absl::string_view, int> input_placeholders_idx;
1636 for (const InputArgExpansion& input_arg : item.inputs()) {
1637 for (const string& placeholder : input_arg.placeholders) {
1638 const int idx = input_placeholders_idx.size();
1639 input_placeholders_idx[placeholder] = idx;
1640 }
1641 }
1642
1643 const string prefix = strings::StrCat(func_node.name(), "/");
1644
1645 // ------------------------------------------------------------------------ //
1646 // For each function output value we added an identity node that reads the
1647 // tensor from one of the function body nodes. When we inline function into
1648 // the main graph we want to bypass these nodes, so we keep a mapping from
1649 // 'output node name' -> 'output tensor name'.
1650 absl::flat_hash_map<string, string> output_tensors;
1651
1652 // Unique names of nodes producing tensors in `output_tensors`.
1653 absl::flat_hash_set<string> output_tensors_nodes;
1654
1655 // Identity nodes added to the function body in place of function outputs.
1656 absl::flat_hash_set<string> output_nodes;
1657 for (const OutputArgExpansion& output_arg : item.outputs()) {
1658 for (const string& output_node : output_arg.output_nodes) {
1659 output_nodes.insert(output_node);
1660 }
1661 }
1662
1663 for (const NodeDef& func_body_node : item.graph.node()) {
1664 const string& node_name = func_body_node.name();
1665
1666 if (IsIdentity(func_body_node) && output_nodes.count(node_name)) {
1667 const string& output_tensor = func_body_node.input(0);
1668 output_tensors.emplace(node_name, output_tensor);
1669
1670 SafeTensorId tensor_id = ParseTensorName(output_tensor);
1671 output_tensors_nodes.insert(tensor_id.node());
1672 }
1673 }
1674
1675 // ------------------------------------------------------------------------ //
1676 // IMPORTANT: Actual inputs will be added to the following nodes at the very
1677 // last stage, because we don't want to have invalid edges in a function body
1678 // graph (control edges that depend on the nodes in the "outer" optimized
1679 // graph).
1680
1681 // If one of the function inputs is a dead tensor, we must not execute any of
1682 // the function body nodes, and let the dead tensor flag propagate through the
1683 // inlined function body. We add NoOp inputs_ready node, and add control edges
1684 // to it from all input nodes. Inlined function arguments (Identity nodes)
1685 // will have a control dependency on it.
1686 //
1687 // TODO(ezhulenev): We do not need to provide this guarantee for ALL nodes in
1688 // the function body. We must only ensure that we do not generate observable
1689 // side effects.
1690 //
1691 // If the function call node has incoming control edges, we will update them
1692 // to use this node as destination, to ensure side-effects execution order.
1693 NodeDef* inputs_ready_node = nullptr;
1694 if (func_node.input_size() > 0) {
1695 inputs_ready_node = item.graph.add_node();
1696 inputs_ready_node->set_op("NoOp");
1697 inputs_ready_node->set_name(kInputsReadyNodeName);
1698 }
1699
1700 // All nodes that have a control edge from the function call node, will be
1701 // updated to have a control edge from 'side_effects_executed_node`. This node
1702 // will have control edges from all function control outputs (see
1703 // `control_ret` in FunctionDef). This a "barrier" that guarantees that all
1704 // ops with side effects in the function body were executed
1705 //
1706 // If the function call node has no outgoing control edges, it means that no
1707 // one is interested in the function side-effect affecting captured resources.
1708 //
1709 // If node is in keep_ops set, it means that it must execute. This could
1710 // happen if the graph is an instantiation of a function with control output.
1711 NodeDef* side_effects_executed_node = nullptr;
1712 if (!happens_after.empty() || ctx->IsKeepOp(func_node.name())) {
1713 side_effects_executed_node = item.graph.add_node();
1714 side_effects_executed_node->set_op("NoOp");
1715 side_effects_executed_node->set_name(kSideEffectsExecutedNodeName);
1716 }
1717
1718 // If function executed only for the regular data outputs, it's totally safe
1719 // to prune side-effects. If side-effects order is important, it must be
1720 // captured at graph construction time via control edges.
1721 if (item.control_output_size() > 0 && happens_after.empty()) {
1722 VLOG(2) << "Function has control outputs and empty happens after set.";
1723 }
1724
1725 // ------------------------------------------------------------------------ //
1726 // If we have a node inside the function body without inputs (e.g. Const), we
1727 // must attach a control dependency to it, to make sure that if a function
1728 // call happens inside a loop, the node will be evaluated in correct frame.
1729 //
1730 // If the function call node has no inputs and no control dependencies, it
1731 // means that it can't be a function call inside a loop, and we can safely
1732 // insert that node without inputs into the main graph.
1733 //
1734 // TODO(ezhulenev): Use FrameMap (see grappler/utils/frame.h) to find out if
1735 // the function is called inside a loop.
1736 std::vector<string> empty_inputs_hook;
1737 if (inputs_ready_node != nullptr) {
1738 empty_inputs_hook.push_back(inputs_ready_node->name());
1739 }
1740
1741 // ------------------------------------------------------------------------ //
1742 // Grappler called after PRE_PLACEMENT and PLACEMENT passes, so we have to
1743 // make sure that after inlining all nodes will have valid device assignment.
1744
1745 GraphDef placed_graph_def;
1746 TF_RETURN_IF_ERROR(PlaceInlinedFunctionBody(
1747 func_node, item, input_placeholders_idx, ctx, &placed_graph_def));
1748
1749 // ------------------------------------------------------------------------ //
1750 // After all nodes placed we need to prepare them for inlining into the
1751 // optimized graph: turn placeholders into identities, update nodes
1752 // connectivity, etc...
1753
1754 const auto inlined_node_name = [&func_node](const string& name) -> string {
1755 return AddPrefixToNodeName(name, /*prefix=*/func_node.name());
1756 };
1757
1758 for (NodeDef& func_body_node : *placed_graph_def.mutable_node()) {
1759 const string& node_name = func_body_node.name();
1760
1761 // Turn placeholders added in place of input arguments into identity nodes.
1762 const auto input_placeholder_idx = input_placeholders_idx.find(node_name);
1763 if (input_placeholder_idx != input_placeholders_idx.end()) {
1764 DCHECK_EQ(0, func_body_node.input_size());
1765 func_body_node.set_op("Identity");
1766 (*func_body_node.mutable_attr())["T"] = func_body_node.attr().at("dtype");
1767 func_body_node.mutable_attr()->erase("dtype");
1768 func_body_node.mutable_attr()->erase("shape");
1769 const int input_idx = input_placeholder_idx->second;
1770 func_body_node.add_input(inputs[input_idx].ToString());
1771
1772 // Add a control dependency on 'inputs_ready' node, to guarantee that all
1773 // inputs are alive and all side-effects executed before function body.
1774 if (inputs_ready_node) {
1775 func_body_node.add_input(
1776 AsControlDependency(inlined_node_name(inputs_ready_node->name())));
1777 }
1778 } else {
1779 // Update inputs of the regular function body nodes.
1780 for (string& input : *func_body_node.mutable_input()) {
1781 input = inlined_node_name(input);
1782 }
1783
1784 // Check if we need to ensure node execution in correct loop frame.
1785 bool node_needs_empty_inputs_hook =
1786 // We have a node to hook and node has no inputs.
1787 !empty_inputs_hook.empty() && func_body_node.input_size() == 0 &&
1788 // Inputs ready node will always have edge from main graph. If
1789 // function call has no regular and control inputs, we will not add
1790 // inputs_ready node to the function body graph.
1791 node_name != kInputsReadyNodeName &&
1792 // The node acting as a return barrier for execution of side effects
1793 // might not have any inputs (in case function has no control outputs,
1794 // but we still added it because of non-empty happens-after set), so
1795 // we must make sure it's executed in correct frame.
1796 (node_name != kSideEffectsExecutedNodeName ||
1797 item.control_output_size() == 0);
1798
1799 if (node_needs_empty_inputs_hook) {
1800 *func_body_node.add_input() =
1801 AsControlDependency(inlined_node_name(empty_inputs_hook[0]));
1802 }
1803 }
1804
1805 // Add the function node name as a prefix 1) to node name to avoid
1806 // collisions; 2) to frame name to avoid multiple LoopCond nodes in one
1807 // frame after inlining.
1808 TF_RETURN_IF_ERROR(
1809 AddPrefixAndSuffixToNode(prefix, /*suffix=*/"", &func_body_node));
1810
1811 // After inlining into the optimized graph, NodeDef must have all attributes
1812 // defined, which is not required for a node in a FunctionDef.
1813 const OpDef* op_def;
1814 TF_RETURN_IF_ERROR(
1815 ctx->function_library().LookUpOpDef(func_body_node.op(), &op_def));
1816 AddDefaultsToNodeDef(*op_def, &func_body_node);
1817 }
1818
1819 // ------------------------------------------------------------------------ //
1820 // Check that after inlining all side-effects will be executed in well defined
1821 // order. We do it by checking if there is a path from stateful/dataset ops to
1822 // one of the output nodes.
1823
1824 // Because we rename all the nodes before inlining, we need a copy of
1825 // output_nodes with a new names.
1826 absl::flat_hash_set<string> inlined_output_nodes;
1827 for (const string& output_node : output_nodes) {
1828 inlined_output_nodes.insert(inlined_node_name(output_node));
1829 }
1830 const auto is_inlined_output_node = [&](const NodeDef& node) -> bool {
1831 return inlined_output_nodes.find(node.name()) != inlined_output_nodes.end();
1832 };
1833
1834 // Names of the inlined control output nodes.
1835 absl::flat_hash_set<string> inlined_control_output_nodes;
1836 for (const ControlOutput& control_output : item.control_outputs()) {
1837 inlined_control_output_nodes.insert(
1838 inlined_node_name(control_output.node_name));
1839 }
1840
1841 // Construct a graph topology view for DFS traversals (skip invalid edges for
1842 // input nodes connected to nodes in the optimized graph).
1843 GraphTopologyView placed_topo_view(/*skip_invalid_edges=*/true);
1844 TF_RETURN_IF_ERROR(placed_topo_view.InitializeFromGraph(placed_graph_def));
1845 TF_RETURN_IF_ERROR(CheckThatSideEffectsWillExecute(
1846 *ctx, placed_topo_view, inlined_control_output_nodes));
1847
1848 // ------------------------------------------------------------------------ //
1849 // Move all the nodes to the optimized graph after successful preprocessing.
1850
1851 if (inputs_ready_node != nullptr) {
1852 string inlined_node = inlined_node_name(inputs_ready_node->name());
1853 absl::optional<int> node_idx = placed_topo_view.GetNodeIndex(inlined_node);
1854
1855 absl::flat_hash_set<string> input_nodes;
1856 for (const string& input : func_node.input()) {
1857 SafeTensorId tensor = ParseTensorName(input);
1858
1859 // Input node might have been a function call that was already inlined.
1860 auto it = ctx->tensor_mapping().find(tensor);
1861 while (it != ctx->tensor_mapping().end()) {
1862 tensor = it->second;
1863 it = ctx->tensor_mapping().find(tensor);
1864 }
1865
1866 if (input_nodes.insert(tensor.node()).second) {
1867 placed_graph_def.mutable_node(*node_idx)->add_input(
1868 AsControlDependency(tensor.node()));
1869 }
1870 }
1871 }
1872
1873 if (side_effects_executed_node != nullptr) {
1874 string inlined_node = inlined_node_name(side_effects_executed_node->name());
1875 absl::optional<int> node_idx = placed_topo_view.GetNodeIndex(inlined_node);
1876
1877 // Add control edges from all control output nodes.
1878 for (const string& node_name : inlined_control_output_nodes) {
1879 placed_graph_def.mutable_node(*node_idx)->add_input(
1880 AsControlDependency(node_name));
1881 }
1882
1883 // Forward all control dependencies in the optimized graph to the new node.
1884 ctx->AddControlOverrides(func_node, {inlined_node});
1885 }
1886
1887 for (NodeDef& func_body_node : *placed_graph_def.mutable_node()) {
1888 // Skip output identity nodes.
1889 if (IsIdentity(func_body_node) && is_inlined_output_node(func_body_node))
1890 continue;
1891
1892 optimized_graph->add_node()->Swap(&func_body_node);
1893 }
1894
1895 // Indirect function call is fully inlined into the optimized graph, and we do
1896 // not copy the original function call node, so we have to setup tensor
1897 // mapping from old output tensors, to the outputs of inlined nodes.
1898 int output_idx = 0;
1899 for (const OutputArgExpansion& output : item.outputs()) {
1900 for (const string& output_node : output.output_nodes) {
1901 const string& output_tensor = output_tensors.at(output_node);
1902
1903 const SafeTensorId from_tensor(func_node.name(), output_idx++);
1904 const SafeTensorId to_tensor = ParseTensorName(output_tensor);
1905
1906 const SafeTensorId inlined_to_tensor =
1907 SafeTensorId(absl::StrCat(func_node.name(), "/", to_tensor.node()),
1908 to_tensor.index());
1909
1910 ctx->AddTensorMapping(from_tensor, inlined_to_tensor);
1911 }
1912 }
1913
1914 // If function call node was in keep_ops set, it means that we need to keep a
1915 // node with the same name in the optimized graph. We forward all data
1916 // consumers to inlined nodes, and we verify that the node is not in a fetch
1917 // set, so it's safe to assume that the function call node is only required
1918 // for a control edge source.
1919 if (ctx->IsKeepOp(func_node.name())) {
1920 VLOG(4) << "Add NoOp for inlined function in keep ops set.";
1921 NodeDef* keep_func_node = optimized_graph->add_node();
1922 keep_func_node->set_op("NoOp");
1923 keep_func_node->set_name(func_node.name());
1924 keep_func_node->set_device(func_node.device());
1925 keep_func_node->add_input(
1926 AsControlDependency(inlined_node_name(kSideEffectsExecutedNodeName)));
1927 }
1928
1929 VLOG(3) << "Successfully inlined indirect function call: "
1930 << SummarizeNodeDef(func_node);
1931
1932 return Status::OK();
1933 }
1934
1935 // Restores graph invariants after function specialization and inlining: all
1936 // inputs must be connected to valid nodes.
RestoreGraphInvariants(const FunctionOptimizerContext & ctx,GraphDef * optimized_graph)1937 Status RestoreGraphInvariants(const FunctionOptimizerContext& ctx,
1938 GraphDef* optimized_graph) {
1939 // After function specialization and inlining graph might be in invalid
1940 // state, and some nodes can read tensors that do not exists anymore in the
1941 // optimized graph: function call node was fully inlined into the graph, or
1942 // output index was invalidated by the output pruning.
1943
1944 if (!ctx.tensor_mapping().empty()) {
1945 for (NodeDef& node : *optimized_graph->mutable_node()) {
1946 for (int idx = 0; idx < node.input_size(); ++idx) {
1947 TensorId input_tensor = ParseTensorName(node.input(idx));
1948 if (input_tensor.index() == Graph::kControlSlot) break;
1949
1950 auto mapping = ctx.tensor_mapping().find(input_tensor);
1951 if (mapping != ctx.tensor_mapping().end()) {
1952 node.set_input(idx, mapping->second.ToString());
1953 }
1954 }
1955 }
1956 }
1957
1958 // Function inlining instantiates function body directly into the optimized
1959 // graph, and we might end up with control dependencies to the nodes that no
1960 // longer exist in a graph. We need to apply control overrides to all
1961 // invalidated nodes, and rewire control dependencies to the control outputs
1962 // node (it's also possible to rewrite singe control edge into multiple edges
1963 // to inlined side-effectful nodes).
1964
1965 if (!ctx.control_overrides().empty()) {
1966 for (NodeDef& node : *optimized_graph->mutable_node()) {
1967 // Keep track of new control inputs to the node.
1968 absl::flat_hash_set<string> add_ctrl_inputs;
1969
1970 // Remove all invalidated control inputs.
1971 for (int idx = 0; idx < node.input_size(); /* see below */) {
1972 // TODO(ezhulenev): Use non-allocating TensorId after migrating
1973 // `control_overrides()` to absl::flat_hash_set.
1974 SafeTensorId input_tensor = ParseTensorName(node.input(idx));
1975
1976 auto overrides = ctx.control_overrides().find(input_tensor.node());
1977 if (overrides != ctx.control_overrides().end()) {
1978 // If this happens it's a bug in the function inlining.
1979 if (input_tensor.index() != Graph::kControlSlot) {
1980 return errors::Internal(
1981 "Illegal input edge from inlined function call node");
1982 }
1983 // Remove control dependency to the inlined function call node.
1984 node.mutable_input()->SwapElements(idx, node.input_size() - 1);
1985 node.mutable_input()->RemoveLast();
1986
1987 // Keep track of all overrides.
1988 for (const string& override : overrides->second) {
1989 add_ctrl_inputs.insert(AsControlDependency(override));
1990 }
1991 } else {
1992 // Go to the next input only if the current one was not invalidated,
1993 // otherwise we need to check the swapped input as well.
1994 ++idx;
1995 }
1996 }
1997
1998 // Add overrides to the node inputs.
1999 for (const string& ctrl_input : add_ctrl_inputs) {
2000 node.add_input(ctrl_input);
2001 }
2002 }
2003 }
2004
2005 return Status::OK();
2006 }
2007
2008 } // namespace
2009
RunFunctionOptimizerPass(const GrapplerItem & item,const GraphDef & graph,const int iteration,std::unordered_set<string> * skip_nodes,GraphDef * optimized_graph,bool * graph_has_unoptimized_function_calls) const2010 Status FunctionOptimizer::RunFunctionOptimizerPass(
2011 const GrapplerItem& item, const GraphDef& graph, const int iteration,
2012 std::unordered_set<string>* skip_nodes, GraphDef* optimized_graph,
2013 bool* graph_has_unoptimized_function_calls) const {
2014 VLOG(3) << absl::Substitute(
2015 "Run function optimizer pass (iteration = $0): grappler_item_id = $1",
2016 iteration, item.id);
2017
2018 FunctionOptimizerContext ctx(item, opt_level_, graph);
2019
2020 bool inline_gradients = options_.enable_symbolic_gradient_inlining;
2021 bool inline_func = options_.enable_function_inlining;
2022 bool specialize_func = options_.enable_function_specialization;
2023
2024 // We will process all the nodes in topological order, to correctly handle
2025 // inlining of function call chains.
2026 std::vector<const NodeDef*> topo_ordered_nodes;
2027 TF_RETURN_IF_ERROR(ComputeTopologicalOrder(graph, &topo_ordered_nodes));
2028
2029 for (const NodeDef* node : topo_ordered_nodes) {
2030 // Each node optimization can modify optimized graph only by adding new
2031 // nodes, we can check node size to make sure that graph was not modified.
2032 const int num_nodes_before = optimized_graph->node_size();
2033 const auto is_graph_modified = [&]() {
2034 int num_nodes = optimized_graph->node_size();
2035 DCHECK_GE(num_nodes, num_nodes_before) << "Nodes should not be removed";
2036 return num_nodes > num_nodes_before;
2037 };
2038
2039 // Copy node from the `graph` to the `optimized_graph`.
2040 const auto copy_node = [&]() { *optimized_graph->add_node() = *node; };
2041
2042 // If we already failed to optimize this node during one of the previous
2043 // passes, we just give up, and do not try on more time.
2044 if (skip_nodes->find(node->name()) != skip_nodes->end()) {
2045 VLOG(3) << "Skip optimization for node: " << node->name();
2046 copy_node();
2047 continue;
2048 }
2049
2050 // Skip errors if optimized graph was not modified before error happened.
2051 #define TF_SKIP_ERROR_IF_GRAPH_UNMODIFIED(...) \
2052 do { \
2053 const Status _status = (__VA_ARGS__); \
2054 if (TF_PREDICT_FALSE(!_status.ok() && is_graph_modified())) \
2055 return _status; \
2056 if (TF_PREDICT_FALSE(!_status.ok() && !is_graph_modified())) { \
2057 VLOG(3) << "Skip error: " << _status.error_message(); \
2058 skip_nodes->insert(node->name()); \
2059 copy_node(); \
2060 } \
2061 } while (0)
2062
2063 // ---------------------------------------------------------------------- //
2064 // 1. Inline symbolic gradients into the optimized graph. //
2065 // ---------------------------------------------------------------------- //
2066
2067 if (IsSymbolicGradient(*node) && inline_gradients) {
2068 // Inline symbolic gradients only if the corresponding function is not
2069 // marked as `_noinline`.
2070 const auto* f_attr = gtl::FindOrNull(node->attr(), "f");
2071 const string f_name = f_attr != nullptr ? f_attr->func().name() : "";
2072 const FunctionDef* func = ctx.function_library().Find(f_name);
2073 if (func && !MarkedNoInline(*func)) {
2074 TF_SKIP_ERROR_IF_GRAPH_UNMODIFIED(
2075 InlineSymbolicGradient(*node, &ctx, optimized_graph));
2076 continue;
2077 } else {
2078 VLOG(2) << "Skip SymbolicGradient inlining: function=" << f_name;
2079 skip_nodes->insert(node->name());
2080 }
2081 }
2082
2083 // ---------------------------------------------------------------------- //
2084 // 2. Inline or specialize function calls. //
2085 // ---------------------------------------------------------------------- //
2086
2087 // Find if a node is a function call (direct or indirect).
2088 const FunctionDef* func = FindFunctionCall(ctx, *node);
2089
2090 if (func != nullptr) {
2091 const string& func_name = func->signature().name();
2092
2093 const bool is_direct_func = IsDirectFunctionCall(*func, *node);
2094 const bool is_indirect_func = IsIndirectFunctionCall(*func, *node);
2095
2096 // 2a. Inline direct function call if it's inlinable.
2097 if (inline_func && is_direct_func) {
2098 Status inlinable = IsInlinableDirectFunctionCall(ctx, *func, *node);
2099 if (inlinable.ok()) {
2100 TF_SKIP_ERROR_IF_GRAPH_UNMODIFIED(
2101 InlineDirectFunctionCall(*node, *func, ctx, optimized_graph));
2102 continue;
2103 } else {
2104 VLOG(2) << inlinable.error_message();
2105 skip_nodes->insert(node->name());
2106 }
2107 }
2108
2109 // 2b. Inline indirect function call if it's inlinable.
2110 if (inline_func && is_indirect_func) {
2111 Status inlinable = IsInlinableIndirectFunctionCall(ctx, *func, *node);
2112 if (inlinable.ok()) {
2113 TF_SKIP_ERROR_IF_GRAPH_UNMODIFIED(
2114 InlineIndirectFunctionCall(*node, *func, &ctx, optimized_graph));
2115 continue;
2116 } else {
2117 VLOG(2) << inlinable.error_message();
2118 skip_nodes->insert(node->name());
2119 }
2120 }
2121
2122 // 2c. Specialize it to its instantiation context if can't be inlined,
2123 // and it has something worth specializing.
2124 bool specialization_worthy = IsParametrized(*func) ||
2125 HasTrulyConstInputs(*node, ctx) ||
2126 HasUnusedOutputs(*node, *func, ctx);
2127
2128 // Do not specialize if function has custom gradient.
2129 const string grad_func = ctx.function_library().FindGradient(func_name);
2130
2131 if (specialize_func && grad_func.empty() && specialization_worthy) {
2132 // TODO(ezhulenev): Specialize function call if input has a known shape.
2133 // Specialize function body for its instantiation attributes and inputs.
2134 TF_SKIP_ERROR_IF_GRAPH_UNMODIFIED(
2135 SpecializeFunction(*node, *func, &ctx, optimized_graph));
2136 continue;
2137 } else {
2138 VLOG(2) << "Skip function specialization: " << func->signature().name();
2139 skip_nodes->insert(node->name());
2140 }
2141 }
2142
2143 // ---------------------------------------------------------------------- //
2144 // If we reached this point, node was not handled by any of the stages
2145 // (inline, specialize), simply copy the node to the optimized graph.
2146 copy_node();
2147
2148 #undef TF_SKIP_ERROR_IF_GRAPH_UNMODIFIED
2149 }
2150
2151 TF_RETURN_IF_ERROR(RestoreGraphInvariants(ctx, optimized_graph));
2152
2153 // Preserve the graph version.
2154 *optimized_graph->mutable_versions() = graph.versions();
2155
2156 // Prune unreachable function from the library.
2157 if (options_.enable_trim_function_library) {
2158 *optimized_graph->mutable_library() =
2159 PruneFunctionLibrary(ctx.function_library(), *optimized_graph);
2160 } else {
2161 *optimized_graph->mutable_library() = ctx.function_library().ToProto();
2162 }
2163
2164 // Before returning we check if after single optimization pass we have more
2165 // unoptimized function calls.
2166 *graph_has_unoptimized_function_calls = false;
2167 for (const NodeDef& node : optimized_graph->node()) {
2168 // Check if we can inline symbolic gradient.
2169 if (IsSymbolicGradient(node) && inline_gradients &&
2170 skip_nodes->count(node.name()) == 0) {
2171 *graph_has_unoptimized_function_calls = true;
2172 break;
2173 }
2174
2175 // Check if after inlining we have unoptimized function calls.
2176 const FunctionDef* func = FindFunctionCall(ctx, node);
2177 if (func != nullptr && !MarkedSpecialized(*func) &&
2178 skip_nodes->count(node.name()) == 0) {
2179 *graph_has_unoptimized_function_calls = true;
2180 break;
2181 }
2182 }
2183
2184 return Status::OK();
2185 }
2186
Optimize(Cluster *,const GrapplerItem & item,GraphDef * optimized_graph)2187 Status FunctionOptimizer::Optimize(Cluster*, const GrapplerItem& item,
2188 GraphDef* optimized_graph) {
2189 // Nothing to do here.
2190 if (item.graph.library().function_size() == 0) {
2191 *optimized_graph = item.graph;
2192 return Status::OK();
2193 }
2194
2195 // Do not retry failed function inlining or specialization.
2196 std::unordered_set<string> skip_nodes;
2197 bool graph_has_unoptimized_function_calls = false;
2198
2199 // We'll keep running function optimizer pass until we inlined and optimized
2200 // all function call nodes.
2201 int iteration = 0;
2202 constexpr int kMaxIterations = 50;
2203
2204 // 1. Run first optimizer pass with GrapplerItem.graph.
2205 TF_RETURN_IF_ERROR(RunFunctionOptimizerPass(
2206 item, item.graph, 0, &skip_nodes, optimized_graph,
2207 &graph_has_unoptimized_function_calls));
2208
2209 // 2. If after function inlining we have unoptimized function calls, we have
2210 // to run function optimization pass one more time.
2211 while (graph_has_unoptimized_function_calls) {
2212 if (iteration++ > kMaxIterations) {
2213 VLOG(1) << "Break function optimizer loop at iteration #" << iteration;
2214 break;
2215 }
2216
2217 GraphDef workspace_graph;
2218 workspace_graph.Swap(optimized_graph);
2219
2220 TF_RETURN_IF_ERROR(RunFunctionOptimizerPass(
2221 item, workspace_graph, iteration, &skip_nodes, optimized_graph,
2222 &graph_has_unoptimized_function_calls));
2223 }
2224
2225 return Status::OK();
2226 }
2227
Feedback(Cluster * cluster,const GrapplerItem & item,const GraphDef & optimized_graph,double result)2228 void FunctionOptimizer::Feedback(Cluster* cluster, const GrapplerItem& item,
2229 const GraphDef& optimized_graph,
2230 double result) {
2231 // Nothing to do for FunctionOptimizer.
2232 }
2233
2234 } // end namespace grappler
2235 } // end namespace tensorflow
2236