1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/grappler/optimizers/function_optimizer.h"
17 
18 #include <vector>
19 
20 #include "absl/algorithm/container.h"
21 #include "absl/container/flat_hash_map.h"
22 #include "absl/container/flat_hash_set.h"
23 #include "absl/memory/memory.h"
24 #include "absl/strings/str_replace.h"
25 #include "absl/strings/substitute.h"
26 #include "tensorflow/core/common_runtime/device.h"
27 #include "tensorflow/core/common_runtime/device_mgr.h"
28 #include "tensorflow/core/common_runtime/device_set.h"
29 #include "tensorflow/core/common_runtime/function.h"
30 #include "tensorflow/core/common_runtime/lower_if_while.h"
31 #include "tensorflow/core/common_runtime/optimization_registry.h"
32 #include "tensorflow/core/common_runtime/placer.h"
33 #include "tensorflow/core/common_runtime/process_function_library_runtime.h"
34 #include "tensorflow/core/framework/attr_value_util.h"
35 #include "tensorflow/core/framework/function.h"
36 #include "tensorflow/core/framework/function.pb.h"
37 #include "tensorflow/core/framework/graph_def_util.h"
38 #include "tensorflow/core/framework/node_def.pb.h"
39 #include "tensorflow/core/framework/node_def_util.h"
40 #include "tensorflow/core/framework/op_def.pb.h"
41 #include "tensorflow/core/framework/versions.pb.h"
42 #include "tensorflow/core/graph/graph_constructor.h"
43 #include "tensorflow/core/graph/tensor_id.h"
44 #include "tensorflow/core/grappler/graph_topology_view.h"
45 #include "tensorflow/core/grappler/grappler_item.h"
46 #include "tensorflow/core/grappler/mutable_graph_view.h"
47 #include "tensorflow/core/grappler/op_types.h"
48 #include "tensorflow/core/grappler/utils.h"
49 #include "tensorflow/core/grappler/utils/functions.h"
50 #include "tensorflow/core/grappler/utils/topological_sort.h"
51 #include "tensorflow/core/grappler/utils/traversal.h"
52 #include "tensorflow/core/lib/gtl/map_util.h"
53 
54 namespace tensorflow {
55 namespace grappler {
56 namespace {
57 
58 // WARNING: Code in this file implicitly assumes that function input and output
59 // arguments are plain tensors (tensor lists are not supported). Function inputs
60 // and outputs are always expanded to a single placeholder or output tensor.
61 // With this assumption, the calling node's input/output ports always match
62 // function input/output arguments.
63 //
64 // This is guaranteed by the implementation of MakeGrapplerFunctionItem.
65 
66 // Mark functions that were created as a result of function specialization.
67 constexpr char kGrapplerSpecializedFuncAttr[] = "_GrapplerSpecializedFunc";
68 
69 // Name of the attribute that defines the function for indirect function calls.
70 constexpr char kFuncAttrName[] = "f";
71 
72 constexpr char kNoInlineAttr[] = "_noinline";
73 
74 // Name of the node that will have control edges from function input nodes, and
75 // also used as a new destination for incoming control edges.
76 constexpr char kInputsReadyNodeName[] = "inputs_ready";
77 
78 // Name of the node that will have control edges from function control output
79 // nodes, and also used as a new source of outgoing control edges. This node
80 // will guarantee that all side-effects inside function body will be executed
81 // after function inlining.
82 constexpr char kSideEffectsExecutedNodeName[] = "side_effects_executed";
83 
AttrIsTrue(const FunctionDef & func,const string & attr)84 bool AttrIsTrue(const FunctionDef& func, const string& attr) {
85   return func.attr().count(attr) != 0 && func.attr().at(attr).b();
86 }
87 
MarkedSpecialized(const FunctionDef & func)88 bool MarkedSpecialized(const FunctionDef& func) {
89   return AttrIsTrue(func, kGrapplerSpecializedFuncAttr);
90 }
91 
MarkedNoInline(const FunctionDef & func)92 bool MarkedNoInline(const FunctionDef& func) {
93   return AttrIsTrue(func, kNoInlineAttr);
94 }
95 
96 // There are two ways of calling a Tensorflow function:
97 //
98 // 1. Direct function call: node.op() is the name of the function.
99 //
100 // 2. Indirect function call: the function name is passed through a node
101 //    attribute, and special Tensorflow kernels are responsible for calling the
102 //    function through the FunctionLibraryRuntime. Example: PartitionedCallOp.
103 
104 // Check if func_node.op() matches the name in FunctionDef signature.
IsDirectFunctionCall(const FunctionDef & func,const NodeDef & func_node)105 bool IsDirectFunctionCall(const FunctionDef& func, const NodeDef& func_node) {
106   return func_node.op() == func.signature().name();
107 }
108 
109 // Check if func_node has function attribute with a function name matching
110 // FunctionDef signature.
IsIndirectFunctionCall(const FunctionDef & func,const NodeDef & func_node)111 bool IsIndirectFunctionCall(const FunctionDef& func, const NodeDef& func_node) {
112   auto* func_attr = AttrSlice(func_node).Find(kFuncAttrName);
113   return func_attr != nullptr && func_attr->has_func() &&
114          func_attr->func().name() == func.signature().name();
115 }
116 
FunctionInstantiationAttributes(const FunctionDef & func,const NodeDef & func_node)117 AttrSlice FunctionInstantiationAttributes(const FunctionDef& func,
118                                           const NodeDef& func_node) {
119   if (IsDirectFunctionCall(func, func_node)) {
120     return AttrSlice(func_node);
121 
122   } else if (IsIndirectFunctionCall(func, func_node)) {
123     auto* func_attr = AttrSlice(func_node).Find(kFuncAttrName);
124     return AttrSlice(&func_attr->func().attr());
125 
126   } else {
127     LOG(WARNING) << "Can't resolve function instantiation attributes: "
128                  << SummarizeNodeDef(func_node);
129     return AttrSlice();
130   }
131 }
132 
133 // This is a fake device that should not be used for any op kernel execution,
134 // the only purpose of this device is to be passed as a part of DeviceSet to the
135 // Placer.
136 class FakeDevice : public Device {
137  public:
FakeDevice(Env * env,const string & device)138   FakeDevice(Env* env, const string& device) : Device(env, attr(device)) {}
FakeDevice(const string & device)139   explicit FakeDevice(const string& device) : FakeDevice(nullptr, device) {}
Sync()140   Status Sync() override { return Status::OK(); }
141 
142  private:
attr(const string & device)143   static DeviceAttributes attr(const string& device) {
144     DeviceNameUtils::ParsedName parsed_name;
145     bool parsed = DeviceNameUtils::ParseFullName(device, &parsed_name);
146     DCHECK(parsed) << "Failed to parse full device name: " << device;
147 
148     DeviceAttributes attr;
149     attr.set_name(device);
150     attr.set_device_type(parsed_name.type);
151     return attr;
152   }
153 };
154 
155 // -------------------------------------------------------------------------- //
156 // Function specialization.
157 //
158 // FunctionDef is somewhat similar to function template in C++, given all the
159 // type parameters (and attribute values) it generates a statically defined
160 // graph from the type parametrized "graph template" (function body).
161 //
162 // Function specialization instantiates a parametrized FunctionDef into a
163 // statically defined graph, and then converts it back to the fully defined
164 // FunctionDef (it doesn't have any unknown type parameters or attribute
165 // values, known as placeholders).
166 //
167 // Given the fully specified graph we can apply all the Grappler optimizers to
168 // it (see details in MetaOptimizer). Also we can push known constant inputs
169 // into the function body, and remove unused outputs/inputs.
170 
171 // Specialized function instantiation type parameters, body parameters, and
172 // const inputs.
173 struct FunctionSpecializationSignature {
174   // Currently we do not support functions with tensor lists as inputs or
175   // outputs, so caller node input/output ports always match function
176   // input/output arguments.
177   using InputPort = int;
178   using OutputPort = int;
179 
180   string func_name;
181   bool is_in_fetch_set;
182   absl::flat_hash_set<OutputPort> active_outputs;
183   absl::flat_hash_map<string, DataType> type_parameters;
184   absl::flat_hash_map<string, AttrValue> body_parameters;
185   absl::flat_hash_map<InputPort, string> const_inputs;
186 
operator ==tensorflow::grappler::__anondd39c0ba0111::FunctionSpecializationSignature187   bool operator==(const FunctionSpecializationSignature& other) const {
188     bool equals = func_name == other.func_name &&
189                   is_in_fetch_set == other.is_in_fetch_set &&
190                   active_outputs == other.active_outputs &&
191                   type_parameters == other.type_parameters &&
192                   const_inputs == other.const_inputs;
193 
194     if (!equals) return false;
195 
196     // Equality is not defined for AttrValue.
197     if (body_parameters.size() != other.body_parameters.size()) return false;
198 
199     for (const auto& lhs : body_parameters) {
200       auto it = other.body_parameters.find(lhs.first);
201       if (it == other.body_parameters.end()) return false;
202       if (!FastAreAttrValuesEqual(lhs.second, (*it).second)) return false;
203     }
204 
205     return true;
206   }
207 
208   template <typename H>
AbslHashValue(H h,const FunctionSpecializationSignature & s)209   friend H AbslHashValue(H h, const FunctionSpecializationSignature& s) {
210     H base = H::combine(std::move(h), s.func_name, s.is_in_fetch_set);
211 
212     // First pre-compute hashes for all values in collections with
213     // non-deterministic iteration order.
214     std::vector<uint64> hashes;
215     hashes.reserve(s.active_outputs.size()         //
216                    + s.type_parameters.size() * 2  //
217                    + s.body_parameters.size() * 2  //
218                    + s.const_inputs.size() * 2);
219 
220     absl::c_transform(s.active_outputs, std::back_inserter(hashes),
221                       hash<OutputPort>());
222 
223     using TypeParam = std::pair<const string, DataType>;
224     absl::c_for_each(s.type_parameters, [&hashes](const TypeParam& type_param) {
225       AttrValue attr_value;
226       attr_value.set_type(type_param.second);
227       hashes.push_back(Hash64(type_param.first));
228       hashes.push_back(AttrValueHash(attr_value));
229     });
230 
231     using BodyParam = std::pair<const string, AttrValue>;
232     absl::c_for_each(s.body_parameters, [&hashes](const BodyParam& body_param) {
233       hashes.push_back(Hash64(body_param.first));
234       hashes.push_back(FastAttrValueHash(body_param.second));
235     });
236 
237     using ConstInput = std::pair<const InputPort, string>;
238     absl::c_for_each(s.const_inputs, [&hashes](const ConstInput& const_input) {
239       hashes.push_back(hash<InputPort>()(const_input.first));
240       hashes.push_back(Hash64(const_input.second));
241     });
242 
243     // Combine all pre-computed hashes in a deterministic order.
244     absl::c_sort(hashes);
245     return H::combine_contiguous(std::move(base), hashes.data(), hashes.size());
246   }
247 };
248 
249 struct FunctionSpecialization {
250   string specialized_func_name;
251   // True if the function caller node is in GrapplerItem fetch set.
252   bool is_in_fetch_set;
253   // Names of the tensors that were pushed down into the function body.
254   absl::flat_hash_set<string> const_inputs;
255   // Control dependencies of pushed down const inputs have to be attached to
256   // function caller node.
257   absl::flat_hash_set<string> control_deps;
258   // Output tensors (ports) that consumed by other nodes in the graph or in a
259   // GrapplerItem fetch set.
260   absl::flat_hash_set<int> active_outputs;
261   // Mapping from original function output port to the output port of
262   // specialized function. If function specialization changes the number of
263   // function outputs it's required to update all node consumers.
264   std::vector<std::pair<int, int>> output_mapping;
265 };
266 
267 // Function optimizer context initialized once for each optimization pass, and
268 // it uses the latest available graph (for the first iteration it will be the
269 // GrapplerItem.graph, for next iterations it will be the output of previous
270 // function optimizer pass).
271 class FunctionOptimizerContext {
272  public:
FunctionOptimizerContext(const GrapplerItem & item,RewriterConfig::Toggle opt_level,const GraphDef & graph)273   explicit FunctionOptimizerContext(const GrapplerItem& item,
274                                     RewriterConfig::Toggle opt_level,
275                                     const GraphDef& graph)
276       : item_(&item),
277         opt_level_(opt_level),
278         function_library_(OpRegistry::Global(), graph.library()),
279         truly_const_nodes_(InferTrulyConstNodes(item, graph)),
280         graph_view_(&graph) {}
281 
item() const282   const GrapplerItem& item() const { return *item_; }
283 
graph_version() const284   const int graph_version() const { return item_->graph.versions().producer(); }
285 
opt_level() const286   RewriterConfig::Toggle opt_level() const { return opt_level_; }
287 
function_library() const288   const FunctionLibraryDefinition& function_library() const {
289     return function_library_;
290   }
291 
mutable_function_library()292   FunctionLibraryDefinition* mutable_function_library() {
293     return &function_library_;
294   }
295 
mutable_function_library_runtime()296   FunctionLibraryRuntime* mutable_function_library_runtime() {
297     InitializeFunctionLibraryRuntime();
298     return flr_;
299   }
300 
301   const absl::flat_hash_map<SafeTensorId, SafeTensorId, SafeTensorId::Hasher>&
tensor_mapping() const302   tensor_mapping() const {
303     return tensor_mapping_;
304   }
305 
control_overrides() const306   const absl::flat_hash_map<string, std::vector<string>>& control_overrides()
307       const {
308     return control_overrides_;
309   }
310 
graph_view() const311   const GraphView& graph_view() const { return graph_view_; }
312 
devices() const313   const DeviceSet* devices() const {
314     // Create fake devices lazily only if we need a DeviceSet.
315     if (available_devices_.empty() && !item_->devices().empty()) {
316       for (const string& name : item_->devices()) {
317         auto device = absl::make_unique<FakeDevice>(name);
318         available_device_set_.AddDevice(device.get());
319         available_devices_.push_back(std::move(device));
320       }
321     }
322     return &available_device_set_;
323   }
324 
IsFetchNode(const string & node_name) const325   bool IsFetchNode(const string& node_name) const {
326     return absl::c_any_of(item_->fetch, [&](const string& fetch) {
327       return ParseTensorName(fetch).node() == node_name;
328     });
329   }
330 
IsKeepOp(const string & node_name) const331   bool IsKeepOp(const string& node_name) const {
332     return absl::c_any_of(item_->keep_ops, [&](const string& keep_node) {
333       return keep_node == node_name;
334     });
335   }
336 
IsTrulyConst(const string & name) const337   bool IsTrulyConst(const string& name) const {
338     return TrulyConstNode(name) != nullptr;
339   }
340 
TrulyConstNode(const string & name) const341   const NodeDef* TrulyConstNode(const string& name) const {
342     return gtl::FindWithDefault(truly_const_nodes_, name, nullptr);
343   }
344 
FindFunctionSpecialization(const FunctionSpecializationSignature & sig) const345   const FunctionSpecialization* FindFunctionSpecialization(
346       const FunctionSpecializationSignature& sig) const {
347     return gtl::FindOrNull(specialized_functions_, sig);
348   }
349 
AddSpecializedFunction(const FunctionSpecializationSignature & sig,const FunctionSpecialization & specialized_func)350   void AddSpecializedFunction(const FunctionSpecializationSignature& sig,
351                               const FunctionSpecialization& specialized_func) {
352     specialized_functions_.emplace(sig, specialized_func);
353   }
354 
AddTensorMapping(const SafeTensorId & from,const SafeTensorId & to)355   void AddTensorMapping(const SafeTensorId& from, const SafeTensorId& to) {
356     DCHECK(from.index() != Graph::kControlSlot)
357         << "Tensor mapping must be from regular tensor";
358     DCHECK(to.index() != Graph::kControlSlot)
359         << "Tensor mapping must be to regular tensor";
360 
361     auto inserted = tensor_mapping_.insert({from, to});
362     DCHECK(inserted.second)
363         << "Failed to insert duplicated tensor mapping: "
364         << "from=" << from.ToString() << " to=" << to.ToString();
365   }
366 
AddTensorMapping(const string & func_node,const FunctionSpecialization & specialized_func)367   void AddTensorMapping(const string& func_node,
368                         const FunctionSpecialization& specialized_func) {
369     for (const auto& pair : specialized_func.output_mapping) {
370       int from_idx = pair.first;
371       int to_idx = pair.second;
372       if (from_idx != to_idx) {
373         SafeTensorId from_tensor(func_node, from_idx);
374         SafeTensorId to_tensor(func_node, to_idx);
375         AddTensorMapping(from_tensor, to_tensor);
376       }
377     }
378   }
379 
AddControlOverrides(const NodeDef & func_node,const std::vector<string> & control_overrides)380   void AddControlOverrides(const NodeDef& func_node,
381                            const std::vector<string>& control_overrides) {
382     VLOG(4) << "Add control overrides: from=" << func_node.name() << " to: ["
383             << absl::StrJoin(control_overrides, ", ") << "]";
384 
385     control_overrides_[func_node.name()].reserve(control_overrides.size());
386     for (const string& control_override : control_overrides) {
387       control_overrides_[func_node.name()].push_back(control_override);
388     }
389   }
390 
391  private:
InferTrulyConstNodes(const GrapplerItem & item,const GraphDef & graph)392   static absl::flat_hash_map<string, const NodeDef*> InferTrulyConstNodes(
393       const GrapplerItem& item, const GraphDef& graph) {
394     absl::flat_hash_set<absl::string_view> feed_nodes;
395     for (const auto& feed : item.feed) {
396       feed_nodes.insert(feed.first);
397     }
398 
399     absl::flat_hash_map<string, const NodeDef*> const_nodes;
400     for (const NodeDef& node : graph.node()) {
401       if (IsConstant(node) && !feed_nodes.contains(node.name())) {
402         const_nodes[node.name()] = &node;
403       }
404     }
405 
406     return const_nodes;
407   }
408 
InitializeFunctionLibraryRuntime()409   void InitializeFunctionLibraryRuntime() {
410     if (!flr_) {
411       Env* env = Env::Default();
412       std::vector<std::unique_ptr<Device>> devices;
413       devices.push_back(absl::make_unique<FakeDevice>(env, "/device:CPU:0"));
414       device_mgr_ = absl::make_unique<DeviceMgr>(std::move(devices));
415       OptimizerOptions optimizer_opts;
416       optimizer_opts.set_do_function_inlining(true);
417       process_flr_.reset(new ProcessFunctionLibraryRuntime(
418           device_mgr_.get(), env, item_->graph.versions().producer(),
419           &function_library_, optimizer_opts));
420       flr_ = process_flr_->GetFLR(device_mgr_->ListDevices()[0]->name());
421     }
422   }
423 
424   const GrapplerItem* item_;  // must outlive this object
425   RewriterConfig::Toggle opt_level_;
426 
427   // Function library constructed from current graph.
428   FunctionLibraryDefinition function_library_;
429 
430   // These fields initialized lazily only if needed.
431   std::unique_ptr<DeviceMgr> device_mgr_;
432   std::unique_ptr<ProcessFunctionLibraryRuntime> process_flr_;
433   FunctionLibraryRuntime* flr_ = nullptr;
434 
435   // List of available `FakedDevices` (lazily initialized, see devices()).
436   mutable std::vector<std::unique_ptr<Device>> available_devices_;
437 
438   // DeviceSet of fake devices (`FakeDevice`) constructed from
439   // item_.devices() (lazily initialized).
440   mutable DeviceSet available_device_set_;
441 
442   // Nodes that are Const and not in feed.
443   absl::flat_hash_map<string, const NodeDef*> truly_const_nodes_;
444   // Specialized functions.
445   absl::flat_hash_map<FunctionSpecializationSignature,
446                       const FunctionSpecialization>
447       specialized_functions_;
448 
449   // After function inlining and specialization, the optimized graph might be in
450   // invalid state, nodes can read from non-existing function call nodes that
451   // were inlined, or they can read from output index that is no longer valid
452   // after unused outputs pruning.
453   //
454   // Tensor mapping that has to be applied to the graph after all functions
455   // optimizations (invalidated tensor id -> optimized graph tensor id).
456   absl::flat_hash_map<SafeTensorId, SafeTensorId, SafeTensorId::Hasher>
457       tensor_mapping_;
458 
459   // When we inline a function into the optimized graph, we no longer have the
460   // function call node to anchor control dependencies. Instead we must expand
461   // each function call control output edge into multiple control dependencies
462   // to all side-effectful ops inside the function body.
463   //
464   // Invalidated function call node name -> Inlined side-effectful nodes
465   absl::flat_hash_map<string, std::vector<string>> control_overrides_;
466 
467   // Use graph view to find active outputs of the function caller nodes.
468   GraphView graph_view_;
469 
470   TF_DISALLOW_COPY_AND_ASSIGN(FunctionOptimizerContext);
471 };
472 
473 // Returns a pointer to the called function definition iff the given node is
474 // indeed a function call. Otherwise returns nullptr.
FindFunctionCall(const FunctionOptimizerContext & ctx,const NodeDef & node)475 const FunctionDef* FindFunctionCall(const FunctionOptimizerContext& ctx,
476                                     const NodeDef& node) {
477   // Check if a node does indirect function call via PartitionedCallOp.
478   if (IsPartitionedCall(node) || IsStatefulPartitionedCall(node)) {
479     const AttrValue* func_attr = AttrSlice(node).Find("f");
480     return (func_attr != nullptr && func_attr->has_func())
481                ? ctx.function_library().Find(func_attr->func().name())
482                : nullptr;
483   }
484 
485   // Check if the function op itself is a function name.
486   return ctx.function_library().Find(node.op());
487 }
488 
GetActiveOutputs(const NodeDef & node,const FunctionOptimizerContext & ctx,int size_hint=0)489 absl::flat_hash_set<int> GetActiveOutputs(const NodeDef& node,
490                                           const FunctionOptimizerContext& ctx,
491                                           int size_hint = 0) {
492   absl::flat_hash_set<int> active_outputs;
493   active_outputs.reserve(static_cast<size_t>(size_hint));
494 
495   // 1. Output can be consumed by the other graph node.
496   const auto node_fanout_edges =
497       ctx.graph_view().GetFanoutEdges(node, /*include_controlled_edges=*/false);
498   for (const GraphView::Edge& edge : node_fanout_edges) {
499     active_outputs.insert(edge.src.port_id);
500   }
501 
502   // 2. Or it can be in a fetch set.
503   for (const string& fetch : ctx.item().fetch) {
504     TensorId fetch_tensor = ParseTensorName(fetch);
505     if (fetch_tensor.node() == node.name()) {
506       active_outputs.insert(fetch_tensor.index());
507     }
508   }
509 
510   return active_outputs;
511 }
512 
HasTrulyConstInputs(const NodeDef & node,const FunctionOptimizerContext & ctx)513 bool HasTrulyConstInputs(const NodeDef& node,
514                          const FunctionOptimizerContext& ctx) {
515   const auto is_truly_const = [&ctx](const string& input) {
516     return ctx.IsTrulyConst(NodeName(input));
517   };
518   return absl::c_any_of(node.input(), is_truly_const);
519 }
520 
HasUnusedOutputs(const NodeDef & func_node,const FunctionDef & func,const FunctionOptimizerContext & ctx)521 bool HasUnusedOutputs(const NodeDef& func_node, const FunctionDef& func,
522                       const FunctionOptimizerContext& ctx) {
523   // Functions with tensor list outputs are not supported right now, so the
524   // number of output args is the same as number of possible function caller
525   // node outputs.
526   int num_outputs = func.signature().output_arg_size();
527   const absl::flat_hash_set<int> active_outputs =
528       GetActiveOutputs(func_node, ctx, /*size_hind*/ num_outputs);
529 
530   return active_outputs.size() != num_outputs;
531 }
532 
533 // Return pruned FunctionDefLibrary with functions that are reachable from
534 // the optimized graph.
PruneFunctionLibrary(const FunctionLibraryDefinition & flib,const GraphDef & optimized_graph)535 FunctionDefLibrary PruneFunctionLibrary(const FunctionLibraryDefinition& flib,
536                                         const GraphDef& optimized_graph) {
537   FunctionLibraryDefinition pruned_flib =
538       flib.ReachableDefinitions(optimized_graph);
539 
540   int pruned_functions = static_cast<int>(pruned_flib.num_functions()) -
541                          static_cast<int>(flib.num_functions());
542 
543   VLOG(3) << "Pruned function library: " << pruned_flib.num_functions()
544           << " functions (" << pruned_functions << ")";
545 
546   return pruned_flib.ToProto();
547 }
548 
549 // Push all constant inputs of an instantiating node into the function body.
PushDownConstInputs(const NodeDef & func_node,const FunctionOptimizerContext & ctx,GrapplerFunctionItem * item,absl::flat_hash_set<string> * const_inputs,absl::flat_hash_set<string> * control_deps)550 Status PushDownConstInputs(const NodeDef& func_node,
551                            const FunctionOptimizerContext& ctx,
552                            GrapplerFunctionItem* item,
553                            absl::flat_hash_set<string>* const_inputs,
554                            absl::flat_hash_set<string>* control_deps) {
555   // Record node control dependencies in the control_deps set.
556   const auto record_control_deps = [&](const NodeDef* const_input) {
557     for (int i = const_input->input_size() - 1; i >= 0; --i) {
558       const string& input = const_input->input(i);
559       if (IsControlInput(input))
560         control_deps->insert(input);
561       else
562         break;
563     }
564   };
565 
566   for (int i = func_node.input_size() - 1; i >= 0; --i) {
567     const string& input = func_node.input(i);
568     if (IsControlInput(input)) continue;
569 
570     const string node_name = NodeName(input);
571     if (ctx.IsTrulyConst(node_name)) {
572       VLOG(3) << "Push const into function body: input=" << input;
573       const auto* const_input = CHECK_NOTNULL(ctx.TrulyConstNode(node_name));
574       const_inputs->insert(input);
575       record_control_deps(const_input);
576       TF_RETURN_IF_ERROR(ReplaceInputWithConst(*const_input, i, item));
577     }
578   }
579 
580   return Status::OK();
581 }
582 
583 // Remove inputs that were pushed into the function body, and attach their
584 // control dependencies to the function caller node.
RemovePushedDownConstInputs(const FunctionSpecialization & specialization,NodeDef * specialized_func_node)585 void RemovePushedDownConstInputs(const FunctionSpecialization& specialization,
586                                  NodeDef* specialized_func_node) {
587   // Nothing to do if it was no const inputs to the function node.
588   if (specialization.const_inputs.empty()) return;
589 
590   // Keep only non-const inputs.
591   std::vector<string> keep_inputs;
592   const auto& inputs = specialized_func_node->input();
593   std::copy_if(inputs.begin(), inputs.end(), std::back_inserter(keep_inputs),
594                [&](const string& input) {
595                  return specialization.const_inputs.find(input) ==
596                         specialization.const_inputs.end();
597                });
598 
599   specialized_func_node->clear_input();
600   for (const auto& keep : keep_inputs) specialized_func_node->add_input(keep);
601 
602   // Attach control dependencies of pushed down const input to the caller node.
603   if (!specialization.control_deps.empty()) {
604     absl::flat_hash_set<string> existing_control_deps;
605 
606     for (const string& input : keep_inputs) {
607       existing_control_deps.insert(AsControlDependency(NodeName(input)));
608     }
609 
610     for (const string& ctrl : specialization.control_deps) {
611       if (existing_control_deps.find(ctrl) == existing_control_deps.end()) {
612         VLOG(3) << "Forward control dependency: input=" << ctrl;
613         specialized_func_node->add_input(ctrl);
614       }
615     }
616   }
617 }
618 
619 // Remove Tin type parameters for pushed down const inputs.
RemovePushedDownConstInputTypes(const FunctionSpecialization & specialization,const NodeDef & func_node,NodeDef * specialized_func_node)620 void RemovePushedDownConstInputTypes(
621     const FunctionSpecialization& specialization, const NodeDef& func_node,
622     NodeDef* specialized_func_node) {
623   // Nothing to do if it was no const inputs to the function node.
624   if (specialization.const_inputs.empty()) return;
625 
626   // Make sure that original function caller has Tin attribute.
627   const AttrValue* tin = AttrSlice(func_node).Find("Tin");
628   if (tin == nullptr || !tin->has_list()) return;
629 
630   // Clear input types for the specialized node.
631   auto* attr = specialized_func_node->mutable_attr();
632   (*attr)["Tin"].mutable_list()->clear_type();
633 
634   // Keep types of non-const inputs.
635   for (int i = 0; i < func_node.input_size(); ++i) {
636     const string& input = func_node.input(i);
637     if (IsControlInput(input)) break;
638 
639     if (specialization.const_inputs.find(input) ==
640         specialization.const_inputs.end()) {
641       DataType dt = tin->list().type(i);
642       (*attr)["Tin"].mutable_list()->add_type(dt);
643     }
644   }
645 }
646 
647 // Remove Tout type parameters for pruned function outputs.
RemoveUnusedOutputsTypes(const FunctionSpecialization & specialization,const NodeDef & func_node,NodeDef * specialized_func_node)648 void RemoveUnusedOutputsTypes(const FunctionSpecialization& specialization,
649                               const NodeDef& func_node,
650                               NodeDef* specialized_func_node) {
651   // Make sure that original function caller has Tout attribute.
652   const AttrValue* tout = AttrSlice(func_node).Find("Tout");
653   if (tout == nullptr || !tout->has_list()) return;
654 
655   // Nothing to do if all outputs are active.
656   if (specialization.active_outputs.size() == tout->list().type_size()) return;
657 
658   // Clear input types for the specialized node.
659   auto* attr = specialized_func_node->mutable_attr();
660   (*attr)["Tout"].mutable_list()->clear_type();
661 
662   // Keep output types of active outputs only.
663   for (int i = 0; i < tout->list().type_size(); ++i) {
664     if (specialization.active_outputs.find(i) !=
665         specialization.active_outputs.end()) {
666       DataType dt = tout->list().type(i);
667       (*attr)["Tout"].mutable_list()->add_type(dt);
668     }
669   }
670 }
671 
UpdateSpecializedFunctionCallSite(const FunctionDef & func,const NodeDef & func_node,const string & specialized_func_name,NodeDef * specialized_func_node)672 Status UpdateSpecializedFunctionCallSite(const FunctionDef& func,
673                                          const NodeDef& func_node,
674                                          const string& specialized_func_name,
675                                          NodeDef* specialized_func_node) {
676   if (IsDirectFunctionCall(func, func_node)) {
677     specialized_func_node->set_op(specialized_func_name);
678 
679   } else if (IsIndirectFunctionCall(func, func_node)) {
680     auto* attr = specialized_func_node->mutable_attr();
681     (*attr)[kFuncAttrName].mutable_func()->set_name(specialized_func_name);
682 
683   } else {
684     return errors::InvalidArgument("Unknown function call site");
685   }
686 
687   return Status::OK();
688 }
689 
690 // Update a graph node created from the original function caller node, to the
691 // function specialization. Function specialization might change the number of
692 // inputs and outputs, so we have to make sure that graph node is updated
693 // accordingly.
UpdateSpecializedFunctionNode(const FunctionDef & func,const NodeDef & func_node,const FunctionSpecialization & specialization,NodeDef * specialized_func_node)694 Status UpdateSpecializedFunctionNode(
695     const FunctionDef& func, const NodeDef& func_node,
696     const FunctionSpecialization& specialization,
697     NodeDef* specialized_func_node) {
698   // Function called indirectly via custom kernel (e.g. PartitionedCallOp).
699   bool is_indirect_call = IsIndirectFunctionCall(func, func_node);
700 
701   // 1. Call the specialized function instead of original one.
702   TF_RETURN_IF_ERROR(UpdateSpecializedFunctionCallSite(
703       func, func_node, specialization.specialized_func_name,
704       specialized_func_node));
705 
706   // 2. Remove inputs corresponding to the pushed down consts.
707   RemovePushedDownConstInputs(specialization, specialized_func_node);
708 
709   // NOTE: PartitionedCallOp has `Tin` and `Tout` attributes for input/output
710   // types, that must be in sync with updated function signature.
711 
712   // 3. Update input types for the indirect function calls.
713   if (is_indirect_call) {
714     RemovePushedDownConstInputTypes(specialization, func_node,
715                                     specialized_func_node);
716   }
717 
718   // 4. Update output types for the indirect function call. It's unsafe to
719   // change the number of outputs for the fetch nodes, so we just skip them.
720   if (is_indirect_call && !specialization.is_in_fetch_set) {
721     RemoveUnusedOutputsTypes(specialization, func_node, specialized_func_node);
722   }
723 
724   // 5. Remove custom gradient annotation.
725   specialized_func_node->mutable_attr()->erase("_gradient_op_type");
726 
727   return Status::OK();
728 }
729 
InitializeFunctionSpecializationSignature(const NodeDef & func_node,const FunctionDef & func,const AttrSlice & func_instantiation_attr,const FunctionOptimizerContext & ctx,FunctionSpecializationSignature * sig)730 Status InitializeFunctionSpecializationSignature(
731     const NodeDef& func_node, const FunctionDef& func,
732     const AttrSlice& func_instantiation_attr,
733     const FunctionOptimizerContext& ctx, FunctionSpecializationSignature* sig) {
734   DCHECK(sig->const_inputs.empty());
735   DCHECK(sig->active_outputs.empty());
736 
737   sig->func_name = func.signature().name();
738   sig->is_in_fetch_set = ctx.IsFetchNode(func_node.name());
739   sig->active_outputs = GetActiveOutputs(func_node, ctx);
740 
741   TF_RETURN_IF_ERROR(InstantiationTypeParameters(func, func_instantiation_attr,
742                                                  &sig->type_parameters));
743   TF_RETURN_IF_ERROR(InstantiationBodyParameters(func, func_instantiation_attr,
744                                                  &sig->body_parameters));
745 
746   for (int i = 0; i < func_node.input_size(); ++i) {
747     const string& input = func_node.input(i);
748     if (IsControlInput(input)) break;
749     if (ctx.IsTrulyConst(input)) {
750       sig->const_inputs.emplace(i, input);
751     }
752   }
753 
754   return Status::OK();
755 }
756 
757 // Create a name for the function specialization. The name of the function, name
758 // of the node instantiating it, and a Grappler item id should generate unique
759 // function name. Meta optimizer might create multiple Grappler items for the
760 // same graph when optimizing functions, but it's guaranteed that they all will
761 // have unique ids.
SpecializedFunctionName(const FunctionOptimizerContext & ctx,const FunctionDef & func,const NodeDef & func_node)762 string SpecializedFunctionName(const FunctionOptimizerContext& ctx,
763                                const FunctionDef& func,
764                                const NodeDef& func_node) {
765   return absl::Substitute(
766       "$0_specialized_for_$1_at_$2", func.signature().name(),
767       absl::StrReplaceAll(func_node.name(), {{"/", "_"}}), ctx.item().id);
768 }
769 
SpecializeFunction(const NodeDef & func_node,const FunctionDef & func,FunctionOptimizerContext * ctx,GraphDef * optimized_graph)770 Status SpecializeFunction(const NodeDef& func_node, const FunctionDef& func,
771                           FunctionOptimizerContext* ctx,
772                           GraphDef* optimized_graph) {
773   VLOG(2) << "Specialize function call: " << SummarizeNodeDef(func_node);
774 
775   const AttrSlice func_instantiation_attr =
776       FunctionInstantiationAttributes(func, func_node);
777 
778   FunctionSpecializationSignature signature;
779   TF_RETURN_IF_ERROR(InitializeFunctionSpecializationSignature(
780       func_node, func, func_instantiation_attr, *ctx, &signature));
781 
782   // Check if function was already specialized for identical context.
783   const FunctionSpecialization* already_specialized =
784       ctx->FindFunctionSpecialization(signature);
785 
786   if (already_specialized) {
787     VLOG(2) << "Function was already specialized in identical context: "
788                "specialized_name="
789             << already_specialized->specialized_func_name;
790 
791     // Add a function call node for the specialized function.
792     NodeDef* specialized_func_node = optimized_graph->add_node();
793     *specialized_func_node = func_node;
794 
795     TF_RETURN_IF_ERROR(UpdateSpecializedFunctionNode(
796         func, func_node, *already_specialized, specialized_func_node));
797 
798     ctx->AddTensorMapping(specialized_func_node->name(), *already_specialized);
799 
800     return Status::OK();
801   }
802 
803   // Add a new specialized function definition to the library.
804   const auto& flib = ctx->function_library();
805 
806   // Make a GrapplerFunctionItem and convert it back to FunctionDef after
807   // pushing all constant inputs into the function body.
808   GrapplerFunctionItem item;
809   TF_RETURN_IF_ERROR(MakeGrapplerFunctionItem(
810       func, func_instantiation_attr, flib, ctx->graph_version(), &item));
811 
812   // Push const inputs into the function body, and keep track of their control
813   // dependencies.
814   absl::flat_hash_set<string> const_inputs;
815   absl::flat_hash_set<string> control_deps;
816   TF_RETURN_IF_ERROR(PushDownConstInputs(func_node, *ctx, &item, &const_inputs,
817                                          &control_deps));
818 
819   // Remove function outputs that do not have any consumers. We can't safely
820   // update outputs for the fetch nodes, so we just skip them.
821   std::vector<std::pair<int, int>> output_mapping;
822   if (!signature.is_in_fetch_set) {
823     int num_func_outputs = 0;
824     for (const auto& out_arg : item.outputs()) {
825       num_func_outputs += out_arg.output_nodes.size();
826     }
827 
828     absl::flat_hash_set<int> remove;
829     for (int i = 0; i < num_func_outputs; ++i) {
830       if (!signature.active_outputs.count(i)) remove.insert(i);
831     }
832 
833     TF_RETURN_IF_ERROR(RemoveFunctionOutputs(remove, &item, &output_mapping));
834   }
835 
836   // TODO(ezhulenev): Push down known input shapes.
837   FunctionDef specialized_func;
838   TF_RETURN_IF_ERROR(MakeFunctionDef(item, flib, &specialized_func));
839 
840   // Find a name for specialized function.
841   const string specialized_func_name =
842       SpecializedFunctionName(*ctx, func, func_node);
843   if (flib.Contains(specialized_func_name)) {
844     // NOTE(ezhulenev): This should never happen. If it happens, it's a sign of
845     // a serious internal error, that must be investigated.
846     return errors::Internal("Created duplicate function specialization");
847   }
848 
849   specialized_func.mutable_signature()->set_name(specialized_func_name);
850   auto* specialized_attr = specialized_func.mutable_attr();
851   (*specialized_attr)[kGrapplerSpecializedFuncAttr].set_b(true);
852 
853   // Add specialized function to the library.
854   TF_RETURN_IF_ERROR(
855       ctx->mutable_function_library()->AddFunctionDef(specialized_func));
856 
857   // Add a function call node for the specialized function.
858   NodeDef* specialized_func_node = optimized_graph->add_node();
859   *specialized_func_node = func_node;
860 
861   FunctionSpecialization func_specialization = {
862       specialized_func_name, signature.is_in_fetch_set, const_inputs,
863       control_deps,          signature.active_outputs,  output_mapping};
864 
865   TF_RETURN_IF_ERROR(UpdateSpecializedFunctionNode(
866       func, func_node, func_specialization, specialized_func_node));
867 
868   ctx->AddSpecializedFunction(signature, func_specialization);
869   ctx->AddTensorMapping(specialized_func_node->name(), func_specialization);
870 
871   return Status::OK();
872 }
873 
874 // -------------------------------------------------------------------------- //
875 // Inline direct functions calls.
876 //
877 // When we inline direct function calls, we instantiate the function body from
878 // its FunctionDef and caller node attributes, and embed the instantiated graph
879 // into the "main graph". When we do that, we must preserve the function call
880 // semantics:
881 //
882 // 1) All input nodes must be executed before any of function body nodes will
883 //    start executing.
884 // 2) All function body nodes must be executed before any of the nodes, reading
885 //    outputs of the function will start executing.
886 // 3) All nodes with side effects inside a function must be executed, this is
887 //    different from the nodes with side effects in the main graph, that can be
888 //    pruned if they are not in transitive dependency set of any of the fetch
889 //    nodes.
890 // 4) All nodes of the function body must be execute on the device specified by
891 //    the function caller node.
892 //
893 // To guarantee that function call semantics are preserved after inlining, we
894 // insert an IdentityN node before the inlined function body, and hook all
895 // inputs into that, and we insert another IdentityN node to hook all function
896 // outputs to it.
897 
898 // Returns `Status::OK()` iff `node` is a direct function call of `func`, and we
899 // know how to inline it into the main graph, otherwise returns and error
900 // indicating why the function call is not inlinable.
IsInlinableDirectFunctionCall(const FunctionOptimizerContext & ctx,const FunctionDef & func,const NodeDef & func_node)901 Status IsInlinableDirectFunctionCall(const FunctionOptimizerContext& ctx,
902                                      const FunctionDef& func,
903                                      const NodeDef& func_node) {
904   // Indirect function calls (PartitionedCallOp) have automatic control
905   // dependencies and inlined separately from direct function calls.
906   if (!IsDirectFunctionCall(func, func_node)) {
907     return errors::InvalidArgument("Unsupported function call type: ",
908                                    SummarizeNodeDef(func_node));
909   }
910 
911   // For direct function  calls we insert IdentityN nodes before/after inlined
912   // function body to preserve function call semantics (all inputs evaluated
913   // before function evaluation starts, and all function body nodes finished
914   // before output consumed by other nodes).
915   if (func.signature().input_arg_size() == 0) {
916     return errors::FailedPrecondition(
917         "Can't inline direct function call with empty inputs: ",
918         SummarizeNodeDef(func_node));
919   }
920 
921   // TODO(ezhulenev): Relax constraint on output args?
922   if (func.signature().output_arg_size() == 0) {
923     return errors::FailedPrecondition(
924         "Can't inline direct function call with empty outputs: ",
925         SummarizeNodeDef(func_node));
926   }
927 
928   // Function must execute all the nodes in a function body that might have side
929   // effects. After inlining these nodes into the main graph, we can no longer
930   // guarantee that. For now we disable inlining functions with side effects.
931   //
932   // Attaching control dependency to the output IdentityN node is not safe,
933   // because it might be split or pruned in a later optimization pass.
934   //
935   // Indirect function calls (via PartitionedCallOp) have automatic dependency
936   // tracking, and allow us to safely inline functions with side effects.
937   bool has_side_effects =
938       absl::c_any_of(func.node_def(), [&ctx](const NodeDef& node) {
939         return !IsFreeOfSideEffect(node, &ctx.function_library());
940       });
941   if (has_side_effects) {
942     return errors::FailedPrecondition(
943         "Can't inline function with side-effects in the function body: ",
944         SummarizeNodeDef(func_node));
945   }
946 
947   // We ignore `_noinline` marker in aggressive mode.
948   bool aggressive = ctx.opt_level() == RewriterConfig::AGGRESSIVE;
949   if (MarkedNoInline(func) && !aggressive) {
950     return errors::FailedPrecondition(
951         "Can't inline function marked with '_noinline': ",
952         SummarizeNodeDef(func_node));
953   }
954 
955   // Function specialization and inlining must be mutually exclusive.
956   if (MarkedSpecialized(func)) {
957     return errors::FailedPrecondition(
958         "Can't inline function created in Grappler function specialization: ",
959         SummarizeNodeDef(func_node));
960   }
961 
962   return Status::OK();
963 }
964 
965 // Create an IdentityN node to hook the function inputs to: this ensures that
966 // they're all evaluated before the evaluation of the function body starts.
InlinedFunctionInputsNode(const NodeDef & func_node,const GrapplerFunctionItem & item)967 NodeDef InlinedFunctionInputsNode(const NodeDef& func_node,
968                                   const GrapplerFunctionItem& item) {
969   NodeDef inputs;
970   inputs.set_name(strings::StrCat(func_node.name(), "/", "inlined_inputs"));
971   inputs.set_op("IdentityN");
972   inputs.set_device(func_node.device());
973   *inputs.mutable_input() = func_node.input();
974   AttrValue::ListValue* type_list =
975       (*inputs.mutable_attr())["T"].mutable_list();
976 
977   for (const InputArgExpansion& input_arg : item.inputs()) {
978     for (int i = 0; i < input_arg.placeholders.size(); ++i) {
979       type_list->add_type(input_arg.data_type);
980     }
981   }
982 
983   return inputs;
984 }
985 
986 // Create an IdentityN node to hook the function outputs to: this ensures that
987 // the function body is fully evaluated before its fanout gets scheduled.
InlinedFunctionOutputsNode(const NodeDef & func_node,const GrapplerFunctionItem & item,const absl::flat_hash_map<absl::string_view,absl::string_view> output_tensors)988 NodeDef InlinedFunctionOutputsNode(
989     const NodeDef& func_node, const GrapplerFunctionItem& item,
990     const absl::flat_hash_map<absl::string_view, absl::string_view>
991         output_tensors) {
992   NodeDef outputs;
993   outputs.set_name(func_node.name());
994   outputs.set_op("IdentityN");
995   outputs.set_device(func_node.device());
996   AttrValue::ListValue* type_list =
997       (*outputs.mutable_attr())["T"].mutable_list();
998 
999   for (const OutputArgExpansion& output_arg : item.outputs()) {
1000     for (const string& output_node : output_arg.output_nodes) {
1001       const absl::string_view output_tensor = output_tensors.at(output_node);
1002       type_list->add_type(output_arg.data_type);
1003       outputs.add_input(strings::StrCat(func_node.name(), "/", output_tensor));
1004     }
1005   }
1006 
1007   return outputs;
1008 }
1009 
InlineDirectFunctionCall(const NodeDef & func_node,const FunctionDef & func,const FunctionOptimizerContext & ctx,GraphDef * optimized_graph)1010 Status InlineDirectFunctionCall(const NodeDef& func_node,
1011                                 const FunctionDef& func,
1012                                 const FunctionOptimizerContext& ctx,
1013                                 GraphDef* optimized_graph) {
1014   VLOG(2) << "Inline direct function call: " << SummarizeNodeDef(func_node);
1015   TF_RETURN_IF_ERROR(IsInlinableDirectFunctionCall(ctx, func, func_node));
1016 
1017   const AttrSlice func_instantiation_attr =
1018       FunctionInstantiationAttributes(func, func_node);
1019 
1020   GrapplerFunctionItem item;
1021   Status item_status = MakeGrapplerFunctionItem(func, func_instantiation_attr,
1022                                                 ctx.function_library(),
1023                                                 ctx.graph_version(), &item);
1024 
1025   if (!item_status.ok()) {
1026     return errors::InvalidArgument("Failed to inline function ", func_node.op(),
1027                                    " instantiated by ", func_node.name(),
1028                                    ". Error: ", item_status.error_message());
1029   }
1030 
1031   // Mapping from input placeholder name to function input position.
1032   absl::flat_hash_map<absl::string_view, int> input_placeholders_idx;
1033   for (const InputArgExpansion& input_arg : item.inputs()) {
1034     for (const string& placeholder : input_arg.placeholders) {
1035       const int idx = input_placeholders_idx.size();
1036       input_placeholders_idx[placeholder] = idx;
1037     }
1038   }
1039 
1040   // Bypass identity nodes added to the graph in place of function outputs.
1041   absl::flat_hash_set<absl::string_view> output_nodes;
1042   for (const OutputArgExpansion& output_arg : item.outputs()) {
1043     for (const string& output_node : output_arg.output_nodes) {
1044       output_nodes.insert(output_node);
1045     }
1046   }
1047 
1048   // For each function output value we added an identity node that reads the
1049   // tensor from one of the function body nodes. When we inline function into
1050   // the main graph we want to bypass these nodes, so we keep a mapping from
1051   // 'output node name' -> 'output tensor name'.
1052   absl::flat_hash_map<absl::string_view, absl::string_view> output_tensors;
1053 
1054   // Hook inlined function inputs to IdentityN node.
1055   NodeDef* func_inputs = optimized_graph->add_node();
1056   *func_inputs = InlinedFunctionInputsNode(func_node, item);
1057 
1058   for (NodeDef& func_body_node : *item.mutable_function_body().mutable_node()) {
1059     const string& node_name = func_body_node.name();
1060 
1061     // Skip output identity node, and update a mapping to the output tensor.
1062     if (IsIdentity(func_body_node) && output_nodes.count(node_name)) {
1063       output_tensors.emplace(node_name, func_body_node.input(0));
1064       continue;
1065     }
1066 
1067     // Turn placeholders added in place of input arguments into identity nodes.
1068     const auto input_placeholder_idx = input_placeholders_idx.find(node_name);
1069     if (input_placeholder_idx != input_placeholders_idx.end()) {
1070       CHECK_EQ(0, func_body_node.input_size());
1071       func_body_node.set_op("Identity");
1072       (*func_body_node.mutable_attr())["T"] = func_body_node.attr().at("dtype");
1073       func_body_node.mutable_attr()->erase("dtype");
1074       func_body_node.mutable_attr()->erase("shape");
1075       func_body_node.add_input(strings::StrCat(func_inputs->name(), ":",
1076                                                input_placeholder_idx->second));
1077     } else {
1078       // Update the input names if any.
1079       for (string& input : *func_body_node.mutable_input()) {
1080         input = AddPrefixToNodeName(input, /*prefix=*/func_node.name());
1081       }
1082       // If the node has no input, make hook it up to the func_inputs node to
1083       // ensure it runs in the same frame as the other nodes of the function
1084       // body.
1085       if (func_body_node.input_size() == 0) {
1086         *func_body_node.add_input() = AsControlDependency(func_inputs->name());
1087       }
1088     }
1089 
1090     // Add the function node name as a prefix 1) to node name to avoid
1091     // collisions; 2) to frame name to avoid multiple LoopCond nodes in one
1092     // frame after inlining.
1093     const string prefix = strings::StrCat(func_node.name(), "/");
1094     TF_RETURN_IF_ERROR(
1095         AddPrefixAndSuffixToNode(prefix, "" /* suffix */, &func_body_node));
1096 
1097     // Make sure the node is placed.
1098     func_body_node.set_device(func_node.device());
1099 
1100     // Move the node to the main graph.
1101     optimized_graph->add_node()->Swap(&func_body_node);
1102   }
1103 
1104   DCHECK(output_tensors.size() == item.output_size())
1105       << "Each function output must be mapped to an output tensor";
1106 
1107   // Hook inlined function outputs to IdentityN node.
1108   NodeDef* func_outputs = optimized_graph->add_node();
1109   *func_outputs = InlinedFunctionOutputsNode(func_node, item, output_tensors);
1110 
1111   return Status::OK();
1112 }
1113 
InlineSymbolicGradient(const NodeDef & node,FunctionOptimizerContext * ctx,GraphDef * optimized_graph)1114 Status InlineSymbolicGradient(const NodeDef& node,
1115                               FunctionOptimizerContext* ctx,
1116                               GraphDef* optimized_graph) {
1117   VLOG(2) << "Inline symbolic gradient: " << SummarizeNodeDef(node);
1118 
1119   GraphDef graph_def;
1120 
1121   // Create a node to anchor the gradient inputs
1122   NodeDef* inlined_input = graph_def.add_node();
1123   inlined_input->set_name("FunctionInputs");
1124   inlined_input->set_op("IdentityN");
1125   AttrValue::ListValue* type_list =
1126       (*inlined_input->mutable_attr())["T"].mutable_list();
1127   for (const auto& type : node.attr().at("Tin").list().type()) {
1128     type_list->add_type(static_cast<DataType>(type));
1129   }
1130 
1131   // Add the gradient node
1132   NodeDef* inlined = graph_def.add_node();
1133   *inlined = node;
1134   inlined->clear_input();
1135   for (int i = 0; i < node.attr().at("Tin").list().type_size(); ++i) {
1136     inlined->add_input(strings::StrCat(inlined_input->name(), ":", i));
1137   }
1138 
1139   // Create a node to anchor the gradient outputs
1140   NodeDef* inlined_output = graph_def.add_node();
1141   inlined_output->set_name("FunctionOutputs");
1142   inlined_output->set_op("IdentityN");
1143   type_list = (*inlined_output->mutable_attr())["T"].mutable_list();
1144   for (const auto& type : node.attr().at("Tout").list().type()) {
1145     type_list->add_type(static_cast<DataType>(type));
1146   }
1147   for (int i = 0; i < node.attr().at("Tout").list().type_size(); ++i) {
1148     inlined_output->add_input(strings::StrCat(inlined->name(), ":", i));
1149   }
1150 
1151   // Convert the graphdef to a graph
1152   GraphConstructorOptions graph_ctor_opts;
1153   graph_ctor_opts.allow_internal_ops = true;
1154   graph_ctor_opts.expect_device_spec = false;
1155   Graph graph(ctx->function_library());
1156   TF_RETURN_IF_ERROR(
1157       ConvertGraphDefToGraph(graph_ctor_opts, graph_def, &graph));
1158 
1159   FunctionLibraryRuntime* flr = ctx->mutable_function_library_runtime();
1160 
1161   // 1. Inline symbolic gradient node.
1162   const bool expanded = ExpandInlineFunctions(flr, &graph);
1163   if (!expanded) {
1164     return errors::Internal("Failed to expand SymbolicGradient op");
1165   }
1166 
1167   // TODO(ezhulenev): InlineFunctionBody in common_runtime/function silently
1168   // fails to inline function into the graph, and leaves the graph unmodified.
1169   // We check that graph has our symbolic gradient inlined, otherwise we return
1170   // a error.
1171   const auto is_symbolic_gradient_op = [&](const Node* node) {
1172     return node->name() == inlined->name() &&
1173            node->type_string() == "SymbolicGradient";
1174   };
1175   for (Node* node : graph.nodes()) {
1176     if (is_symbolic_gradient_op(node)) {
1177       return errors::Internal("Failed to inline symbolic gradient node: ",
1178                               SummarizeNode(*node));
1179     }
1180   }
1181 
1182   // 2. Recursively inline nested function calls.
1183   int iteration = 0;
1184   while (ExpandInlineFunctions(flr, &graph)) {
1185     if (++iteration >= 50) {
1186       VLOG(2) << "Break symbolic gradient inlining loop at iteration #"
1187               << iteration;
1188       break;
1189     }
1190   }
1191 
1192   GraphDef inlined_graph_def;
1193   graph.ToGraphDef(&inlined_graph_def);
1194 
1195   // Add the default values of attributes to the nodes that have been inlined.
1196   TF_RETURN_IF_ERROR(AddDefaultAttrsToGraphDef(&inlined_graph_def,
1197                                                *graph.op_registry(), 0, true));
1198 
1199   // Add the inlined nodes to the graph
1200   for (NodeDef& inlined_node : *inlined_graph_def.mutable_node()) {
1201     if (inlined_node.name() == "FunctionOutputs") {
1202       inlined_node.set_name(node.name());
1203       for (int i = 0; i < inlined_node.input_size(); ++i) {
1204         inlined_node.set_input(
1205             i, AddPrefixToNodeName(inlined_node.input(i), node.name()));
1206       }
1207     } else if (inlined_node.name() == "FunctionInputs") {
1208       inlined_node.set_name(
1209           AddPrefixToNodeName(inlined_node.name(), node.name()));
1210       inlined_node.clear_input();
1211       for (int i = 0; i < node.input_size(); ++i) {
1212         inlined_node.add_input(node.input(i));
1213       }
1214     } else {
1215       inlined_node.set_name(
1216           AddPrefixToNodeName(inlined_node.name(), node.name()));
1217       for (int i = 0; i < inlined_node.input_size(); ++i) {
1218         inlined_node.set_input(
1219             i, AddPrefixToNodeName(inlined_node.input(i), node.name()));
1220       }
1221       // If the node has no input, hook it up to the function input node to make
1222       // sure it runs in the same frame as the other nodes of the function body.
1223       if (inlined_node.input_size() == 0) {
1224         *inlined_node.add_input() = AsControlDependency(
1225             AddPrefixToNodeName("FunctionInputs", node.name()));
1226       }
1227     }
1228     inlined_node.set_device(node.device());
1229     optimized_graph->add_node()->Swap(&inlined_node);
1230   }
1231 
1232   return Status::OK();
1233 }
1234 
1235 // -------------------------------------------------------------------------- //
1236 // Inline indirect functions calls (aka PartitionedCallOp).
1237 //
1238 // When we inline indirect function calls, we instantiate the function body from
1239 // its FunctionDef and caller node attributes, and embed the instantiated graph
1240 // into the "main graph".
1241 //
1242 // In contrast to direct function calls, `PartitionedCallOp` has automatic
1243 // dependency tracking via input/output control edges, and we relax some of the
1244 // constraints that we have for direct function call inlining.
1245 //
1246 // Automatic control dependency rules:
1247 //
1248 // 1) "When a `PartitionedCallOp` function has a resource (DT_RESOURCE data
1249 //    type) input argument it "captures" the mutable resource.  This is
1250 //    implemented by automatically adding a incoming control edge from the
1251 //    previous side-effectful op touching that resource, and an outgoing control
1252 //    edge to the next side-effectful op using the same resource. This
1253 //    serializes the mutations of the resource to make graph execution
1254 //    deterministic.
1255 //
1256 // 2) All stateful ops inside a function body are guaranteed to execute in
1257 //    program order, this is achieved by adding control edges between stateful
1258 //    ops at graph construction time.
1259 //
1260 // 3) Furthermore, all ops accepting the same resource as an input are
1261 //    guaranteed to run in program order. This is also done by adding control
1262 //    edges at graph construction time. The last op touching the resource
1263 //    will have an outgoing control edge to all function return nodes, which
1264 //    will guarantee that all side effects to the resource will happen before
1265 //    function completion.
1266 //
1267 // Function call inlining must preserve side effect visibility:
1268 //
1269 // 1) All side effects to the captured resources, that happened before function
1270 //    call must be visible to the function body nodes using that resources.
1271 // 2) All side effects to the captured resources, that happened inside function
1272 //    body, must be visible to every op/function using that resource after the
1273 //    function call completed.
1274 //
1275 // To guarantee that these properties are preserved after inlining we:
1276 //
1277 // 1) Create "input_control" NoOp. Function call node incoming control edges
1278 //    will be forwarded *to* this node. Function inputs (Identity nodes) will
1279 //    have a control edge *from* this node. If function has no inputs, by
1280 //    construction it must have nodes without inputs in the function body, and
1281 //    in this case these nodes will have a control edge *from* this node.
1282 
1283 // 2) Create "output_control" NoOp. All nodes that have incoming control edge
1284 //    *from* the function call node, will be forwarded to this node. Function
1285 //    outputs (Identity nodes) will have a control edge *to* this node. This
1286 //    will guarantee that nodes that have control dependency on the function
1287 //    call, will observe all side-effects (guaranteed by graph construction with
1288 //    automatic control dependencies tracking).
1289 //
1290 // If after function instantiation we find a stateful or a dataset op inside
1291 // the function body, that is not reachable from any of the function outputs (or
1292 // if the function has no outputs), we do not inline it, because we can't
1293 // guarantee that these nodes will be executed in correct order (or executed at
1294 // all) after inlining.
1295 //
1296 // We do not try to add any extra control edges to make sure that all
1297 // side-effectful nodes will be executed, that should be handled at graph
1298 // construction time.
1299 
1300 struct MaybeDeadOutput {
1301   const NodeDef* dead_tensor_src;
1302   const NodeDef* output_node_dst;
1303 };
1304 
1305 // Finds all function outputs that might return a dead tensor. This can happen
1306 // if there is no `Merge` node on the path from the `Switch` node, to the
1307 // function output.
MaybeDeadOutputs(const FunctionOptimizerContext & ctx,const GrapplerFunctionItem & item,std::vector<MaybeDeadOutput> * maybe_dead)1308 Status MaybeDeadOutputs(const FunctionOptimizerContext& ctx,
1309                         const GrapplerFunctionItem& item,
1310                         std::vector<MaybeDeadOutput>* maybe_dead) {
1311   VLOG(3) << "Find function outputs that might return dead tensors: item.id="
1312           << item.id;
1313   DCHECK(maybe_dead->empty()) << "Input argument must be an empty vector";
1314 
1315   std::vector<const NodeDef*> dead_tensor_srcs;
1316   for (const NodeDef& node : item.graph.node()) {
1317     if (IsSwitch(node)) {
1318       VLOG(4) << "Add dead tensors source. Switch node: " << node.name();
1319       dead_tensor_srcs.push_back(&node);
1320       continue;
1321     }
1322 
1323     // Regular (aka 'direct') function call can also produce dead tensors if
1324     // the function body has mergeless switches.
1325     const FunctionDef* func = ctx.function_library().Find(node.op());
1326     if (func != nullptr) {
1327       GrapplerFunctionItem func_item;
1328       TF_RETURN_IF_ERROR(MakeGrapplerFunctionItem(
1329           *func, FunctionInstantiationAttributes(*func, node),
1330           ctx.function_library(), ctx.graph_version(), &func_item));
1331 
1332       std::vector<MaybeDeadOutput> func_dead_outputs;
1333       TF_RETURN_IF_ERROR(MaybeDeadOutputs(ctx, func_item, &func_dead_outputs));
1334 
1335       if (!func_dead_outputs.empty()) {
1336         VLOG(4) << "Add dead tensors source. Function call: " << node.op()
1337                 << " node=" << node.name();
1338         dead_tensor_srcs.push_back(&node);
1339       }
1340     }
1341   }
1342 
1343   // If we do not have dead tensor sources in the function body, it's
1344   // guaranteed that all output tensors can't become dead.
1345   if (dead_tensor_srcs.empty()) return Status::OK();
1346 
1347   // Names of the function body nodes that return function output values.
1348   absl::flat_hash_set<absl::string_view> output_nodes;
1349   for (const auto& output_expansion : item.outputs()) {
1350     for (const auto& output_node : output_expansion.output_nodes) {
1351       output_nodes.insert(output_node);
1352     }
1353   }
1354 
1355   GraphTopologyView topology_view;
1356   TF_RETURN_IF_ERROR(topology_view.InitializeFromGraph(item.graph));
1357 
1358   for (const NodeDef* dead_tensor_src : dead_tensor_srcs) {
1359     DfsTraversal(topology_view, {dead_tensor_src},
1360                  TraversalDirection::kFollowOutputs,
1361                  // Stop traversal when reached first `Merge` node.
1362                  DfsPredicates::Advance(
1363                      [](const NodeDef* node) { return !IsMerge(*node); }),
1364                  // If we reached output node, add MaybeDeadOutput edge.
1365                  DfsCallbacks::PreOrder([&](const NodeDef* node) {
1366                    if (output_nodes.find(node->name()) != output_nodes.end()) {
1367                      maybe_dead->push_back({dead_tensor_src, node});
1368                    }
1369                  }));
1370   }
1371 
1372   return Status::OK();
1373 }
1374 
1375 // Returns `Status::OK()` iff `node` is an indirect function call of `func`, and
1376 // we know how to inline it into the main graph, otherwise returns and error
1377 // indicating why the function call is not inlinable.
IsInlinableIndirectFunctionCall(const FunctionOptimizerContext & ctx,const FunctionDef & func,const NodeDef & func_node)1378 Status IsInlinableIndirectFunctionCall(const FunctionOptimizerContext& ctx,
1379                                        const FunctionDef& func,
1380                                        const NodeDef& func_node) {
1381   // We inline direct function calls above, using different rules.
1382   if (!IsIndirectFunctionCall(func, func_node)) {
1383     return errors::InvalidArgument("Unsupported function call type: ",
1384                                    SummarizeNodeDef(func_node));
1385   }
1386 
1387   if (MarkedNoInline(func)) {
1388     return errors::FailedPrecondition(
1389         "Can't inline function marked with '_noinline': ",
1390         SummarizeNodeDef(func_node));
1391   }
1392 
1393   // Function specialization and inlining must be mutually exclusive.
1394   if (MarkedSpecialized(func)) {
1395     return errors::FailedPrecondition(
1396         "Can't inline function created in Grappler function specialization: ",
1397         SummarizeNodeDef(func_node));
1398   }
1399 
1400   // We can't inline functions that are in a fetch set, because it would
1401   // invalidate fetch tensors (function call node fully inlined and doesn't
1402   // exist in the optimized graph).
1403   if (ctx.IsFetchNode(func_node.name())) {
1404     return errors::FailedPrecondition(
1405         "Can't inline function in a Grappler item fetch set: ",
1406         SummarizeNodeDef(func_node));
1407   }
1408 
1409   return Status::OK();
1410 }
1411 
1412 // Checks that all side-effects will be executed in well defined order. We do it
1413 // by checking if there is a path from stateful/dataset ops to one of the
1414 // control output nodes.
CheckThatSideEffectsWillExecute(const FunctionOptimizerContext & ctx,const GraphTopologyView & graph_topo_view,const absl::flat_hash_set<string> control_output_nodes)1415 Status CheckThatSideEffectsWillExecute(
1416     const FunctionOptimizerContext& ctx,
1417     const GraphTopologyView& graph_topo_view,
1418     const absl::flat_hash_set<string> control_output_nodes) {
1419   // In aggressive mode we just print a warning for side-effectful nodes that
1420   // might not be executed after inlining.
1421   const bool aggressive = ctx.opt_level() == RewriterConfig::AGGRESSIVE;
1422 
1423   for (const NodeDef& func_body_node : graph_topo_view.graph()->node()) {
1424     const bool node_must_execute =
1425         IsDataset(func_body_node) ||
1426         IsStateful(func_body_node, &ctx.function_library());
1427 
1428     // If op has DT_RESOURCE argument it will be marked as stateful, though if
1429     // it only reads from that resource, it's allowed to prune it, because it
1430     // can't produce any visible side-effects.
1431     const bool read_only = IsReadVariableOp(func_body_node);
1432 
1433     if (read_only || !node_must_execute) continue;
1434 
1435     VLOG(3) << "Check that node " << func_body_node.name()
1436             << " will execute after inlining.";
1437     bool will_execute = false;
1438 
1439     // Check if we reached one of the output nodes.
1440     const auto callbacks = DfsCallbacks::PreOrder([&](const NodeDef* node) {
1441       if (control_output_nodes.contains(node->name())) {
1442         VLOG(4) << "Found a path to control output node: " << node->name();
1443         will_execute = true;
1444       }
1445     });
1446 
1447     // Stop if we already proved that node will execute.
1448     const auto predicates = DfsPredicates::Enter(
1449         [&](const NodeDef* node) { return !will_execute; });
1450 
1451     DfsTraversal(graph_topo_view, {&func_body_node},
1452                  TraversalDirection::kFollowOutputs, predicates, callbacks);
1453 
1454     if (!will_execute) {
1455       const string error_message = absl::StrCat(
1456           "Can't guarantee execution of a side-effectful node, that is not "
1457           "reachable from function outputs. Function body node: ",
1458           SummarizeNodeDef(func_body_node));
1459 
1460       if (aggressive) {
1461         LOG(WARNING) << error_message;
1462       } else {
1463         return errors::Internal(error_message);
1464       }
1465     }
1466   }
1467 
1468   return Status::OK();
1469 }
1470 
PlaceInlinedFunctionBody(const NodeDef & func_node,const GrapplerFunctionItem & item,const absl::flat_hash_map<absl::string_view,int> & input_placeholders_idx,FunctionOptimizerContext * ctx,GraphDef * placed_graph_def)1471 Status PlaceInlinedFunctionBody(
1472     const NodeDef& func_node, const GrapplerFunctionItem& item,
1473     const absl::flat_hash_map<absl::string_view, int>& input_placeholders_idx,
1474     FunctionOptimizerContext* ctx, GraphDef* placed_graph_def) {
1475   // Control flow lowering and Placer works with a Graph object.
1476   std::unique_ptr<Graph> func_body_graph =
1477       absl::make_unique<Graph>(ctx->function_library());
1478 
1479   GraphConstructorOptions opts;
1480   TF_RETURN_IF_ERROR(
1481       ConvertGraphDefToGraph(opts, item.graph, func_body_graph.get()));
1482 
1483   // ------------------------------------------------------------------------ //
1484   // Grappler receives the graph after PRE_PLACEMENT, Placer, and POST_PLACEMENT
1485   // passes, so each node has a valid device assignment. Also V2 control
1486   // flow ops (functional If and While) should have been lowered to V1 control
1487   // flow (Switch and Merge nodes). To keep the graph valid for execution we
1488   // must assign device to every inlined graph node, and also lower the control
1489   // flow.
1490 
1491   GraphOptimizationPassOptions opt_options;
1492   opt_options.graph = &func_body_graph;
1493   opt_options.flib_def = ctx->mutable_function_library();
1494 
1495   // TODO(ezhulenev): Should we run full PRE_PLACEMENT pass here? And
1496   // POST_PLACEMENT after placer?
1497   LowerIfWhilePass pass;
1498   TF_RETURN_IF_ERROR(pass.Run(opt_options));
1499 
1500   // ------------------------------------------------------------------------ //
1501   // Before placing the function body nodes we pin input placeholders to the
1502   // same device as their corresponding input nodes.
1503 
1504   for (Node* func_body_node : func_body_graph->nodes()) {
1505     const auto input_placeholder_idx =
1506         input_placeholders_idx.find(func_body_node->name());
1507 
1508     if (input_placeholder_idx != input_placeholders_idx.end()) {
1509       const int input_idx = input_placeholder_idx->second;
1510       const GraphView::OutputPort output_port =
1511           ctx->graph_view().GetRegularFanin({&func_node, input_idx});
1512 
1513       VLOG(3) << "Pin inlined function input node '" << func_body_node->name()
1514               << "' to the '" << output_port.node->device() << "' device.";
1515       func_body_node->set_requested_device(output_port.node->device());
1516     }
1517   }
1518 
1519   // ------------------------------------------------------------------------ //
1520   // After placing nodes corresponding to the function inputs, we need to assign
1521   // device placements to all other function body nodes.
1522 
1523   const DeviceSet* devices = ctx->devices();
1524 
1525   if (devices->devices().empty()) {
1526     // If there are no devices available for placer, we just put all nodes to
1527     // the same device as a function caller node. This can happen if Grappler is
1528     // running "offline", without active runtime session, for example as a part
1529     // of a batch job for graph analysis/optimization.
1530     VLOG(3) << "Assign function call node device to all function body nodes. "
1531             << "Device: " << func_node.device();
1532     for (Node* func_body_node : func_body_graph->nodes()) {
1533       func_body_node->set_requested_device(func_node.device());
1534     }
1535   } else {
1536     // If we are running in an active runtime session, Grappler will get the
1537     // graph after initial placing is done, and we should have devices for the
1538     // placer.
1539     VLOG(3) << "Run placer for instantiated function body. Devices: ["
1540             << absl::StrJoin(
1541                    devices->devices(), ", ",
1542                    [](string* out, const Device* d) { out->append(d->name()); })
1543             << "]";
1544 
1545     // Use function caller node device as a default for placer.
1546     const Device* default_device =
1547         devices->FindDeviceByName(func_node.device());
1548 
1549     Placer placer(func_body_graph.get(), devices, default_device);
1550     TF_RETURN_IF_ERROR(placer.Run());
1551   }
1552 
1553   // Convert Graph back to the placed GraphDef.
1554   func_body_graph->ToGraphDef(placed_graph_def);
1555 
1556   return Status::OK();
1557 }
1558 
InlineIndirectFunctionCall(const NodeDef & func_node,const FunctionDef & func,FunctionOptimizerContext * ctx,GraphDef * optimized_graph)1559 Status InlineIndirectFunctionCall(const NodeDef& func_node,
1560                                   const FunctionDef& func,
1561                                   FunctionOptimizerContext* ctx,
1562                                   GraphDef* optimized_graph) {
1563   VLOG(2) << "Inline indirect function call: " << SummarizeNodeDef(func_node);
1564   VLOG(4) << "Inlined function definition: " << DebugString(func);
1565   TF_RETURN_IF_ERROR(IsInlinableIndirectFunctionCall(*ctx, func, func_node));
1566 
1567   const AttrSlice func_instantiation_attr =
1568       FunctionInstantiationAttributes(func, func_node);
1569 
1570   GrapplerFunctionItem item;
1571   Status item_status = MakeGrapplerFunctionItem(func, func_instantiation_attr,
1572                                                 ctx->function_library(),
1573                                                 ctx->graph_version(), &item);
1574 
1575   if (!item_status.ok()) {
1576     return errors::InvalidArgument("Failed to inline function ", func_node.op(),
1577                                    " instantiated by ", func_node.name(),
1578                                    ". Error: ", item_status.error_message());
1579   }
1580 
1581   // `PartitionedCallOp` invokes functions with `allow_dead_tensors = true` to
1582   // reset dead flag, and return default initialized tensors instead of a dead
1583   // tensors. There is no way to express this in a regular Tensorflow graph, so
1584   // we choose not to inline if a function can have dead tensors as an output
1585   // position. In practice `mergeless switches` should not exists in a function
1586   // body, because tf-eager will only use v2 control flow ops.
1587   std::vector<MaybeDeadOutput> maybe_dead_outputs;
1588   TF_RETURN_IF_ERROR(MaybeDeadOutputs(*ctx, item, &maybe_dead_outputs));
1589   if (!maybe_dead_outputs.empty()) {
1590     struct MaybeDeadOutputFormatter {
1591       void operator()(string* out, const MaybeDeadOutput& md) const {
1592         absl::StrAppend(out, SummarizeNodeDef(*md.dead_tensor_src));
1593       }
1594     };
1595     return errors::FailedPrecondition(
1596         "Can't inline function with dead outputs. Dead tensor sources (size = ",
1597         maybe_dead_outputs.size(), "): ",
1598         absl::StrJoin(maybe_dead_outputs, "\n", MaybeDeadOutputFormatter()));
1599   }
1600 
1601   GraphView::InputPort control_input_port =
1602       ctx->graph_view().GetInputPort(func_node.name(), Graph::kControlSlot);
1603   GraphView::OutputPort control_output_port =
1604       ctx->graph_view().GetOutputPort(func_node.name(), Graph::kControlSlot);
1605 
1606   // Nodes that have side effects to the captured resources.
1607   std::vector<string> happens_before;
1608   absl::c_transform(
1609       ctx->graph_view().GetFanin(control_input_port),
1610       std::back_inserter(happens_before),
1611       [](const GraphView::OutputPort port) { return port.node->name(); });
1612 
1613   VLOG(3) << "Happens before set (size = " << happens_before.size()
1614           << "): " << absl::StrJoin(happens_before, ", ");
1615 
1616   // Nodes that must observe side effects to the captured resources.
1617   std::vector<string> happens_after;
1618   absl::c_transform(
1619       ctx->graph_view().GetFanout(control_output_port),
1620       std::back_inserter(happens_after),
1621       [](const GraphView::InputPort port) { return port.node->name(); });
1622 
1623   VLOG(3) << "Happens after set (size = " << happens_after.size()
1624           << "): " << absl::StrJoin(happens_after, ", ");
1625 
1626   // Regular (data) inputs to the function call.
1627   std::vector<SafeTensorId> inputs;
1628   for (const string& input : func_node.input()) {
1629     SafeTensorId tensor_id = ParseTensorName(input);
1630     if (tensor_id.index() == Graph::kControlSlot) break;
1631     inputs.push_back(tensor_id);
1632   }
1633 
1634   // Mapping from input placeholder name to function input position.
1635   absl::flat_hash_map<absl::string_view, int> input_placeholders_idx;
1636   for (const InputArgExpansion& input_arg : item.inputs()) {
1637     for (const string& placeholder : input_arg.placeholders) {
1638       const int idx = input_placeholders_idx.size();
1639       input_placeholders_idx[placeholder] = idx;
1640     }
1641   }
1642 
1643   const string prefix = strings::StrCat(func_node.name(), "/");
1644 
1645   // ------------------------------------------------------------------------ //
1646   // For each function output value we added an identity node that reads the
1647   // tensor from one of the function body nodes. When we inline function into
1648   // the main graph we want to bypass these nodes, so we keep a mapping from
1649   // 'output node name' -> 'output tensor name'.
1650   absl::flat_hash_map<string, string> output_tensors;
1651 
1652   // Unique names of nodes producing tensors in `output_tensors`.
1653   absl::flat_hash_set<string> output_tensors_nodes;
1654 
1655   // Identity nodes added to the function body in place of function outputs.
1656   absl::flat_hash_set<string> output_nodes;
1657   for (const OutputArgExpansion& output_arg : item.outputs()) {
1658     for (const string& output_node : output_arg.output_nodes) {
1659       output_nodes.insert(output_node);
1660     }
1661   }
1662 
1663   for (const NodeDef& func_body_node : item.graph.node()) {
1664     const string& node_name = func_body_node.name();
1665 
1666     if (IsIdentity(func_body_node) && output_nodes.count(node_name)) {
1667       const string& output_tensor = func_body_node.input(0);
1668       output_tensors.emplace(node_name, output_tensor);
1669 
1670       SafeTensorId tensor_id = ParseTensorName(output_tensor);
1671       output_tensors_nodes.insert(tensor_id.node());
1672     }
1673   }
1674 
1675   // ------------------------------------------------------------------------ //
1676   // IMPORTANT: Actual inputs will be added to the following nodes at the very
1677   // last stage, because we don't want to have invalid edges in a function body
1678   // graph (control edges that depend on the nodes in the "outer" optimized
1679   // graph).
1680 
1681   // If one of the function inputs is a dead tensor, we must not execute any of
1682   // the function body nodes, and let the dead tensor flag propagate through the
1683   // inlined function body. We add NoOp inputs_ready node, and add control edges
1684   // to it from all input nodes. Inlined function arguments (Identity nodes)
1685   // will have a control dependency on it.
1686   //
1687   // TODO(ezhulenev): We do not need to provide this guarantee for ALL nodes in
1688   // the function body. We must only ensure that we do not generate observable
1689   // side effects.
1690   //
1691   // If the function call node has incoming control edges, we will update them
1692   // to use this node as destination, to ensure side-effects execution order.
1693   NodeDef* inputs_ready_node = nullptr;
1694   if (func_node.input_size() > 0) {
1695     inputs_ready_node = item.graph.add_node();
1696     inputs_ready_node->set_op("NoOp");
1697     inputs_ready_node->set_name(kInputsReadyNodeName);
1698   }
1699 
1700   // All nodes that have a control edge from the function call node, will be
1701   // updated to have a control edge from 'side_effects_executed_node`. This node
1702   // will have control edges from all function control outputs (see
1703   // `control_ret` in FunctionDef). This a "barrier" that guarantees that all
1704   // ops with side effects in the function body were executed
1705   //
1706   // If the function call node has no outgoing control edges, it means that no
1707   // one is interested in the function side-effect affecting captured resources.
1708   //
1709   // If node is in keep_ops set, it means that it must execute. This could
1710   // happen if the graph is an instantiation of a function with control output.
1711   NodeDef* side_effects_executed_node = nullptr;
1712   if (!happens_after.empty() || ctx->IsKeepOp(func_node.name())) {
1713     side_effects_executed_node = item.graph.add_node();
1714     side_effects_executed_node->set_op("NoOp");
1715     side_effects_executed_node->set_name(kSideEffectsExecutedNodeName);
1716   }
1717 
1718   // If function executed only for the regular data outputs, it's totally safe
1719   // to prune side-effects. If side-effects order is important, it must be
1720   // captured at graph construction time via control edges.
1721   if (item.control_output_size() > 0 && happens_after.empty()) {
1722     VLOG(2) << "Function has control outputs and empty happens after set.";
1723   }
1724 
1725   // ------------------------------------------------------------------------ //
1726   // If we have a node inside the function body without inputs (e.g. Const), we
1727   // must attach a control dependency to it, to make sure that if a function
1728   // call happens inside a loop, the node will be evaluated in correct frame.
1729   //
1730   // If the function call node has no inputs and no control dependencies, it
1731   // means that it can't be a function call inside a loop, and we can safely
1732   // insert that node without inputs into the main graph.
1733   //
1734   // TODO(ezhulenev): Use FrameMap (see grappler/utils/frame.h) to find out if
1735   // the function is called inside a loop.
1736   std::vector<string> empty_inputs_hook;
1737   if (inputs_ready_node != nullptr) {
1738     empty_inputs_hook.push_back(inputs_ready_node->name());
1739   }
1740 
1741   // ------------------------------------------------------------------------ //
1742   // Grappler called after PRE_PLACEMENT and PLACEMENT passes, so we have to
1743   // make sure that after inlining all nodes will have valid device assignment.
1744 
1745   GraphDef placed_graph_def;
1746   TF_RETURN_IF_ERROR(PlaceInlinedFunctionBody(
1747       func_node, item, input_placeholders_idx, ctx, &placed_graph_def));
1748 
1749   // ------------------------------------------------------------------------ //
1750   // After all nodes placed we need to prepare them for inlining into the
1751   // optimized graph: turn placeholders into identities, update nodes
1752   // connectivity, etc...
1753 
1754   const auto inlined_node_name = [&func_node](const string& name) -> string {
1755     return AddPrefixToNodeName(name, /*prefix=*/func_node.name());
1756   };
1757 
1758   for (NodeDef& func_body_node : *placed_graph_def.mutable_node()) {
1759     const string& node_name = func_body_node.name();
1760 
1761     // Turn placeholders added in place of input arguments into identity nodes.
1762     const auto input_placeholder_idx = input_placeholders_idx.find(node_name);
1763     if (input_placeholder_idx != input_placeholders_idx.end()) {
1764       DCHECK_EQ(0, func_body_node.input_size());
1765       func_body_node.set_op("Identity");
1766       (*func_body_node.mutable_attr())["T"] = func_body_node.attr().at("dtype");
1767       func_body_node.mutable_attr()->erase("dtype");
1768       func_body_node.mutable_attr()->erase("shape");
1769       const int input_idx = input_placeholder_idx->second;
1770       func_body_node.add_input(inputs[input_idx].ToString());
1771 
1772       // Add a control dependency on 'inputs_ready' node, to guarantee that all
1773       // inputs are alive and all side-effects executed before function body.
1774       if (inputs_ready_node) {
1775         func_body_node.add_input(
1776             AsControlDependency(inlined_node_name(inputs_ready_node->name())));
1777       }
1778     } else {
1779       // Update inputs of the regular function body nodes.
1780       for (string& input : *func_body_node.mutable_input()) {
1781         input = inlined_node_name(input);
1782       }
1783 
1784       // Check if we need to ensure node execution in correct loop frame.
1785       bool node_needs_empty_inputs_hook =
1786           // We have a node to hook and node has no inputs.
1787           !empty_inputs_hook.empty() && func_body_node.input_size() == 0 &&
1788           // Inputs ready node will always have edge from main graph. If
1789           // function call has no regular and control inputs, we will not add
1790           // inputs_ready node to the function body graph.
1791           node_name != kInputsReadyNodeName &&
1792           // The node acting as a return barrier for execution of side effects
1793           // might not have any inputs (in case function has no control outputs,
1794           // but we still added it because of non-empty happens-after set), so
1795           // we must make sure it's executed in correct frame.
1796           (node_name != kSideEffectsExecutedNodeName ||
1797            item.control_output_size() == 0);
1798 
1799       if (node_needs_empty_inputs_hook) {
1800         *func_body_node.add_input() =
1801             AsControlDependency(inlined_node_name(empty_inputs_hook[0]));
1802       }
1803     }
1804 
1805     // Add the function node name as a prefix 1) to node name to avoid
1806     // collisions; 2) to frame name to avoid multiple LoopCond nodes in one
1807     // frame after inlining.
1808     TF_RETURN_IF_ERROR(
1809         AddPrefixAndSuffixToNode(prefix, /*suffix=*/"", &func_body_node));
1810 
1811     // After inlining into the optimized graph, NodeDef must have all attributes
1812     // defined, which is not required for a node in a FunctionDef.
1813     const OpDef* op_def;
1814     TF_RETURN_IF_ERROR(
1815         ctx->function_library().LookUpOpDef(func_body_node.op(), &op_def));
1816     AddDefaultsToNodeDef(*op_def, &func_body_node);
1817   }
1818 
1819   // ------------------------------------------------------------------------ //
1820   // Check that after inlining all side-effects will be executed in well defined
1821   // order. We do it by checking if there is a path from stateful/dataset ops to
1822   // one of the output nodes.
1823 
1824   // Because we rename all the nodes before inlining, we need a copy of
1825   // output_nodes with a new names.
1826   absl::flat_hash_set<string> inlined_output_nodes;
1827   for (const string& output_node : output_nodes) {
1828     inlined_output_nodes.insert(inlined_node_name(output_node));
1829   }
1830   const auto is_inlined_output_node = [&](const NodeDef& node) -> bool {
1831     return inlined_output_nodes.find(node.name()) != inlined_output_nodes.end();
1832   };
1833 
1834   // Names of the inlined control output nodes.
1835   absl::flat_hash_set<string> inlined_control_output_nodes;
1836   for (const ControlOutput& control_output : item.control_outputs()) {
1837     inlined_control_output_nodes.insert(
1838         inlined_node_name(control_output.node_name));
1839   }
1840 
1841   // Construct a graph topology view for DFS traversals (skip invalid edges for
1842   // input nodes connected to nodes in the optimized graph).
1843   GraphTopologyView placed_topo_view(/*skip_invalid_edges=*/true);
1844   TF_RETURN_IF_ERROR(placed_topo_view.InitializeFromGraph(placed_graph_def));
1845   TF_RETURN_IF_ERROR(CheckThatSideEffectsWillExecute(
1846       *ctx, placed_topo_view, inlined_control_output_nodes));
1847 
1848   // ------------------------------------------------------------------------ //
1849   // Move all the nodes to the optimized graph after successful preprocessing.
1850 
1851   if (inputs_ready_node != nullptr) {
1852     string inlined_node = inlined_node_name(inputs_ready_node->name());
1853     absl::optional<int> node_idx = placed_topo_view.GetNodeIndex(inlined_node);
1854 
1855     absl::flat_hash_set<string> input_nodes;
1856     for (const string& input : func_node.input()) {
1857       SafeTensorId tensor = ParseTensorName(input);
1858 
1859       // Input node might have been a function call that was already inlined.
1860       auto it = ctx->tensor_mapping().find(tensor);
1861       while (it != ctx->tensor_mapping().end()) {
1862         tensor = it->second;
1863         it = ctx->tensor_mapping().find(tensor);
1864       }
1865 
1866       if (input_nodes.insert(tensor.node()).second) {
1867         placed_graph_def.mutable_node(*node_idx)->add_input(
1868             AsControlDependency(tensor.node()));
1869       }
1870     }
1871   }
1872 
1873   if (side_effects_executed_node != nullptr) {
1874     string inlined_node = inlined_node_name(side_effects_executed_node->name());
1875     absl::optional<int> node_idx = placed_topo_view.GetNodeIndex(inlined_node);
1876 
1877     // Add control edges from all control output nodes.
1878     for (const string& node_name : inlined_control_output_nodes) {
1879       placed_graph_def.mutable_node(*node_idx)->add_input(
1880           AsControlDependency(node_name));
1881     }
1882 
1883     // Forward all control dependencies in the optimized graph to the new node.
1884     ctx->AddControlOverrides(func_node, {inlined_node});
1885   }
1886 
1887   for (NodeDef& func_body_node : *placed_graph_def.mutable_node()) {
1888     // Skip output identity nodes.
1889     if (IsIdentity(func_body_node) && is_inlined_output_node(func_body_node))
1890       continue;
1891 
1892     optimized_graph->add_node()->Swap(&func_body_node);
1893   }
1894 
1895   // Indirect function call is fully inlined into the optimized graph, and we do
1896   // not copy the original function call node, so we have to setup tensor
1897   // mapping from old output tensors, to the outputs of inlined nodes.
1898   int output_idx = 0;
1899   for (const OutputArgExpansion& output : item.outputs()) {
1900     for (const string& output_node : output.output_nodes) {
1901       const string& output_tensor = output_tensors.at(output_node);
1902 
1903       const SafeTensorId from_tensor(func_node.name(), output_idx++);
1904       const SafeTensorId to_tensor = ParseTensorName(output_tensor);
1905 
1906       const SafeTensorId inlined_to_tensor =
1907           SafeTensorId(absl::StrCat(func_node.name(), "/", to_tensor.node()),
1908                        to_tensor.index());
1909 
1910       ctx->AddTensorMapping(from_tensor, inlined_to_tensor);
1911     }
1912   }
1913 
1914   // If function call node was in keep_ops set, it means that we need to keep a
1915   // node with the same name in the optimized graph. We forward all data
1916   // consumers to inlined nodes, and we verify that the node is not in a fetch
1917   // set, so it's safe to assume that the function call node is only required
1918   // for a control edge source.
1919   if (ctx->IsKeepOp(func_node.name())) {
1920     VLOG(4) << "Add NoOp for inlined function in keep ops set.";
1921     NodeDef* keep_func_node = optimized_graph->add_node();
1922     keep_func_node->set_op("NoOp");
1923     keep_func_node->set_name(func_node.name());
1924     keep_func_node->set_device(func_node.device());
1925     keep_func_node->add_input(
1926         AsControlDependency(inlined_node_name(kSideEffectsExecutedNodeName)));
1927   }
1928 
1929   VLOG(3) << "Successfully inlined indirect function call: "
1930           << SummarizeNodeDef(func_node);
1931 
1932   return Status::OK();
1933 }
1934 
1935 // Restores graph invariants after function specialization and inlining: all
1936 // inputs must be connected to valid nodes.
RestoreGraphInvariants(const FunctionOptimizerContext & ctx,GraphDef * optimized_graph)1937 Status RestoreGraphInvariants(const FunctionOptimizerContext& ctx,
1938                               GraphDef* optimized_graph) {
1939   // After function specialization and inlining graph might be in invalid
1940   // state, and some nodes can read tensors that do not exists anymore in the
1941   // optimized graph: function call node was fully inlined into the graph, or
1942   // output index was invalidated by the output pruning.
1943 
1944   if (!ctx.tensor_mapping().empty()) {
1945     for (NodeDef& node : *optimized_graph->mutable_node()) {
1946       for (int idx = 0; idx < node.input_size(); ++idx) {
1947         TensorId input_tensor = ParseTensorName(node.input(idx));
1948         if (input_tensor.index() == Graph::kControlSlot) break;
1949 
1950         auto mapping = ctx.tensor_mapping().find(input_tensor);
1951         if (mapping != ctx.tensor_mapping().end()) {
1952           node.set_input(idx, mapping->second.ToString());
1953         }
1954       }
1955     }
1956   }
1957 
1958   // Function inlining instantiates function body directly into the optimized
1959   // graph, and we might end up with control dependencies to the nodes that no
1960   // longer exist in a graph. We need to apply control overrides to all
1961   // invalidated nodes, and rewire control dependencies to the control outputs
1962   // node (it's also possible to rewrite singe control edge into multiple edges
1963   // to inlined side-effectful nodes).
1964 
1965   if (!ctx.control_overrides().empty()) {
1966     for (NodeDef& node : *optimized_graph->mutable_node()) {
1967       // Keep track of new control inputs to the node.
1968       absl::flat_hash_set<string> add_ctrl_inputs;
1969 
1970       // Remove all invalidated control inputs.
1971       for (int idx = 0; idx < node.input_size(); /* see below */) {
1972         // TODO(ezhulenev): Use non-allocating TensorId after migrating
1973         // `control_overrides()` to absl::flat_hash_set.
1974         SafeTensorId input_tensor = ParseTensorName(node.input(idx));
1975 
1976         auto overrides = ctx.control_overrides().find(input_tensor.node());
1977         if (overrides != ctx.control_overrides().end()) {
1978           // If this happens it's a bug in the function inlining.
1979           if (input_tensor.index() != Graph::kControlSlot) {
1980             return errors::Internal(
1981                 "Illegal input edge from inlined function call node");
1982           }
1983           // Remove control dependency to the inlined function call node.
1984           node.mutable_input()->SwapElements(idx, node.input_size() - 1);
1985           node.mutable_input()->RemoveLast();
1986 
1987           // Keep track of all overrides.
1988           for (const string& override : overrides->second) {
1989             add_ctrl_inputs.insert(AsControlDependency(override));
1990           }
1991         } else {
1992           // Go to the next input only if the current one was not invalidated,
1993           // otherwise we need to check the swapped input as well.
1994           ++idx;
1995         }
1996       }
1997 
1998       // Add overrides to the node inputs.
1999       for (const string& ctrl_input : add_ctrl_inputs) {
2000         node.add_input(ctrl_input);
2001       }
2002     }
2003   }
2004 
2005   return Status::OK();
2006 }
2007 
2008 }  // namespace
2009 
RunFunctionOptimizerPass(const GrapplerItem & item,const GraphDef & graph,const int iteration,std::unordered_set<string> * skip_nodes,GraphDef * optimized_graph,bool * graph_has_unoptimized_function_calls) const2010 Status FunctionOptimizer::RunFunctionOptimizerPass(
2011     const GrapplerItem& item, const GraphDef& graph, const int iteration,
2012     std::unordered_set<string>* skip_nodes, GraphDef* optimized_graph,
2013     bool* graph_has_unoptimized_function_calls) const {
2014   VLOG(3) << absl::Substitute(
2015       "Run function optimizer pass (iteration = $0): grappler_item_id = $1",
2016       iteration, item.id);
2017 
2018   FunctionOptimizerContext ctx(item, opt_level_, graph);
2019 
2020   bool inline_gradients = options_.enable_symbolic_gradient_inlining;
2021   bool inline_func = options_.enable_function_inlining;
2022   bool specialize_func = options_.enable_function_specialization;
2023 
2024   // We will process all the nodes in topological order, to correctly handle
2025   // inlining of function call chains.
2026   std::vector<const NodeDef*> topo_ordered_nodes;
2027   TF_RETURN_IF_ERROR(ComputeTopologicalOrder(graph, &topo_ordered_nodes));
2028 
2029   for (const NodeDef* node : topo_ordered_nodes) {
2030     // Each node optimization can modify optimized graph only by adding new
2031     // nodes, we can check node size to make sure that graph was not modified.
2032     const int num_nodes_before = optimized_graph->node_size();
2033     const auto is_graph_modified = [&]() {
2034       int num_nodes = optimized_graph->node_size();
2035       DCHECK_GE(num_nodes, num_nodes_before) << "Nodes should not be removed";
2036       return num_nodes > num_nodes_before;
2037     };
2038 
2039     // Copy node from the `graph` to the `optimized_graph`.
2040     const auto copy_node = [&]() { *optimized_graph->add_node() = *node; };
2041 
2042     // If we already failed to optimize this node during one of the previous
2043     // passes, we just give up, and do not try on more time.
2044     if (skip_nodes->find(node->name()) != skip_nodes->end()) {
2045       VLOG(3) << "Skip optimization for node: " << node->name();
2046       copy_node();
2047       continue;
2048     }
2049 
2050 // Skip errors if optimized graph was not modified before error happened.
2051 #define TF_SKIP_ERROR_IF_GRAPH_UNMODIFIED(...)                     \
2052   do {                                                             \
2053     const Status _status = (__VA_ARGS__);                          \
2054     if (TF_PREDICT_FALSE(!_status.ok() && is_graph_modified()))    \
2055       return _status;                                              \
2056     if (TF_PREDICT_FALSE(!_status.ok() && !is_graph_modified())) { \
2057       VLOG(3) << "Skip error: " << _status.error_message();        \
2058       skip_nodes->insert(node->name());                            \
2059       copy_node();                                                 \
2060     }                                                              \
2061   } while (0)
2062 
2063     // ---------------------------------------------------------------------- //
2064     // 1. Inline symbolic gradients into the optimized graph.                 //
2065     // ---------------------------------------------------------------------- //
2066 
2067     if (IsSymbolicGradient(*node) && inline_gradients) {
2068       // Inline symbolic gradients only if the corresponding function is not
2069       // marked as `_noinline`.
2070       const auto* f_attr = gtl::FindOrNull(node->attr(), "f");
2071       const string f_name = f_attr != nullptr ? f_attr->func().name() : "";
2072       const FunctionDef* func = ctx.function_library().Find(f_name);
2073       if (func && !MarkedNoInline(*func)) {
2074         TF_SKIP_ERROR_IF_GRAPH_UNMODIFIED(
2075             InlineSymbolicGradient(*node, &ctx, optimized_graph));
2076         continue;
2077       } else {
2078         VLOG(2) << "Skip SymbolicGradient inlining: function=" << f_name;
2079         skip_nodes->insert(node->name());
2080       }
2081     }
2082 
2083     // ---------------------------------------------------------------------- //
2084     // 2. Inline or specialize function calls.                                //
2085     // ---------------------------------------------------------------------- //
2086 
2087     // Find if a node is a function call (direct or indirect).
2088     const FunctionDef* func = FindFunctionCall(ctx, *node);
2089 
2090     if (func != nullptr) {
2091       const string& func_name = func->signature().name();
2092 
2093       const bool is_direct_func = IsDirectFunctionCall(*func, *node);
2094       const bool is_indirect_func = IsIndirectFunctionCall(*func, *node);
2095 
2096       // 2a. Inline direct function call if it's inlinable.
2097       if (inline_func && is_direct_func) {
2098         Status inlinable = IsInlinableDirectFunctionCall(ctx, *func, *node);
2099         if (inlinable.ok()) {
2100           TF_SKIP_ERROR_IF_GRAPH_UNMODIFIED(
2101               InlineDirectFunctionCall(*node, *func, ctx, optimized_graph));
2102           continue;
2103         } else {
2104           VLOG(2) << inlinable.error_message();
2105           skip_nodes->insert(node->name());
2106         }
2107       }
2108 
2109       // 2b. Inline indirect function call if it's inlinable.
2110       if (inline_func && is_indirect_func) {
2111         Status inlinable = IsInlinableIndirectFunctionCall(ctx, *func, *node);
2112         if (inlinable.ok()) {
2113           TF_SKIP_ERROR_IF_GRAPH_UNMODIFIED(
2114               InlineIndirectFunctionCall(*node, *func, &ctx, optimized_graph));
2115           continue;
2116         } else {
2117           VLOG(2) << inlinable.error_message();
2118           skip_nodes->insert(node->name());
2119         }
2120       }
2121 
2122       // 2c. Specialize it to its instantiation context if can't be inlined,
2123       // and it has something worth specializing.
2124       bool specialization_worthy = IsParametrized(*func) ||
2125                                    HasTrulyConstInputs(*node, ctx) ||
2126                                    HasUnusedOutputs(*node, *func, ctx);
2127 
2128       // Do not specialize if function has custom gradient.
2129       const string grad_func = ctx.function_library().FindGradient(func_name);
2130 
2131       if (specialize_func && grad_func.empty() && specialization_worthy) {
2132         // TODO(ezhulenev): Specialize function call if input has a known shape.
2133         // Specialize function body for its instantiation attributes and inputs.
2134         TF_SKIP_ERROR_IF_GRAPH_UNMODIFIED(
2135             SpecializeFunction(*node, *func, &ctx, optimized_graph));
2136         continue;
2137       } else {
2138         VLOG(2) << "Skip function specialization: " << func->signature().name();
2139         skip_nodes->insert(node->name());
2140       }
2141     }
2142 
2143     // ---------------------------------------------------------------------- //
2144     // If we reached this point, node was not handled by any of the stages
2145     // (inline, specialize), simply copy the node to the optimized graph.
2146     copy_node();
2147 
2148 #undef TF_SKIP_ERROR_IF_GRAPH_UNMODIFIED
2149   }
2150 
2151   TF_RETURN_IF_ERROR(RestoreGraphInvariants(ctx, optimized_graph));
2152 
2153   // Preserve the graph version.
2154   *optimized_graph->mutable_versions() = graph.versions();
2155 
2156   // Prune unreachable function from the library.
2157   if (options_.enable_trim_function_library) {
2158     *optimized_graph->mutable_library() =
2159         PruneFunctionLibrary(ctx.function_library(), *optimized_graph);
2160   } else {
2161     *optimized_graph->mutable_library() = ctx.function_library().ToProto();
2162   }
2163 
2164   // Before returning we check if after single optimization pass we have more
2165   // unoptimized function calls.
2166   *graph_has_unoptimized_function_calls = false;
2167   for (const NodeDef& node : optimized_graph->node()) {
2168     // Check if we can inline symbolic gradient.
2169     if (IsSymbolicGradient(node) && inline_gradients &&
2170         skip_nodes->count(node.name()) == 0) {
2171       *graph_has_unoptimized_function_calls = true;
2172       break;
2173     }
2174 
2175     // Check if after inlining we have unoptimized function calls.
2176     const FunctionDef* func = FindFunctionCall(ctx, node);
2177     if (func != nullptr && !MarkedSpecialized(*func) &&
2178         skip_nodes->count(node.name()) == 0) {
2179       *graph_has_unoptimized_function_calls = true;
2180       break;
2181     }
2182   }
2183 
2184   return Status::OK();
2185 }
2186 
Optimize(Cluster *,const GrapplerItem & item,GraphDef * optimized_graph)2187 Status FunctionOptimizer::Optimize(Cluster*, const GrapplerItem& item,
2188                                    GraphDef* optimized_graph) {
2189   // Nothing to do here.
2190   if (item.graph.library().function_size() == 0) {
2191     *optimized_graph = item.graph;
2192     return Status::OK();
2193   }
2194 
2195   // Do not retry failed function inlining or specialization.
2196   std::unordered_set<string> skip_nodes;
2197   bool graph_has_unoptimized_function_calls = false;
2198 
2199   // We'll keep running function optimizer pass until we inlined and optimized
2200   // all function call nodes.
2201   int iteration = 0;
2202   constexpr int kMaxIterations = 50;
2203 
2204   // 1. Run first optimizer pass with GrapplerItem.graph.
2205   TF_RETURN_IF_ERROR(RunFunctionOptimizerPass(
2206       item, item.graph, 0, &skip_nodes, optimized_graph,
2207       &graph_has_unoptimized_function_calls));
2208 
2209   // 2. If after function inlining we have unoptimized function calls, we have
2210   // to run function optimization pass one more time.
2211   while (graph_has_unoptimized_function_calls) {
2212     if (iteration++ > kMaxIterations) {
2213       VLOG(1) << "Break function optimizer loop at iteration #" << iteration;
2214       break;
2215     }
2216 
2217     GraphDef workspace_graph;
2218     workspace_graph.Swap(optimized_graph);
2219 
2220     TF_RETURN_IF_ERROR(RunFunctionOptimizerPass(
2221         item, workspace_graph, iteration, &skip_nodes, optimized_graph,
2222         &graph_has_unoptimized_function_calls));
2223   }
2224 
2225   return Status::OK();
2226 }
2227 
Feedback(Cluster * cluster,const GrapplerItem & item,const GraphDef & optimized_graph,double result)2228 void FunctionOptimizer::Feedback(Cluster* cluster, const GrapplerItem& item,
2229                                  const GraphDef& optimized_graph,
2230                                  double result) {
2231   // Nothing to do for FunctionOptimizer.
2232 }
2233 
2234 }  // end namespace grappler
2235 }  // end namespace tensorflow
2236