1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/grappler/optimizers/function_optimizer.h"
17 
18 #include <vector>
19 
20 #include "absl/algorithm/container.h"
21 #include "absl/container/flat_hash_map.h"
22 #include "absl/container/flat_hash_set.h"
23 #include "absl/memory/memory.h"
24 #include "absl/strings/str_replace.h"
25 #include "absl/strings/substitute.h"
26 #include "tensorflow/compiler/jit/defs.h"
27 #include "tensorflow/core/common_runtime/device.h"
28 #include "tensorflow/core/common_runtime/device_mgr.h"
29 #include "tensorflow/core/common_runtime/device_set.h"
30 #include "tensorflow/core/common_runtime/function.h"
31 #include "tensorflow/core/common_runtime/graph_constructor.h"
32 #include "tensorflow/core/common_runtime/lower_case_op.h"
33 #include "tensorflow/core/common_runtime/lower_functional_ops.h"
34 #include "tensorflow/core/common_runtime/lower_if_op.h"
35 #include "tensorflow/core/common_runtime/lower_while_op.h"
36 #include "tensorflow/core/common_runtime/placer.h"
37 #include "tensorflow/core/framework/attr_value_util.h"
38 #include "tensorflow/core/framework/function.h"
39 #include "tensorflow/core/framework/function.pb.h"
40 #include "tensorflow/core/framework/graph_def_util.h"
41 #include "tensorflow/core/framework/node_def.pb.h"
42 #include "tensorflow/core/framework/node_def_util.h"
43 #include "tensorflow/core/framework/op_def.pb.h"
44 #include "tensorflow/core/framework/versions.pb.h"
45 #include "tensorflow/core/graph/algorithm.h"
46 #include "tensorflow/core/graph/control_flow.h"
47 #include "tensorflow/core/graph/graph_node_util.h"
48 #include "tensorflow/core/graph/tensor_id.h"
49 #include "tensorflow/core/grappler/graph_view.h"
50 #include "tensorflow/core/grappler/grappler_item.h"
51 #include "tensorflow/core/grappler/op_types.h"
52 #include "tensorflow/core/grappler/utils.h"
53 #include "tensorflow/core/grappler/utils/functions.h"
54 #include "tensorflow/core/lib/gtl/map_util.h"
55 
56 namespace tensorflow {
57 namespace grappler {
58 namespace {
59 
60 constexpr const char* const kFuncAttr = FunctionLibraryDefinition::kFuncAttr;
61 
62 // Do not specialize functions marked with '_nospecialize' attribute.
63 constexpr const char* const kNoSpecializeAttr = "_nospecialize";
64 
65 // Mark functions that were created as a result of function specialization.
66 constexpr const char* const kGrapplerSpecializedFuncAttr =
67     "_GrapplerSpecializedFunc";
68 
69 // There are two ways of calling a Tensorflow function:
70 //
71 // 1. Direct function call: node.op() is the name of the function.
72 //
73 // 2. Indirect function call: the function name is passed through a node
74 //    attribute, and special Tensorflow kernels are responsible for calling the
75 //    function through the FunctionLibraryRuntime. Example: PartitionedCallOp.
76 
77 // Check if func_node.op() matches the name in FunctionDef signature.
IsDirectFunctionCall(const FunctionDef & func,const NodeDef & func_node)78 bool IsDirectFunctionCall(const FunctionDef& func, const NodeDef& func_node) {
79   return func_node.op() == func.signature().name();
80 }
81 
82 // Check if func_node has function attribute with a function name matching
83 // FunctionDef signature.
IsIndirectFunctionCall(const FunctionDef & func,const NodeDef & func_node)84 bool IsIndirectFunctionCall(const FunctionDef& func, const NodeDef& func_node) {
85   if (!IsPartitionedCall(func_node) && !IsStatefulPartitionedCall(func_node)) {
86     return false;
87   }
88 
89   auto* func_attr = AttrSlice(func_node).Find(kFuncAttr);
90   return func_attr != nullptr && func_attr->has_func() &&
91          func_attr->func().name() == func.signature().name();
92 }
93 
FunctionInstantiationAttributes(const FunctionDef & func,const NodeDef & func_node)94 AttrSlice FunctionInstantiationAttributes(const FunctionDef& func,
95                                           const NodeDef& func_node) {
96   if (IsDirectFunctionCall(func, func_node)) {
97     return AttrSlice(func_node);
98 
99   } else if (IsIndirectFunctionCall(func, func_node)) {
100     auto* func_attr = AttrSlice(func_node).Find(kFuncAttr);
101     return AttrSlice(&func_attr->func().attr());
102 
103   } else {
104     LOG(WARNING) << "Can't resolve function instantiation attributes: "
105                  << SummarizeNodeDef(func_node);
106     return AttrSlice();
107   }
108 }
109 
110 // This is a fake device that should not be used for any op kernel execution,
111 // the only purpose of this device is to be passed as a part of DeviceSet to the
112 // Placer.
113 class FakeDevice : public Device {
114  public:
FakeDevice(Env * env,const string & device)115   FakeDevice(Env* env, const string& device) : Device(env, attr(device)) {}
FakeDevice(const string & device)116   explicit FakeDevice(const string& device) : FakeDevice(nullptr, device) {}
Sync()117   Status Sync() override { return Status::OK(); }
118 
119  private:
attr(const string & device)120   static DeviceAttributes attr(const string& device) {
121     DeviceNameUtils::ParsedName parsed_name;
122     bool parsed = DeviceNameUtils::ParseFullName(device, &parsed_name);
123     DCHECK(parsed) << "Failed to parse full device name: " << device;
124 
125     DeviceAttributes attr;
126     attr.set_name(device);
127     attr.set_device_type(parsed_name.type);
128     return attr;
129   }
130 };
131 
132 // -------------------------------------------------------------------------- //
133 // Function specialization.
134 //
135 // FunctionDef is somewhat similar to function template in C++, given all the
136 // type parameters (and attribute values) it generates a statically defined
137 // graph from the type parametrized "graph template" (function body).
138 //
139 // Function specialization instantiates a parametrized FunctionDef into a
140 // statically defined graph, and then converts it back to the fully defined
141 // FunctionDef (it doesn't have any unknown type parameters or attribute
142 // values, known as placeholders).
143 //
144 // Given the fully specified graph we can apply all the Grappler optimizers to
145 // it (see details in MetaOptimizer). Also we can push known constant inputs
146 // into the function body, and remove unused outputs/inputs.
147 
MarkedNoSpecialize(const FunctionDef & fdef)148 bool MarkedNoSpecialize(const FunctionDef& fdef) {
149   const auto attr = AttrSlice(&fdef.attr());
150   bool nospecialize = false;
151   return TryGetNodeAttr(attr, kNoSpecializeAttr, &nospecialize) && nospecialize;
152 }
153 
154 // Specialized function instantiation type parameters, body parameters, and
155 // const inputs.
156 struct FunctionSpecializationSignature {
157   // Currently we do not support functions with tensor lists as inputs or
158   // outputs, so caller node input/output ports always match function
159   // input/output arguments.
160   using InputPort = int;
161   using OutputPort = int;
162 
163   string func_name;
164   bool is_in_fetch_set;
165   absl::flat_hash_set<OutputPort> active_outputs;
166   absl::flat_hash_map<string, DataType> type_parameters;
167   absl::flat_hash_map<string, AttrValue> body_parameters;
168   absl::flat_hash_map<InputPort, string> const_inputs;
169 
operator ==tensorflow::grappler::__anondd39c0ba0111::FunctionSpecializationSignature170   bool operator==(const FunctionSpecializationSignature& other) const {
171     bool equals = func_name == other.func_name &&
172                   is_in_fetch_set == other.is_in_fetch_set &&
173                   active_outputs == other.active_outputs &&
174                   type_parameters == other.type_parameters &&
175                   const_inputs == other.const_inputs;
176 
177     if (!equals) return false;
178 
179     // Equality is not defined for AttrValue.
180     if (body_parameters.size() != other.body_parameters.size()) return false;
181 
182     for (const auto& lhs : body_parameters) {
183       auto it = other.body_parameters.find(lhs.first);
184       if (it == other.body_parameters.end()) return false;
185       if (!FastAreAttrValuesEqual(lhs.second, (*it).second)) return false;
186     }
187 
188     return true;
189   }
190 
191   template <typename H>
AbslHashValue(H h,const FunctionSpecializationSignature & s)192   friend H AbslHashValue(H h, const FunctionSpecializationSignature& s) {
193     H base = H::combine(std::move(h), s.func_name, s.is_in_fetch_set);
194 
195     // First pre-compute hashes for all values in collections with
196     // non-deterministic iteration order.
197     std::vector<uint64> hashes;
198     hashes.reserve(s.active_outputs.size()         //
199                    + s.type_parameters.size() * 2  //
200                    + s.body_parameters.size() * 2  //
201                    + s.const_inputs.size() * 2);
202 
203     absl::c_transform(s.active_outputs, std::back_inserter(hashes),
204                       hash<OutputPort>());
205 
206     using TypeParam = std::pair<const string, DataType>;
207     absl::c_for_each(s.type_parameters, [&hashes](const TypeParam& type_param) {
208       AttrValue attr_value;
209       attr_value.set_type(type_param.second);
210       hashes.push_back(Hash64(type_param.first));
211       hashes.push_back(AttrValueHash(attr_value));
212     });
213 
214     using BodyParam = std::pair<const string, AttrValue>;
215     absl::c_for_each(s.body_parameters, [&hashes](const BodyParam& body_param) {
216       hashes.push_back(Hash64(body_param.first));
217       hashes.push_back(FastAttrValueHash(body_param.second));
218     });
219 
220     using ConstInput = std::pair<const InputPort, string>;
221     absl::c_for_each(s.const_inputs, [&hashes](const ConstInput& const_input) {
222       hashes.push_back(hash<InputPort>()(const_input.first));
223       hashes.push_back(Hash64(const_input.second));
224     });
225 
226     // Combine all pre-computed hashes in a deterministic order.
227     absl::c_sort(hashes);
228     return H::combine_contiguous(std::move(base), hashes.data(), hashes.size());
229   }
230 };
231 
232 struct FunctionSpecialization {
233   string specialized_func_name;
234   // True if the function caller node is in GrapplerItem fetch set.
235   bool is_in_fetch_set;
236   // Names of the tensors that were pushed down into the function body.
237   absl::flat_hash_set<string> const_inputs;
238   // Control dependencies of pushed down const inputs have to be attached to
239   // function caller node.
240   absl::flat_hash_set<string> control_deps;
241   // Output tensors (ports) that consumed by other nodes in the graph or in a
242   // GrapplerItem fetch set.
243   absl::flat_hash_set<int> active_outputs;
244   // Mapping from original function output port to the output port of
245   // specialized function. If function specialization changes the number of
246   // function outputs it's required to update all node consumers.
247   std::vector<std::pair<int, int>> output_mapping;
248 };
249 
250 // Function optimizer context initialized once for each optimization pass, and
251 // it uses the latest available graph (for the first iteration it will be the
252 // GrapplerItem.graph, for next iterations it will be the output of previous
253 // function optimizer pass).
254 class FunctionOptimizerContext {
255  public:
FunctionOptimizerContext(const GrapplerItem & item,RewriterConfig::Toggle opt_level,const GraphDef & graph)256   explicit FunctionOptimizerContext(const GrapplerItem& item,
257                                     RewriterConfig::Toggle opt_level,
258                                     const GraphDef& graph)
259       : item_(&item),
260         opt_level_(opt_level),
261         function_library_(OpRegistry::Global(), graph.library()),
262         truly_const_nodes_(InferTrulyConstNodes(item, graph)),
263         graph_view_(&graph) {}
264 
item() const265   const GrapplerItem& item() const { return *item_; }
266 
graph_version() const267   const int graph_version() const { return item_->graph.versions().producer(); }
268 
opt_level() const269   RewriterConfig::Toggle opt_level() const { return opt_level_; }
270 
function_library() const271   const FunctionLibraryDefinition& function_library() const {
272     return function_library_;
273   }
function_library()274   FunctionLibraryDefinition& function_library() { return function_library_; }
275 
276   const absl::flat_hash_map<SafeTensorId, SafeTensorId, SafeTensorId::Hasher>&
tensor_mapping() const277   tensor_mapping() const {
278     return tensor_mapping_;
279   }
280 
graph_view() const281   const GraphView& graph_view() const { return graph_view_; }
282 
IsFeedNode(const string & node_name) const283   bool IsFeedNode(const string& node_name) const {
284     return absl::c_any_of(
285         item_->feed, [&](const std::pair<std::string, Tensor>& feed) {
286           return ParseTensorName(feed.first).node() == node_name;
287         });
288   }
289 
IsFetchNode(const string & node_name) const290   bool IsFetchNode(const string& node_name) const {
291     return absl::c_any_of(item_->fetch, [&](const string& fetch) {
292       return ParseTensorName(fetch).node() == node_name;
293     });
294   }
295 
IsTrulyConst(const string & name) const296   bool IsTrulyConst(const string& name) const {
297     return TrulyConstNode(name) != nullptr;
298   }
299 
TrulyConstNode(const string & name) const300   const NodeDef* TrulyConstNode(const string& name) const {
301     return gtl::FindWithDefault(truly_const_nodes_, name, nullptr);
302   }
303 
FindFunctionSpecialization(const FunctionSpecializationSignature & sig) const304   const FunctionSpecialization* FindFunctionSpecialization(
305       const FunctionSpecializationSignature& sig) const {
306     return gtl::FindOrNull(specialized_functions_, sig);
307   }
308 
AddSpecializedFunction(const FunctionSpecializationSignature & sig,const FunctionSpecialization & specialized_func)309   void AddSpecializedFunction(const FunctionSpecializationSignature& sig,
310                               const FunctionSpecialization& specialized_func) {
311     specialized_functions_.emplace(sig, specialized_func);
312   }
313 
AddTensorMapping(const SafeTensorId & from,const SafeTensorId & to)314   void AddTensorMapping(const SafeTensorId& from, const SafeTensorId& to) {
315     DCHECK(from.index() != Graph::kControlSlot)
316         << "Tensor mapping must be from regular tensor";
317     DCHECK(to.index() != Graph::kControlSlot)
318         << "Tensor mapping must be to regular tensor";
319 
320     auto inserted = tensor_mapping_.insert({from, to});
321     DCHECK(inserted.second)
322         << "Failed to insert duplicated tensor mapping: "
323         << "from=" << from.ToString() << " to=" << to.ToString();
324   }
325 
AddTensorMapping(const string & func_node,const FunctionSpecialization & specialized_func)326   void AddTensorMapping(const string& func_node,
327                         const FunctionSpecialization& specialized_func) {
328     for (const auto& pair : specialized_func.output_mapping) {
329       int from_idx = pair.first;
330       int to_idx = pair.second;
331       if (from_idx != to_idx) {
332         SafeTensorId from_tensor(func_node, from_idx);
333         SafeTensorId to_tensor(func_node, to_idx);
334         AddTensorMapping(from_tensor, to_tensor);
335       }
336     }
337   }
338 
339  private:
InferTrulyConstNodes(const GrapplerItem & item,const GraphDef & graph)340   static absl::flat_hash_map<string, const NodeDef*> InferTrulyConstNodes(
341       const GrapplerItem& item, const GraphDef& graph) {
342     absl::flat_hash_set<absl::string_view> feed_nodes;
343     for (const auto& feed : item.feed) {
344       feed_nodes.insert(feed.first);
345     }
346 
347     absl::flat_hash_map<string, const NodeDef*> const_nodes;
348     for (const NodeDef& node : graph.node()) {
349       if (IsConstant(node) && !feed_nodes.contains(node.name())) {
350         const_nodes[node.name()] = &node;
351       }
352     }
353 
354     return const_nodes;
355   }
356 
357   const GrapplerItem* item_;  // must outlive this object
358   RewriterConfig::Toggle opt_level_;
359 
360   // Function library constructed from current graph.
361   FunctionLibraryDefinition function_library_;
362 
363   // Nodes that are Const and not in feed.
364   absl::flat_hash_map<string, const NodeDef*> truly_const_nodes_;
365   // Specialized functions.
366   absl::flat_hash_map<FunctionSpecializationSignature,
367                       const FunctionSpecialization>
368       specialized_functions_;
369 
370   // After function specialization, the optimized graph might be in invalid
371   // state, nodes can read from output index that is no longer valid after
372   // unused outputs pruning.
373   //
374   // Tensor mapping that has to be applied to the graph after all functions
375   // optimizations (invalidated tensor id -> optimized graph tensor id).
376   absl::flat_hash_map<SafeTensorId, SafeTensorId, SafeTensorId::Hasher>
377       tensor_mapping_;
378 
379   // Use graph view to find active outputs of the function caller nodes.
380   GraphView graph_view_;
381 
382   TF_DISALLOW_COPY_AND_ASSIGN(FunctionOptimizerContext);
383 };
384 
385 // Returns a pointer to the called function definition iff the given node is
386 // indeed a function call. Otherwise returns nullptr.
FindFunctionCall(const FunctionOptimizerContext & ctx,const NodeDef & node)387 const FunctionDef* FindFunctionCall(const FunctionOptimizerContext& ctx,
388                                     const NodeDef& node) {
389   // Check if a node does indirect function call via PartitionedCallOp.
390   if (IsPartitionedCall(node) || IsStatefulPartitionedCall(node)) {
391     const AttrValue* func_attr = AttrSlice(node).Find("f");
392     return (func_attr != nullptr && func_attr->has_func())
393                ? ctx.function_library().Find(func_attr->func().name())
394                : nullptr;
395   }
396 
397   // Check if the function op itself is a function name.
398   return ctx.function_library().Find(node.op());
399 }
400 
GetActiveOutputs(const NodeDef & node,const FunctionOptimizerContext & ctx,int size_hint=0)401 absl::flat_hash_set<int> GetActiveOutputs(const NodeDef& node,
402                                           const FunctionOptimizerContext& ctx,
403                                           int size_hint = 0) {
404   absl::flat_hash_set<int> active_outputs;
405   active_outputs.reserve(static_cast<size_t>(size_hint));
406 
407   // 1. Output can be consumed by the other graph node.
408   const auto node_fanout_edges =
409       ctx.graph_view().GetFanoutEdges(node, /*include_controlled_edges=*/false);
410   for (const GraphView::Edge& edge : node_fanout_edges) {
411     active_outputs.insert(edge.src.port_id);
412   }
413 
414   // 2. Or it can be in a fetch set.
415   for (const string& fetch : ctx.item().fetch) {
416     TensorId fetch_tensor = ParseTensorName(fetch);
417     if (fetch_tensor.node() == node.name()) {
418       active_outputs.insert(fetch_tensor.index());
419     }
420   }
421 
422   return active_outputs;
423 }
424 
HasTrulyConstInputs(const NodeDef & node,const FunctionOptimizerContext & ctx)425 bool HasTrulyConstInputs(const NodeDef& node,
426                          const FunctionOptimizerContext& ctx) {
427   const auto is_truly_const = [&ctx](const string& input) {
428     return ctx.IsTrulyConst(NodeName(input));
429   };
430   return absl::c_any_of(node.input(), is_truly_const);
431 }
432 
HasUnusedOutputs(const NodeDef & func_node,const FunctionDef & func,const FunctionOptimizerContext & ctx)433 bool HasUnusedOutputs(const NodeDef& func_node, const FunctionDef& func,
434                       const FunctionOptimizerContext& ctx) {
435   // Functions with tensor list outputs are not supported right now, so the
436   // number of output args is the same as number of possible function caller
437   // node outputs.
438   int num_outputs = func.signature().output_arg_size();
439   const absl::flat_hash_set<int> active_outputs =
440       GetActiveOutputs(func_node, ctx, /*size_hind*/ num_outputs);
441   int active_outputs_size = active_outputs.size();
442   return active_outputs_size != num_outputs;
443 }
444 
445 // Return pruned FunctionDefLibrary with functions that are reachable from
446 // the optimized graph.
PruneFunctionLibrary(const FunctionLibraryDefinition & flib,const GraphDef & optimized_graph)447 FunctionDefLibrary PruneFunctionLibrary(const FunctionLibraryDefinition& flib,
448                                         const GraphDef& optimized_graph) {
449   FunctionLibraryDefinition pruned_flib =
450       flib.ReachableDefinitions(optimized_graph);
451 
452   int pruned_functions = static_cast<int>(pruned_flib.num_functions()) -
453                          static_cast<int>(flib.num_functions());
454 
455   VLOG(3) << "Pruned function library: " << pruned_flib.num_functions()
456           << " functions (" << pruned_functions << ")";
457 
458   return pruned_flib.ToProto();
459 }
460 
461 // Push all constant inputs of an instantiating node into the function body.
PushDownConstInputs(const NodeDef & func_node,const FunctionOptimizerContext & ctx,GrapplerFunctionItem * item,absl::flat_hash_set<string> * const_inputs,absl::flat_hash_set<string> * control_deps)462 Status PushDownConstInputs(const NodeDef& func_node,
463                            const FunctionOptimizerContext& ctx,
464                            GrapplerFunctionItem* item,
465                            absl::flat_hash_set<string>* const_inputs,
466                            absl::flat_hash_set<string>* control_deps) {
467   // Record node control dependencies in the control_deps set.
468   const auto record_control_deps = [&](const NodeDef* const_input) {
469     for (int i = const_input->input_size() - 1; i >= 0; --i) {
470       const string& input = const_input->input(i);
471       if (IsControlInput(input))
472         control_deps->insert(input);
473       else
474         break;
475     }
476   };
477 
478   for (int i = func_node.input_size() - 1; i >= 0; --i) {
479     const string& input = func_node.input(i);
480     if (IsControlInput(input)) continue;
481 
482     const string node_name = NodeName(input);
483     if (ctx.IsTrulyConst(node_name)) {
484       VLOG(3) << "Push const into function body: input=" << input;
485       const auto* const_input = CHECK_NOTNULL(ctx.TrulyConstNode(node_name));
486       const_inputs->insert(input);
487       record_control_deps(const_input);
488       TF_RETURN_IF_ERROR(ReplaceInputWithConst(*const_input, i, item));
489     }
490   }
491 
492   return Status::OK();
493 }
494 
495 // Remove inputs that were pushed into the function body, and attach their
496 // control dependencies to the function caller node.
RemovePushedDownConstInputs(const FunctionSpecialization & specialization,NodeDef * specialized_func_node)497 void RemovePushedDownConstInputs(const FunctionSpecialization& specialization,
498                                  NodeDef* specialized_func_node) {
499   // Nothing to do if it was no const inputs to the function node.
500   if (specialization.const_inputs.empty()) return;
501 
502   // Keep only non-const inputs.
503   std::vector<string> keep_inputs;
504   const auto& inputs = specialized_func_node->input();
505   absl::c_copy_if(inputs, std::back_inserter(keep_inputs),
506                   [&](const string& input) {
507                     return !specialization.const_inputs.contains(input);
508                   });
509 
510   specialized_func_node->clear_input();
511   for (const auto& keep : keep_inputs) specialized_func_node->add_input(keep);
512 
513   // Attach control dependencies of pushed down const input to the caller node.
514   if (!specialization.control_deps.empty()) {
515     absl::flat_hash_set<string> existing_control_deps;
516 
517     for (const string& input : keep_inputs) {
518       existing_control_deps.insert(AsControlDependency(NodeName(input)));
519     }
520 
521     for (const string& ctrl : specialization.control_deps) {
522       if (!existing_control_deps.contains(ctrl)) {
523         VLOG(3) << "Forward control dependency: input=" << ctrl;
524         specialized_func_node->add_input(ctrl);
525       }
526     }
527   }
528 }
529 
530 // Remove Tin type parameters for pushed down const inputs.
RemovePushedDownConstInputTypes(const FunctionSpecialization & specialization,const NodeDef & func_node,NodeDef * specialized_func_node)531 void RemovePushedDownConstInputTypes(
532     const FunctionSpecialization& specialization, const NodeDef& func_node,
533     NodeDef* specialized_func_node) {
534   // Nothing to do if it was no const inputs to the function node.
535   if (specialization.const_inputs.empty()) return;
536 
537   // Make sure that original function caller has Tin attribute.
538   const AttrValue* tin = AttrSlice(func_node).Find("Tin");
539   if (tin == nullptr || !tin->has_list()) return;
540 
541   // Clear input types for the specialized node.
542   auto* attr = specialized_func_node->mutable_attr();
543   (*attr)["Tin"].mutable_list()->clear_type();
544 
545   // Keep types of non-const inputs.
546   for (int i = 0; i < func_node.input_size(); ++i) {
547     const string& input = func_node.input(i);
548     if (IsControlInput(input)) break;
549 
550     if (!specialization.const_inputs.contains(input)) {
551       DataType dt = tin->list().type(i);
552       (*attr)["Tin"].mutable_list()->add_type(dt);
553     }
554   }
555 }
556 
557 // Remove Tout type parameters for pruned function outputs.
RemoveUnusedOutputsTypes(const FunctionSpecialization & specialization,const NodeDef & func_node,NodeDef * specialized_func_node)558 void RemoveUnusedOutputsTypes(const FunctionSpecialization& specialization,
559                               const NodeDef& func_node,
560                               NodeDef* specialized_func_node) {
561   // Make sure that original function caller has Tout attribute.
562   const AttrValue* tout = AttrSlice(func_node).Find("Tout");
563   if (tout == nullptr || !tout->has_list()) return;
564 
565   // Nothing to do if all outputs are active.
566   int specialization_active_outputs_size = specialization.active_outputs.size();
567   if (specialization_active_outputs_size == tout->list().type_size()) return;
568 
569   // Clear input types for the specialized node.
570   auto* attr = specialized_func_node->mutable_attr();
571   (*attr)["Tout"].mutable_list()->clear_type();
572 
573   // Keep output types of active outputs only.
574   for (int i = 0; i < tout->list().type_size(); ++i) {
575     if (specialization.active_outputs.contains(i)) {
576       DataType dt = tout->list().type(i);
577       (*attr)["Tout"].mutable_list()->add_type(dt);
578     }
579   }
580 }
581 
UpdateSpecializedFunctionCallSite(const FunctionDef & func,const NodeDef & func_node,const string & specialized_func_name,NodeDef * specialized_func_node)582 Status UpdateSpecializedFunctionCallSite(const FunctionDef& func,
583                                          const NodeDef& func_node,
584                                          const string& specialized_func_name,
585                                          NodeDef* specialized_func_node) {
586   if (IsDirectFunctionCall(func, func_node)) {
587     specialized_func_node->set_op(specialized_func_name);
588 
589   } else if (IsIndirectFunctionCall(func, func_node)) {
590     auto* attr = specialized_func_node->mutable_attr();
591     (*attr)[kFuncAttr].mutable_func()->set_name(specialized_func_name);
592 
593   } else {
594     return errors::InvalidArgument("Unknown function call site");
595   }
596 
597   return Status::OK();
598 }
599 
600 // Update a graph node created from the original function caller node, to the
601 // function specialization. Function specialization might change the number of
602 // inputs and outputs, so we have to make sure that graph node is updated
603 // accordingly.
UpdateSpecializedFunctionNode(const FunctionDef & func,const NodeDef & func_node,const FunctionSpecialization & specialization,NodeDef * specialized_func_node)604 Status UpdateSpecializedFunctionNode(
605     const FunctionDef& func, const NodeDef& func_node,
606     const FunctionSpecialization& specialization,
607     NodeDef* specialized_func_node) {
608   // Function called indirectly via custom kernel (e.g. PartitionedCallOp).
609   bool is_indirect_call = IsIndirectFunctionCall(func, func_node);
610 
611   // 1. Call the specialized function instead of original one.
612   TF_RETURN_IF_ERROR(UpdateSpecializedFunctionCallSite(
613       func, func_node, specialization.specialized_func_name,
614       specialized_func_node));
615 
616   // 2. Remove inputs corresponding to the pushed down consts.
617   RemovePushedDownConstInputs(specialization, specialized_func_node);
618 
619   // NOTE: PartitionedCallOp has `Tin` and `Tout` attributes for input/output
620   // types, that must be in sync with updated function signature.
621 
622   // 3. Update input types for the indirect function calls.
623   if (is_indirect_call) {
624     RemovePushedDownConstInputTypes(specialization, func_node,
625                                     specialized_func_node);
626   }
627 
628   // 4. Update output types for the indirect function call. It's unsafe to
629   // change the number of outputs for the fetch nodes, so we just skip them.
630   if (is_indirect_call && !specialization.is_in_fetch_set) {
631     RemoveUnusedOutputsTypes(specialization, func_node, specialized_func_node);
632   }
633 
634   // 5. Remove custom gradient annotation.
635   specialized_func_node->mutable_attr()->erase("_gradient_op_type");
636 
637   return Status::OK();
638 }
639 
InitializeFunctionSpecializationSignature(const NodeDef & func_node,const FunctionDef & func,const AttrSlice & func_instantiation_attr,const FunctionOptimizerContext & ctx,FunctionSpecializationSignature * sig)640 Status InitializeFunctionSpecializationSignature(
641     const NodeDef& func_node, const FunctionDef& func,
642     const AttrSlice& func_instantiation_attr,
643     const FunctionOptimizerContext& ctx, FunctionSpecializationSignature* sig) {
644   DCHECK(sig->const_inputs.empty());
645   DCHECK(sig->active_outputs.empty());
646 
647   sig->func_name = func.signature().name();
648   sig->is_in_fetch_set = ctx.IsFetchNode(func_node.name());
649   sig->active_outputs = GetActiveOutputs(func_node, ctx);
650 
651   TF_RETURN_IF_ERROR(InstantiationTypeParameters(func, func_instantiation_attr,
652                                                  &sig->type_parameters));
653   TF_RETURN_IF_ERROR(InstantiationBodyParameters(func, func_instantiation_attr,
654                                                  &sig->body_parameters));
655 
656   for (int i = 0; i < func_node.input_size(); ++i) {
657     const string& input = func_node.input(i);
658     if (IsControlInput(input)) break;
659     if (ctx.IsTrulyConst(input)) {
660       sig->const_inputs.emplace(i, input);
661     }
662   }
663 
664   return Status::OK();
665 }
666 
667 // Create a name for the function specialization. The name of the function, name
668 // of the node instantiating it, and a Grappler item id should generate unique
669 // function name. Meta optimizer might create multiple Grappler items for the
670 // same graph when optimizing functions, but it's guaranteed that they all will
671 // have unique ids.
SpecializedFunctionName(const FunctionOptimizerContext & ctx,const FunctionDef & func,const NodeDef & func_node)672 string SpecializedFunctionName(const FunctionOptimizerContext& ctx,
673                                const FunctionDef& func,
674                                const NodeDef& func_node) {
675   return absl::Substitute(
676       "$0_specialized_for_$1_at_$2", func.signature().name(),
677       absl::StrReplaceAll(func_node.name(), {{"/", "_"}}), ctx.item().id);
678 }
679 
SpecializeFunction(const NodeDef & func_node,const FunctionDef & func,FunctionOptimizerContext * ctx,GraphDef * optimized_graph)680 Status SpecializeFunction(const NodeDef& func_node, const FunctionDef& func,
681                           FunctionOptimizerContext* ctx,
682                           GraphDef* optimized_graph) {
683   VLOG(2) << "Specialize function call: " << SummarizeNodeDef(func_node);
684 
685   const AttrSlice func_instantiation_attr =
686       FunctionInstantiationAttributes(func, func_node);
687 
688   FunctionSpecializationSignature signature;
689   TF_RETURN_IF_ERROR(InitializeFunctionSpecializationSignature(
690       func_node, func, func_instantiation_attr, *ctx, &signature));
691 
692   // Check if function was already specialized for identical context.
693   const FunctionSpecialization* already_specialized =
694       ctx->FindFunctionSpecialization(signature);
695 
696   if (already_specialized) {
697     VLOG(2) << "Function was already specialized in identical context: "
698                "specialized_name="
699             << already_specialized->specialized_func_name;
700 
701     // Add a function call node for the specialized function.
702     NodeDef* specialized_func_node = optimized_graph->add_node();
703     *specialized_func_node = func_node;
704 
705     TF_RETURN_IF_ERROR(UpdateSpecializedFunctionNode(
706         func, func_node, *already_specialized, specialized_func_node));
707 
708     ctx->AddTensorMapping(specialized_func_node->name(), *already_specialized);
709 
710     return Status::OK();
711   }
712 
713   // Add a new specialized function definition to the library.
714   const auto& flib = ctx->function_library();
715 
716   // Make a GrapplerFunctionItem and convert it back to FunctionDef after
717   // pushing all constant inputs into the function body.
718   GrapplerFunctionItem item;
719   TF_RETURN_IF_ERROR(MakeGrapplerFunctionItem(
720       func, func_instantiation_attr, flib, ctx->graph_version(), &item));
721 
722   // Push const inputs into the function body, and keep track of their control
723   // dependencies.
724   absl::flat_hash_set<string> const_inputs;
725   absl::flat_hash_set<string> control_deps;
726   TF_RETURN_IF_ERROR(PushDownConstInputs(func_node, *ctx, &item, &const_inputs,
727                                          &control_deps));
728 
729   // Remove function outputs that do not have any consumers. We can't safely
730   // update outputs for the fetch nodes, so we just skip them.
731   std::vector<std::pair<int, int>> output_mapping;
732   if (!signature.is_in_fetch_set) {
733     int num_func_outputs = item.output_size();
734 
735     absl::flat_hash_set<int> remove;
736     for (int i = 0; i < num_func_outputs; ++i) {
737       if (!signature.active_outputs.count(i)) remove.insert(i);
738     }
739 
740     TF_RETURN_IF_ERROR(RemoveFunctionOutputs(remove, &item, &output_mapping));
741   }
742 
743   // TODO(ezhulenev): Push down known input shapes.
744   FunctionDef specialized_func;
745   TF_RETURN_IF_ERROR(MakeFunctionDef(item, flib, &specialized_func));
746 
747   // Find a name for specialized function.
748   const string specialized_func_name =
749       SpecializedFunctionName(*ctx, func, func_node);
750   if (flib.Contains(specialized_func_name)) {
751     // NOTE(ezhulenev): This should never happen. If it happens, it's a sign of
752     // a serious internal error, that must be investigated.
753     return errors::Internal("Created duplicate function specialization");
754   }
755 
756   specialized_func.mutable_signature()->set_name(specialized_func_name);
757   auto* specialized_attr = specialized_func.mutable_attr();
758   (*specialized_attr)[kGrapplerSpecializedFuncAttr].set_b(true);
759 
760   // Add specialized function to the library.
761   TF_RETURN_IF_ERROR(ctx->function_library().AddFunctionDef(specialized_func));
762 
763   // Add a function call node for the specialized function.
764   NodeDef* specialized_func_node = optimized_graph->add_node();
765   *specialized_func_node = func_node;
766 
767   FunctionSpecialization func_specialization = {
768       specialized_func_name, signature.is_in_fetch_set, const_inputs,
769       control_deps,          signature.active_outputs,  output_mapping};
770 
771   TF_RETURN_IF_ERROR(UpdateSpecializedFunctionNode(
772       func, func_node, func_specialization, specialized_func_node));
773 
774   ctx->AddSpecializedFunction(signature, func_specialization);
775   ctx->AddTensorMapping(specialized_func_node->name(), func_specialization);
776 
777   return Status::OK();
778 }
779 
780 // -------------------------------------------------------------------------- //
781 // Inline function calls into a graph using function inlining implementation
782 // from common_runtime:
783 //
784 // 1) Convert GraphDef to Graph.
785 // 2) Inline function calls.
786 // 3) Convert Graph back to the GraphDef.
787 
788 constexpr const char* const kLowerUsingSwitchMergeAttr =
789     LowerFunctionalOpsPass::kLowerUsingSwitchMergeAttr;
790 constexpr const char* const kLowerAsMultiDeviceFunctionAttr =
791     LowerFunctionalOpsPass::kLowerAsMultiDeviceFunctionAttr;
792 
793 using KeepCallerNode = InlineFunctionBodyOptions::KeepCallerNode;
794 using OutputControlSource = InlineFunctionBodyOptions::OutputControlSource;
795 
796 // Checks if boolean attribute is defined and its value is 'true'.
CheckBoolAttr(const Node * n,absl::string_view attr_name)797 bool CheckBoolAttr(const Node* n, absl::string_view attr_name) {
798   bool match;
799   bool found = TryGetNodeAttr(n->attrs(), attr_name, &match);
800   return found && match;
801 }
802 
803 // Checks if string attribute is defined and it's not empty.
CheckStringAttr(const Node * n,absl::string_view attr_name)804 bool CheckStringAttr(const Node* n, absl::string_view attr_name) {
805   const string& value = GetNodeAttrString(n->attrs(), attr_name);
806   return !value.empty();
807 }
808 
LowerUsingSwitchMergeIsOn(const Node * n)809 bool LowerUsingSwitchMergeIsOn(const Node* n) {
810   return CheckBoolAttr(n, kLowerUsingSwitchMergeAttr);
811 }
812 
LowerAsMultiDeviceFunctionIsOn(const Node * n)813 bool LowerAsMultiDeviceFunctionIsOn(const Node* n) {
814   return CheckBoolAttr(n, kLowerAsMultiDeviceFunctionAttr);
815 }
816 
MarkedForXlaCompilation(const NodeDef & n)817 bool MarkedForXlaCompilation(const NodeDef& n) {
818   auto is_enabled = [&](std::string attr_name) -> bool {
819     auto it = n.attr().find(attr_name);
820     return it != n.attr().end() && (!it->second.s().empty() || it->second.b());
821   };
822   return is_enabled("_xla_compile_id") || is_enabled("_tpu_replicate") ||
823          is_enabled(kXlaMustCompileAttr);
824 }
825 
IsExemptFromSideEffectsExecutionValidation(const string & op)826 const bool IsExemptFromSideEffectsExecutionValidation(const string& op) {
827   static const auto* exemption = new absl::flat_hash_set<string>(
828       {// LINT.IfChange
829        // Op types that should not run in program order, e.g. because they need
830        // to run asynchronously to avoid deadlock.
831        "CollectiveGather", "CollectiveGatherV2", "CollectiveReduce",
832        "CollectiveReduceV2", "CollectiveBcastSend", "CollectiveBcastRecv",
833        "CollectiveBcastSendV2", "CollectiveBcastRecvV2", "NcclAllReduce",
834        "Send", "Recv",
835 
836        // Legacy random ops.
837        // See details in tensorflow/python/framework/auto_control_deps.py.
838        "RandomUniform", "RandomUniformInt", "RandomStandardNormal",
839        "ParameterizedTruncatedNormal", "TruncatedNormal", "RandomShuffle",
840        "Multinomial", "RandomGamma", "RandomGammaGrad", "RandomPoisson",
841        "RandomPoissonV2",
842 
843        // ReadVariableOp marked as stateful because it consumes DT_RESOURCE,
844        // but it can't generate any observable side-effect.
845        "ReadVariableOp",
846 
847        // CudnnRNN ops are stateful but they can't generate any observable
848        // side-effect.
849        "CudnnRNN", "CudnnRNNBackprop", "CudnnRNNV2", "CudnnRNNV3",
850        "CudnnRNNBackpropV2", "CudnnRNNBackpropV3",
851 
852        // TPUEmbedding EnqueueOps are stateful but this is only between ops with
853        // the same device_ordinal on the same host.
854        "EnqueueTPUEmbeddingSparseBatch", "EnqueueTPUEmbeddingIntegerBatch",
855        "EnqueueTPUEmbeddingSparseTensorBatch",
856        "EnqueueTPUEmbeddingRaggedTensorBatch",
857 
858        // SaveV2 and RestoreV2 should be allowed to operate in parallel on
859        // multiple hosts.
860        "SaveV2", "RestoreV2"});
861   // LINT.ThenChange(//tensorflow/python/framework/auto_control_deps.py)
862   return exemption->contains(op);
863 }
864 
865 // Validates that all side effects inside function body will be executed after
866 // function inlining. We do it by looking for a path from stateful ops, to one
867 // of the output control sources.
868 //
869 // When function executed via FunctionLibraryRuntime we do not have to check
870 // this, because `PruneFunctionBody` has special pruning rules for stateful ops.
ValidateSideEffectsExecution(const FunctionBody & fbody,OutputControlSource output_control_source,bool has_outgoing_control_edges,bool validate_outgoing_control_edge=true)871 Status ValidateSideEffectsExecution(
872     const FunctionBody& fbody, OutputControlSource output_control_source,
873     bool has_outgoing_control_edges,
874     bool validate_outgoing_control_edge = true) {
875   // Find all nodes that can produce side effects in the function body graph. We
876   // use 'is_stateful()' bit as an approximation of "has side effects" property.
877   std::vector<const Node*> fbody_side_effects;
878   absl::c_copy_if(
879       fbody.graph->nodes(), std::back_inserter(fbody_side_effects),
880       [](const Node* n) {
881         return n->op_def().is_stateful() && !n->IsArg() && !n->IsRetval() &&
882                !IsExemptFromSideEffectsExecutionValidation(n->type_string());
883       });
884 
885   // When graph executed in TF-2.0 context with automatic control dependencies
886   // tracking, absence of outgoing control edge indicates that no one is
887   // interested in observing side effects, so it is safe to inline the function
888   // body, even if some side-effects will not be executed.
889   if (!fbody_side_effects.empty() && !has_outgoing_control_edges) {
890     const string error_message =
891         "Can't guarantee execution of function side-effects after inlining. "
892         "Function call node has no outgoing control edges.";
893     if (validate_outgoing_control_edge) {
894       return errors::Internal(error_message);
895     } else {
896       VLOG(3) << error_message;
897     }
898   }
899 
900   // Find all nodes in the function body that will be used as control sources.
901   absl::flat_hash_set<const Node*> control_sources;
902   if (output_control_source == OutputControlSource::kDataOutputs) {
903     control_sources = {fbody.ret_nodes.begin(), fbody.ret_nodes.end()};
904   } else if (output_control_source == OutputControlSource::kControlOutputs) {
905     control_sources = {fbody.control_ret_nodes.begin(),
906                        fbody.control_ret_nodes.end()};
907   }
908 
909   for (const Node* side_effect : fbody_side_effects) {
910     VLOG(4) << "Check that node " << side_effect->name()
911             << " will execute after inlining.";
912     bool will_execute = false;
913 
914     const auto is_control_source = [&](const Node* n) -> void {
915       const auto it = control_sources.find(n);
916       if (it != control_sources.end()) {
917         VLOG(4) << "Found a path to control source: " << side_effect->name()
918                 << " ---> " << (*it)->name();
919         will_execute = true;
920       }
921     };
922 
923     DFSFrom(*fbody.graph, {side_effect}, /*enter=*/is_control_source,
924             /*leave=*/{}, NodeComparatorName{});
925 
926     if (!will_execute) {
927       return errors::Internal(
928           "Can't guarantee execution of a side-effectful node, that is not "
929           "reachable from function control source. Function body node: ",
930           SummarizeNode(*side_effect));
931     }
932   }
933 
934   return Status::OK();
935 }
936 
937 // Validates that no dead tensor can reach function output.
ValidateNoDeadOutputs(const FunctionLibraryDefinition & flib_def,const FunctionBody & fbody)938 Status ValidateNoDeadOutputs(const FunctionLibraryDefinition& flib_def,
939                              const FunctionBody& fbody) {
940   absl::flat_hash_set<const Node*> output_nodes = {fbody.ret_nodes.begin(),
941                                                    fbody.ret_nodes.end()};
942 
943   // Find all nodes that can produce dead tensors.
944   std::vector<const Node*> dead_tensor_sources;
945   for (const Node* n : fbody.graph->nodes()) {
946     if (n->IsSwitch()) {
947       VLOG(4) << "Add dead tensors source. Switch node: " << n->name();
948       dead_tensor_sources.push_back(n);
949       continue;
950     }
951 
952     // Native function call can also produce dead tensors if the function body
953     // has mergeless switches.
954     const FunctionDef* fdef = flib_def.Find(n->type_string());
955     if (fdef != nullptr) {
956       std::unique_ptr<FunctionBody> nested_fbody;
957 
958       NameAttrList func;
959       TF_RETURN_IF_ERROR(NameAndAttrsFromFunctionCall(n->def(), &func));
960       TF_RETURN_IF_ERROR(FunctionDefToBodyHelper(*fdef, AttrSlice(&func.attr()),
961                                                  &flib_def, &nested_fbody));
962 
963       if (!ValidateNoDeadOutputs(flib_def, *nested_fbody).ok()) {
964         VLOG(4) << "Add dead tensors source. Function call: " << func.name()
965                 << " node=" << n->name();
966         dead_tensor_sources.push_back(n);
967       }
968     }
969   }
970 
971   for (const Node* dead_tensor_source : dead_tensor_sources) {
972     bool has_dead_output = false;
973 
974     const auto is_output_node = [&](const Node* n) -> void {
975       const auto it = output_nodes.find(n);
976       if (it != output_nodes.end()) {
977         VLOG(4) << "Found a path to output node from dead tensor source: "
978                 << dead_tensor_source->name() << " ---> " << (*it)->name();
979         has_dead_output = true;
980       }
981     };
982 
983     // Stop DFS traversal at a Merge node or if already found a dead output.
984     const auto stop_traversal = [&has_dead_output](const Edge& edge) -> bool {
985       return !edge.src()->IsMerge() || has_dead_output;
986     };
987 
988     DFSFrom(*fbody.graph, {dead_tensor_source}, /*enter=*/is_output_node,
989             /*leave=*/{}, NodeComparatorName{},
990             /*edge_filter=*/stop_traversal);
991 
992     if (has_dead_output) {
993       return errors::Internal(
994           "Can't inline a function with dead outputs. Dead tensor source: ",
995           SummarizeNode(*dead_tensor_source));
996     }
997   }
998 
999   return Status::OK();
1000 }
1001 
1002 // Makes an instance of FunctionBody for inlining from a Node.
MakeFunctionBodyForInlining(const Node & node,const FunctionLibraryDefinition & flib_def,std::unique_ptr<FunctionBody> * fbody)1003 Status MakeFunctionBodyForInlining(const Node& node,
1004                                    const FunctionLibraryDefinition& flib_def,
1005                                    std::unique_ptr<FunctionBody>* fbody) {
1006   VLOG(3) << "Make function body for inlining: " << SummarizeNode(node);
1007 
1008   // Finds a FunctionDef in a library and verifies that it exists.
1009   const auto find_fdef = [&flib_def, &node](
1010                              const string& name,
1011                              const FunctionDef** fdef) -> Status {
1012     if ((*fdef = flib_def.Find(name)) == nullptr) {
1013       return errors::Internal(
1014           "Was not able to find a function definition (name=", name,
1015           ") for a function call: ", SummarizeNode(node));
1016     }
1017     return Status::OK();
1018   };
1019 
1020   // SymbolicGradient is a special "function call" op, which has been
1021   // deprecated for a while, but we still support for compatibility reasons.
1022   if (node.type_string() == FunctionLibraryDefinition::kGradientOp) {
1023     NameAttrList func;
1024     TF_RETURN_IF_ERROR(GetNodeAttr(node.attrs(), kFuncAttr, &func));
1025 
1026     const string grad = flib_def.FindGradient(func.name());
1027 
1028     if (!grad.empty()) {
1029       // Function has a custom gradient registered in a library.
1030       const FunctionDef* grad_fdef;
1031       TF_RETURN_IF_ERROR(find_fdef(grad, &grad_fdef));
1032 
1033       VLOG(4) << "Instantiate a custom SymbolicGradient: gradient=" << grad
1034               << " (function=" << func.name() << ")";
1035       TF_RETURN_IF_ERROR(FunctionDefToBodyHelper(
1036           *grad_fdef, AttrSlice(&func.attr()), &flib_def, fbody));
1037 
1038     } else if (flib_def.Find(func.name()) == nullptr) {
1039       // Function is not really a function, but a primitive op.
1040       gradient::Creator creator;
1041       TF_RETURN_IF_ERROR(gradient::GetOpGradientCreator(func.name(), &creator));
1042       if (creator == nullptr) {
1043         return errors::InvalidArgument("No gradient is defined for ",
1044                                        func.name());
1045       }
1046       FunctionDef grad_fdef;
1047       TF_RETURN_IF_ERROR(creator(AttrSlice(&func.attr()), &grad_fdef));
1048 
1049       VLOG(4) << "Instantiate a SymbolicGradient for a primitive op: "
1050               << func.name();
1051       TF_RETURN_IF_ERROR(FunctionDefToBodyHelper(
1052           grad_fdef, AttrSlice(&func.attr()), &flib_def, fbody));
1053 
1054     } else {
1055       // Build a gradient graph from the function body.
1056       const FunctionDef* fdef;
1057       TF_RETURN_IF_ERROR(find_fdef(func.name(), &fdef));
1058 
1059       VLOG(4) << "Instantiate a SymbolicGradient for a function: "
1060               << func.name();
1061       TF_RETURN_IF_ERROR(FunctionDefToBodyHelper(*fdef, AttrSlice(&func.attr()),
1062                                                  &flib_def, fbody));
1063       *fbody = SymbolicGradient(**fbody);
1064     }
1065 
1066   } else {
1067     NameAttrList func;
1068     TF_RETURN_IF_ERROR(NameAndAttrsFromFunctionCall(node.def(), &func));
1069     const FunctionDef* fdef;
1070     TF_RETURN_IF_ERROR(find_fdef(func.name(), &fdef));
1071 
1072     VLOG(4) << "Instantiate a function call: function=" << func.name();
1073     TF_RETURN_IF_ERROR(FunctionDefToBodyHelper(*fdef, AttrSlice(&func.attr()),
1074                                                &flib_def, fbody));
1075   }
1076 
1077   return Status::OK();
1078 }
1079 
1080 // Adds a control edges from each data input to the 'caller' to enforce strict
1081 // inputs semantics (all inputs are ready and alive). This is required when:
1082 //
1083 //  1) The function takes resources as inputs, and it doesn't have incoming
1084 //     control edges. In Tensorflow v2 context (eager mode) this should never
1085 //     happen, because automatic control dependencies tracking will add a
1086 //     control edge from the last op touching the resource. However such graphs
1087 //     might be produced by legacy v1 code without automatic dependency
1088 //     tracking. In this case strict function call semantics is required for
1089 //     enforcing side effects execution order.
1090 //
1091 //  2) One of the inputs is consuming Enter[is_constant=true] node, in which
1092 //     case it will be always alive, and potentially can lead to partial
1093 //     function execution after the last loop execution.
1094 //
1095 // Both of these cases would be considered illegal by construction in Tensorflow
1096 // V2, however we have to guarantee that graphs constructed with Tensorflow V1
1097 // will produce correct results.
AddStrictInputSemantics(Node * caller,Graph * g)1098 void AddStrictInputSemantics(Node* caller, Graph* g) {
1099   absl::flat_hash_set<const Node*> existing_control_sources;
1100   for (const Edge* edge : caller->in_edges()) {
1101     if (edge->IsControlEdge()) {
1102       existing_control_sources.insert(edge->src());
1103     }
1104   }
1105 
1106   const bool has_incoming_control_edges = !existing_control_sources.empty();
1107 
1108   const bool has_resource_input =
1109       absl::c_any_of(caller->input_types(),
1110                      [](const DataType dtype) { return dtype == DT_RESOURCE; });
1111 
1112   const bool has_constant_enter_input =
1113       absl::c_any_of(caller->in_edges(), [](const Edge* edge) {
1114         Node* src = edge->src();
1115         return src->IsEnter() && CheckBoolAttr(src, "is_constant");
1116       });
1117 
1118   const bool requires_strict_semantics =
1119       (!has_incoming_control_edges && has_resource_input) ||  // Case #1
1120       (has_constant_enter_input);                             // Case #2
1121   if (!requires_strict_semantics) return;
1122 
1123   std::set<const Node*> data_inputs;
1124   for (const Edge* edge : caller->in_edges()) {
1125     if (!edge->IsControlEdge() &&
1126         !existing_control_sources.contains(edge->src())) {
1127       data_inputs.insert(edge->src());
1128     }
1129   }
1130 
1131   VLOG(3) << "Add control edges from all data inputs to enforce strict "
1132              "semantics with regard to function inputs";
1133 
1134   // Do not add control edges from placeholders, because it will prevent
1135   // pruning, and they can't produce any side effects anyway.
1136   const auto is_placeholder = [](const Node* node) -> bool {
1137     return node->type_string() == "Placeholder";
1138   };
1139 
1140   for (const Node* node : data_inputs) {
1141     if (is_placeholder(node)) continue;
1142     g->AddControlEdge(g->FindNodeId(node->id()), caller,
1143                       /*allow_duplicates=*/true);
1144   }
1145 }
1146 
1147 // Adds a control edge from a frame node if the 'caller' is executing inside a
1148 // While loop (see control_flow.h for the 'frame' node explanation).
AddFrameForwardingControlEdge(const std::vector<ControlFlowInfo> & info,Node * caller,Graph * g)1149 void AddFrameForwardingControlEdge(const std::vector<ControlFlowInfo>& info,
1150                                    Node* caller, Graph* g) {
1151   // All nodes added to the graph by v2 control flow lowering and function
1152   // inlining are guaranteed to have control edges to nested function calls.
1153   int info_size = info.size();
1154   if (caller->id() >= info_size) return;
1155 
1156   // Check if a lowered node is executing inside a while loop.
1157   const Node* frame = info[caller->id()].frame;
1158   const bool is_in_while_loop = frame->id() != Graph::kSourceId;
1159   if (!is_in_while_loop) return;
1160 
1161   // Check if a node already has an incoming control edge. All incoming edges
1162   // must be from the same execution frame (executor.cc invariant), so if we
1163   // already have an incoming control edge, it's guaranteed that it will "carry"
1164   // the same frame as all regular inputs.
1165   const bool has_incoming_control_edges =
1166       absl::c_any_of(caller->in_edges(),
1167                      [](const Edge* edge) { return edge->IsControlEdge(); });
1168   if (has_incoming_control_edges) return;
1169 
1170   VLOG(3) << "Add a frame forwarding control edge: from=" << frame->name()
1171           << " to=" << caller->name();
1172   Node* enter = g->FindNodeId(frame->id());
1173   bool is_constant_enter = enter->attrs().Find("is_constant")->b();
1174   if (is_constant_enter) {
1175     // Enter[is_constant=true] is always alive. So we directly add a control
1176     // edge from that.
1177     g->AddControlEdge(enter, caller);
1178   } else {
1179     // Enter[is_constant=false] activates nodes only in 0th iteration so we
1180     // add an edge from the Merge node which is activated in every iteration.
1181     // A non-constant Enter node must have an edge to a Merge node.
1182     auto it = absl::c_find_if(enter->out_edges(), [](const Edge* e) {
1183       return !e->IsControlEdge() && e->dst()->IsMerge();
1184     });
1185     if (it != enter->out_edges().end()) {
1186       g->AddControlEdge((*it)->dst(), caller);
1187     } else {
1188       LOG(WARNING) << "Enter[is_constant=false] node: " << enter->name()
1189                    << " does not have an outgoing edge to a Merge.";
1190     }
1191   }
1192 }
1193 
1194 // Inlines all function calls that are safe for inlining into the main graph.
1195 // Also lowers control flow V2 ops (functional If/While) into the V1 low level
1196 // ops (Switch/Merge/...).
1197 //
1198 // Runs a placer after inlining, to keep all nodes in a graph placed.
InlineFunctionCalls(const GrapplerItem & item,const RewriterConfig::Toggle opt_level,const bool lower_control_flow,GraphDef * output_graph)1199 Status InlineFunctionCalls(const GrapplerItem& item,
1200                            const RewriterConfig::Toggle opt_level,
1201                            const bool lower_control_flow,
1202                            GraphDef* output_graph) {
1203   bool is_aggressive = opt_level == RewriterConfig::AGGRESSIVE;
1204   VLOG(2) << "Inline function calls: grappler_item_id=" << item.id
1205           << " (aggressive_mode=" << is_aggressive << ")";
1206 
1207   FunctionLibraryDefinition flib_def =
1208       FunctionLibraryDefinition(OpRegistry::Global(), item.graph.library());
1209   std::unique_ptr<Graph> graph = absl::make_unique<Graph>(flib_def);
1210 
1211   GraphConstructorOptions graph_constructor_options;
1212   graph_constructor_options.allow_internal_ops = true;
1213   TF_RETURN_IF_ERROR(ConvertGraphDefToGraph(graph_constructor_options,
1214                                             item.graph, graph.get()));
1215 
1216   using NodeNames = absl::flat_hash_set<absl::string_view>;
1217   NodeNames fetch_nodes;
1218   fetch_nodes.reserve(item.fetch.size());
1219   for (const string& fetch : item.fetch) {
1220     fetch_nodes.insert(ParseTensorName(fetch).node());
1221   }
1222   NodeNames keep_nodes(item.keep_ops.begin(), item.keep_ops.end());
1223 
1224   std::vector<string> inlined_function_names;
1225 
1226   // Do not inline function call nodes that are part of a feed set.
1227   NodeNames feed_nodes;
1228   feed_nodes.reserve(item.feed.size());
1229   for (const std::pair<std::string, Tensor>& feed : item.feed) {
1230     feed_nodes.insert(ParseTensorName(feed.first).node());
1231   }
1232 
1233   // If a function call is inside a While loop, it must have an incoming control
1234   // edge, because it will be used to pass execution frame into the function
1235   // body. All nodes without inputs in the function body (e.g. Const and NoOp)
1236   // will be added an extra control edge from the 'input_control_node'.
1237   std::vector<ControlFlowInfo> control_flow_info;
1238   TF_RETURN_IF_ERROR(BuildControlFlowInfo(graph.get(), &control_flow_info));
1239 
1240   // Function inlining always adds new nodes to the end of the list, so we keep
1241   // iterating until we are out of nodes.
1242   for (int i = 2; i < graph->num_node_ids(); ++i) {
1243     Node* n = graph->FindNodeId(i);
1244     if (n == nullptr) continue;  // deleted node
1245 
1246     // Special case for lowering functional control flow ops. We do not rely on
1247     // LowerFunctionOpsPass because in Grappler we have to be more restrictive
1248     // about what type of function calls we are allowed to inline.
1249     if (lower_control_flow && LowerUsingSwitchMergeIsOn(n)) {
1250       VLOG(2) << "Lower functional control flow op: " << SummarizeNode(*n);
1251       AddStrictInputSemantics(n, graph.get());
1252       AddFrameForwardingControlEdge(control_flow_info, n, graph.get());
1253 
1254       if (n->IsIfNode()) {
1255         TF_RETURN_IF_ERROR(RewriteIfNode(n, graph.get(), false));
1256       } else if (n->IsCaseNode()) {
1257         TF_RETURN_IF_ERROR(RewriteCaseNode(n, graph.get(), false));
1258       } else if (n->IsWhileNode()) {
1259         TF_RETURN_IF_ERROR(RewriteWhileNode(n, graph.get(), false));
1260       }
1261       continue;
1262     }
1263 
1264     // Skip nodes that are not function calls.
1265     if (!IsFunctionCall(flib_def, *n)) continue;
1266     // Skip function calls that we plan to compile later.
1267     if (MarkedForXlaCompilation(n->def())) continue;
1268     // Skip nodes in a feed set.
1269     if (feed_nodes.contains(n->name())) continue;
1270 
1271     // Function body that we will inline into the main graph. It can be a
1272     // function instantiation, or a gradient function instantiated from
1273     // SymbolicGradient op.
1274     std::unique_ptr<FunctionBody> fbody;
1275     TF_RETURN_IF_ERROR(MakeFunctionBodyForInlining(*n, flib_def, &fbody));
1276 
1277     InlineFunctionBodyOptions inline_options;
1278     // Ignore '_noinline' flag in aggressive mode.
1279     inline_options.ignore_noinline = is_aggressive;
1280 
1281     // Function calls created after inlining If/While ops are always inlined as
1282     // multi-device functions and are not required to pass additional Grappler
1283     // validations (side effects execution validation below).
1284     bool force_inline_as_multi_device = LowerAsMultiDeviceFunctionIsOn(n);
1285 
1286     // `PartitionedCall` is a TF-2.0 function call mechanism for multi-device
1287     // functions:
1288     // a) Function can be multi-device.
1289     // b) Automatic control dependencies tracking guarantees that all function
1290     //    side-effectful nodes will have a path to one of the control outputs.
1291     //    Control outputs and control edges between side-effectful (stateful)
1292     //    nodes are used to explicitly mark the nodes that must execute, and to
1293     //    define their execution order.
1294     if (n->IsPartitionedCall() || force_inline_as_multi_device) {
1295       inline_options.output_control_src = OutputControlSource::kControlOutputs;
1296       inline_options.inlined_function_body_placer =
1297           InlinedFunctionBodyPlacer::MultiDevice();
1298     } else {
1299       inline_options.output_control_src = OutputControlSource::kDataOutputs;
1300       inline_options.inlined_function_body_placer =
1301           InlinedFunctionBodyPlacer::SingleDevice();
1302     }
1303 
1304     if (fetch_nodes.contains(n->name())) {
1305       inline_options.keep_caller_node = KeepCallerNode::kFetchable;
1306     } else if (keep_nodes.contains(n->name())) {
1307       inline_options.keep_caller_node = KeepCallerNode::kTargetable;
1308     } else {
1309       inline_options.keep_caller_node = KeepCallerNode::kDoNotKeep;
1310     }
1311 
1312     // Basic validation rules defined in common_runtime shared by all functions.
1313     Status can_inline_function_call =
1314         ValidateInlining(n, fbody.get(), inline_options);
1315 
1316     // Additional validation rules defined only in Grappler.
1317     // TODO(ezhulenev): Move it to common_runtime InlineFunctionBodyOptions?
1318     if (can_inline_function_call.ok()) {
1319       bool has_outgoing_control_edges = absl::c_any_of(
1320           n->out_edges(),
1321           [](const Edge* edge) { return edge->IsControlEdge(); });
1322 
1323       can_inline_function_call = ValidateSideEffectsExecution(
1324           *fbody, inline_options.output_control_src,
1325           has_outgoing_control_edges);
1326 
1327       if (!can_inline_function_call.ok() &&
1328           (is_aggressive || force_inline_as_multi_device)) {
1329         VLOG(2) << "Ignore error: " << can_inline_function_call.error_message();
1330         can_inline_function_call = Status::OK();
1331       }
1332     }
1333     if (can_inline_function_call.ok()) {
1334       can_inline_function_call = ValidateNoDeadOutputs(flib_def, *fbody);
1335     }
1336 
1337     if (can_inline_function_call.ok()) {
1338       VLOG(2) << "Inline function call node: " << n->name();
1339       AddStrictInputSemantics(n, graph.get());
1340       AddFrameForwardingControlEdge(control_flow_info, n, graph.get());
1341 
1342       TF_RETURN_IF_ERROR(InlineFunctionBody(flib_def, graph.get(), n,
1343                                             fbody.get(), inline_options));
1344       inlined_function_names.push_back(fbody->fdef.signature().name());
1345 
1346     } else {
1347       VLOG(2) << "Failed to inline function call node: "
1348               << can_inline_function_call.error_message();
1349     }
1350   }
1351 
1352   VLOG(4) << "Inlined " << inlined_function_names.size()
1353           << " function calls: " << absl::StrJoin(inlined_function_names, ", ");
1354 
1355   // ------------------------------------------------------------------------ //
1356   // Grappler receives the graph after PRE_PLACEMENT, Placer, and POST_PLACEMENT
1357   // passes, so each node has a valid device assignment. After function inlining
1358   // and control flow V2 lowering we have to keep graph placed.
1359 
1360   if (inlined_function_names.empty()) {
1361     VLOG(3) << "Not placing graph after function inlining"
1362             << " (did not inline any of the function calls).";
1363 
1364   } else if (item.devices().empty()) {
1365     // If there are no devices available for placer, we do not place graph after
1366     // function inlining. This happens when Grappler is optimizing the function
1367     // library, or when a graph optimized "offline", without an active runtime
1368     // session, for example as a part of batch job for graph
1369     // analysis/optimization. GrapplerItem instantiated from a function library
1370     // doesn't have to be fully placed after all optimizations; it will be
1371     // placed by the function library runtime before execution.
1372     VLOG(3) << "Not placing graph after function inlining"
1373             << " (device set is empty)";
1374 
1375   } else {
1376     // If we are running in an active runtime session, Grappler will get the
1377     // graph after initial placing is done, and we should have devices for the
1378     // placer.
1379     VLOG(3) << "Run placer for the graph after function inlining. "
1380             << "Devices: [" << absl::StrJoin(item.devices(), ", ") << "]";
1381 
1382     DeviceSet device_set;                               // does not own devices
1383     std::vector<std::unique_ptr<Device>> fake_devices;  // owns fake devices
1384 
1385     for (const string& name : item.devices()) {
1386       auto device = absl::make_unique<FakeDevice>(name);
1387       device_set.AddDevice(device.get());
1388       fake_devices.push_back(std::move(device));
1389     }
1390 
1391     Placer placer(graph.get(), item.id, &device_set);
1392     TF_RETURN_IF_ERROR(placer.Run());
1393   }
1394 
1395   graph->ToGraphDef(output_graph);
1396   return Status::OK();
1397 }
1398 
1399 // Restores tensor mapping after function specialization: all inputs must be
1400 // connected to valid nodes.
RestoreTensorMapping(const FunctionOptimizerContext & ctx,GraphDef * optimized_graph)1401 void RestoreTensorMapping(const FunctionOptimizerContext& ctx,
1402                           GraphDef* optimized_graph) {
1403   if (ctx.tensor_mapping().empty()) return;
1404 
1405   // During function specialization, we might prune unused function outputs. We
1406   // need to "close the holes" that might appear in the function outputs.
1407   //
1408   // Example: prune unused output "f:1"
1409   //
1410   //   f = my_func[T=float](...)          f = my_func_specialized[T=float](...)
1411   //   a = Identity(f:0)             ->   a = Identity(f:0)
1412   //   b = Identity(f:2)                  b = Identity(f:1)
1413   //
1414   // Tensor mapping (size=1): [f:2 -> f:1]
1415   for (NodeDef& node : *optimized_graph->mutable_node()) {
1416     for (int idx = 0; idx < node.input_size(); ++idx) {
1417       TensorId input_tensor = ParseTensorName(node.input(idx));
1418       if (input_tensor.index() == Graph::kControlSlot) break;
1419 
1420       auto mapping = ctx.tensor_mapping().find(input_tensor);
1421       if (mapping != ctx.tensor_mapping().end()) {
1422         node.set_input(idx, mapping->second.ToString());
1423       }
1424     }
1425   }
1426 }
1427 
1428 }  // namespace
1429 
RunFunctionOptimizerPass(const GrapplerItem & item,GraphDef * optimized_graph) const1430 Status FunctionOptimizer::RunFunctionOptimizerPass(
1431     const GrapplerItem& item, GraphDef* optimized_graph) const {
1432   VLOG(3) << "Run function optimizer pass: grappler_item_id=" << item.id;
1433 
1434   // Inline all function calls into a graph using common_runtime/function
1435   // implementation (see `InlineFunctionBody` function documentation).
1436   GraphDef graph_after_inlining;
1437   TF_RETURN_IF_ERROR(InlineFunctionCalls(item, opt_level_, lower_control_flow_,
1438                                          &graph_after_inlining));
1439 
1440   // Specialize function calls that we could not inline.
1441   FunctionOptimizerContext ctx(item, opt_level_, graph_after_inlining);
1442 
1443   for (const NodeDef& node : graph_after_inlining.node()) {
1444     // Function specialization can modify optimized graph only by adding new
1445     // nodes, we can check node size to make sure that graph was not modified.
1446     const int num_nodes_before = optimized_graph->node_size();
1447     const auto is_graph_modified = [&]() {
1448       int num_nodes = optimized_graph->node_size();
1449       DCHECK_GE(num_nodes, num_nodes_before) << "Nodes should not be removed";
1450       return num_nodes > num_nodes_before;
1451     };
1452 
1453     // Copy node from the `graph_after_inlining` to the `optimized_graph`.
1454     const auto copy_node = [&]() { *optimized_graph->add_node() = node; };
1455 
1456     // Find if a node is a function call (direct or indirect).
1457     const FunctionDef* func = FindFunctionCall(ctx, node);
1458     if (func == nullptr) {
1459       copy_node();
1460       continue;
1461     }
1462 
1463     const string& func_name = func->signature().name();
1464 
1465     // Specialize it to its instantiation context if it has something worth
1466     // specializing.
1467     const bool specialization_worthy = IsParametrized(*func) ||
1468                                        HasTrulyConstInputs(node, ctx) ||
1469                                        HasUnusedOutputs(node, *func, ctx);
1470 
1471     // Do not specialize if function has custom gradient or marked nospecialize.
1472     const string grad_func = ctx.function_library().FindGradient(func_name);
1473     const bool no_specialize =
1474         !grad_func.empty() || ctx.IsFeedNode(node.name()) ||
1475         MarkedNoSpecialize(*func) || MarkedForXlaCompilation(node);
1476 
1477     if (specialization_worthy && !no_specialize) {
1478       // TODO(ezhulenev): Specialize function call if input has a known shape.
1479       // Specialize function body for its instantiation attributes and inputs.
1480       Status status = SpecializeFunction(node, *func, &ctx, optimized_graph);
1481       if (!status.ok() && is_graph_modified()) {
1482         return status;
1483       } else if (!status.ok() && !is_graph_modified()) {
1484         VLOG(3) << "Skip specialization error: " << status.error_message();
1485         copy_node();
1486       }
1487       continue;
1488     } else {
1489       VLOG(2) << "Skip function specialization: " << func->signature().name();
1490       copy_node();
1491     }
1492   }
1493 
1494   RestoreTensorMapping(ctx, optimized_graph);
1495 
1496   // Preserve the graph version.
1497   *optimized_graph->mutable_versions() = item.graph.versions();
1498   // Prune unreachable function from the library.
1499   *optimized_graph->mutable_library() =
1500       PruneFunctionLibrary(ctx.function_library(), *optimized_graph);
1501 
1502   return Status::OK();
1503 }
1504 
Optimize(Cluster *,const GrapplerItem & item,GraphDef * optimized_graph)1505 Status FunctionOptimizer::Optimize(Cluster*, const GrapplerItem& item,
1506                                    GraphDef* optimized_graph) {
1507   // Nothing to do here.
1508   if (item.graph.library().function_size() == 0) {
1509     return errors::Aborted("Nothing to do.");
1510   }
1511 
1512   TF_RETURN_IF_ERROR(RunFunctionOptimizerPass(item, optimized_graph));
1513 
1514   return Status::OK();
1515 }
1516 
Feedback(Cluster * cluster,const GrapplerItem & item,const GraphDef & optimized_graph,double result)1517 void FunctionOptimizer::Feedback(Cluster* cluster, const GrapplerItem& item,
1518                                  const GraphDef& optimized_graph,
1519                                  double result) {
1520   // Nothing to do for FunctionOptimizer.
1521 }
1522 
1523 }  // end namespace grappler
1524 }  // end namespace tensorflow
1525