1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/common_runtime/inline_function_utils.h"
17 
18 #include <deque>
19 #include <vector>
20 
21 #include "absl/algorithm/container.h"
22 #include "absl/memory/memory.h"
23 #include "absl/strings/str_cat.h"
24 #include "absl/strings/string_view.h"
25 #include "tensorflow/core/common_runtime/device.h"
26 #include "tensorflow/core/common_runtime/function_utils.h"
27 #include "tensorflow/core/common_runtime/graph_constructor.h"
28 #include "tensorflow/core/framework/collective.h"
29 #include "tensorflow/core/framework/function.h"
30 #include "tensorflow/core/framework/node_def.pb.h"
31 #include "tensorflow/core/framework/node_def_util.h"
32 #include "tensorflow/core/framework/op.h"
33 #include "tensorflow/core/framework/op_kernel.h"
34 #include "tensorflow/core/framework/versions.pb.h"
35 #include "tensorflow/core/graph/algorithm.h"
36 #include "tensorflow/core/graph/control_flow.h"
37 #include "tensorflow/core/graph/node_builder.h"
38 #include "tensorflow/core/graph/optimizer_cse.h"
39 #include "tensorflow/core/lib/core/threadpool.h"
40 #include "tensorflow/core/lib/gtl/map_util.h"
41 #include "tensorflow/core/platform/macros.h"
42 #include "tensorflow/core/profiler/lib/traceme.h"
43 #include "tensorflow/core/protobuf/config.pb.h"
44 
45 namespace tensorflow {
46 
47 /*static*/ constexpr const char* const
48     LowerFunctionalOpsConstants::kLowerUsingSwitchMergeAttr;
49 /*static*/ constexpr const char* const
50     LowerFunctionalOpsConstants::kLowerAsMultiDeviceFunctionAttr;
51 
52 namespace {
53 // A few string constant used throughout this module.
54 static constexpr const char* const kArgOp = FunctionLibraryDefinition::kArgOp;
55 static constexpr const char* const kDeviceArgOp =
56     FunctionLibraryDefinition::kDeviceArgOp;
57 static constexpr const char* const kRetOp = FunctionLibraryDefinition::kRetOp;
58 static constexpr const char* const kDeviceRetOp =
59     FunctionLibraryDefinition::kDeviceRetOp;
60 static constexpr const char* const kGradientOp =
61     FunctionLibraryDefinition::kGradientOp;
62 static constexpr const char* const kNodeLabel = "Func";
63 static constexpr const char* const kFuncAttr =
64     FunctionLibraryDefinition::kFuncAttr;
65 
66 // Represents the index-th output of a node.
67 struct Endpoint {
68   Node* node;
69   int index;
70 
71   // Returns the string name represents this endpoint.
nametensorflow::__anon02067df00111::Endpoint72   string name() const {
73     if (index == 0) {
74       return node->name();
75     } else {
76       return strings::StrCat(node->name(), ":", index);
77     }
78   }
79 
dtypetensorflow::__anon02067df00111::Endpoint80   DataType dtype() const { return node->output_type(index); }
81 };
82 
83 struct EndpointHash {
operator ()tensorflow::__anon02067df00111::EndpointHash84   uint64 operator()(const Endpoint& x) const {
85     return Hash64(reinterpret_cast<const char*>(&x.node), sizeof(Node*),
86                   x.index);
87   }
88 };
89 
90 struct EndpointEq {
operator ()tensorflow::__anon02067df00111::EndpointEq91   bool operator()(const Endpoint& x, const Endpoint& y) const {
92     return (x.node == y.node) && (x.index == y.index);
93   }
94 };
95 
96 // The following Add* routines are used to add a few graph nodes while
97 // functions are transformed.
AddNoOp(StringPiece name,Graph * g)98 static Node* AddNoOp(StringPiece name, Graph* g) {
99   NodeDef ndef;
100   ndef.set_name(g->NewName(absl::StrCat(kNodeLabel, "/", name)));
101   ndef.set_op("NoOp");
102   Status s;
103   Node* ret = g->AddNode(ndef, &s);
104   TF_CHECK_OK(s);
105   return ret;
106 }
107 
AddIdentity(StringPiece name,Graph * g,Endpoint input)108 static Node* AddIdentity(StringPiece name, Graph* g, Endpoint input) {
109   DCHECK_LT(0, input.dtype());
110   NodeDef ndef;
111   ndef.set_name(g->NewName(absl::StrCat(kNodeLabel, "/", name)));
112   ndef.set_op("Identity");
113   ndef.add_input(input.name());
114   AddNodeAttr("T", BaseType(input.dtype()), &ndef);
115   Status s;
116   Node* ret = g->AddNode(ndef, &s);
117   TF_CHECK_OK(s);
118   g->AddEdge(input.node, input.index, ret, 0);
119   return ret;
120 }
121 
InputDevices(const Node & caller)122 std::vector<string> InputDevices(const Node& caller) {
123   std::vector<string> input_devices(caller.in_edges().size());
124   std::vector<string> input_tensors(caller.in_edges().size());
125 
126   for (const Edge* edge : caller.in_edges()) {
127     if (edge->IsControlEdge()) continue;
128     const string& input_device = edge->src()->has_assigned_device_name()
129                                      ? edge->src()->assigned_device_name()
130                                      : edge->src()->requested_device();
131     input_devices[edge->dst_input()] = input_device;
132     input_tensors[edge->dst_input()] =
133         absl::StrCat(edge->src()->name(), ":", edge->src_output());
134   }
135 
136   if (VLOG_IS_ON(4)) {
137     VLOG(4) << "Function instantiation input devices:";
138     for (int i = 0; i < input_devices.size(); ++i) {
139       if (input_tensors[i].empty()) continue;  // skip control edges
140       VLOG(4) << "    [index " << i << "]"
141               << " device: " << input_devices[i]
142               << " (input: " << input_tensors[i] << ")";
143     }
144   }
145 
146   return input_devices;
147 }
148 
149 // Place input nodes on the same device as the corresponding caller input
150 // node. Do not specify any placement for all other nodes.
151 class DefaultFunctionBodyPlacer : public InlinedFunctionBodyPlacer {
152  public:
DefaultFunctionBodyPlacer(const Node & caller)153   explicit DefaultFunctionBodyPlacer(const Node& caller)
154       : input_devices_(InputDevices(caller)) {}
155 
InputNodeDevice(int input_index) const156   absl::optional<string> InputNodeDevice(int input_index) const override {
157     return input_devices_[input_index];
158   }
OutputNodeDevice(int output_index) const159   absl::optional<string> OutputNodeDevice(int output_index) const override {
160     return absl::nullopt;
161   }
ColocateInputOutputIdentities() const162   bool ColocateInputOutputIdentities() const override { return false; }
ControlNodeDevice() const163   absl::optional<string> ControlNodeDevice() const override {
164     return absl::nullopt;
165   }
BodyNodeDevice(const NodeDef & ndef) const166   absl::optional<string> BodyNodeDevice(const NodeDef& ndef) const override {
167     return absl::nullopt;
168   }
169 
170  private:
171   const std::vector<string> input_devices_;
172 };
173 
174 // Place all nodes on the same device as caller node.
175 class SingleDeviceFunctionBodyPlacer : public InlinedFunctionBodyPlacer {
176  public:
SingleDeviceFunctionBodyPlacer(const Node & caller)177   explicit SingleDeviceFunctionBodyPlacer(const Node& caller)
178       : caller_device_(caller.def().device()) {}
179 
InputNodeDevice(int input_index) const180   absl::optional<string> InputNodeDevice(int input_index) const override {
181     return caller_device_;
182   }
OutputNodeDevice(int output_index) const183   absl::optional<string> OutputNodeDevice(int output_index) const override {
184     return caller_device_;
185   }
ColocateInputOutputIdentities() const186   bool ColocateInputOutputIdentities() const override { return false; }
ControlNodeDevice() const187   absl::optional<string> ControlNodeDevice() const override {
188     return caller_device_;
189   }
BodyNodeDevice(const NodeDef & ndef) const190   absl::optional<string> BodyNodeDevice(const NodeDef& ndef) const override {
191     return caller_device_;
192   }
193 
194  private:
195   const string caller_device_;
196 };
197 
198 // Place input nodes on the same device as the corresponding caller input
199 // node. Do not place output node. Place control nodes on the same device as
200 // caller node. For all function body nodes overrides job, replica and task
201 // parts of the device assignment to match function caller node.
202 class MultiDeviceFunctionBodyPlacer : public InlinedFunctionBodyPlacer {
203  public:
MultiDeviceFunctionBodyPlacer(const Node & caller)204   explicit MultiDeviceFunctionBodyPlacer(const Node& caller)
205       : caller_device_(caller.def().device()),
206         input_devices_(InputDevices(caller)) {
207     has_parsed_caller_device_ =
208         DeviceNameUtils::ParseFullName(caller_device_, &caller_parsed_device_);
209   }
210 
InputNodeDevice(int input_index) const211   absl::optional<string> InputNodeDevice(int input_index) const override {
212     return input_devices_[input_index];
213   }
OutputNodeDevice(int output_index) const214   absl::optional<string> OutputNodeDevice(int output_index) const override {
215     return absl::nullopt;
216   }
ColocateInputOutputIdentities() const217   bool ColocateInputOutputIdentities() const override { return true; }
ControlNodeDevice() const218   absl::optional<string> ControlNodeDevice() const override {
219     return caller_device_;
220   }
BodyNodeDevice(const NodeDef & ndef) const221   absl::optional<string> BodyNodeDevice(const NodeDef& ndef) const override {
222     // TODO(ezhulenev): If function would have been instantiated as a
223     // multi-device function and executed via FunctionLibraryRuntime, it could
224     // be potentially placed on any available device. However there are multiple
225     // tests relying on this assumption. Fix them, and remove this line.
226     if (ndef.device().empty()) return caller_device_;
227 
228     if (!has_parsed_caller_device_) return ndef.device();
229 
230     DeviceNameUtils::ParsedName ndef_parsed_device;
231     if (!DeviceNameUtils::ParseFullName(ndef.device(), &ndef_parsed_device))
232       return ndef.device();
233 
234     // Nodes with explicit device placements in the function body have those
235     // respected, but otherwise the function's placement provides a default.
236     if (caller_parsed_device_.has_job && !ndef_parsed_device.has_job) {
237       ndef_parsed_device.has_job = caller_parsed_device_.has_job;
238       ndef_parsed_device.job = caller_parsed_device_.job;
239     }
240 
241     if (caller_parsed_device_.has_replica && !ndef_parsed_device.has_replica) {
242       ndef_parsed_device.has_replica = caller_parsed_device_.has_replica;
243       ndef_parsed_device.replica = caller_parsed_device_.replica;
244     }
245 
246     if (caller_parsed_device_.has_task && !ndef_parsed_device.has_task) {
247       ndef_parsed_device.has_task = caller_parsed_device_.has_task;
248       ndef_parsed_device.task = caller_parsed_device_.task;
249     }
250     return DeviceNameUtils::ParsedNameToString(ndef_parsed_device);
251   }
252 
253  private:
254   string caller_device_;
255   bool has_parsed_caller_device_;
256   DeviceNameUtils::ParsedName caller_parsed_device_;
257   std::vector<string> input_devices_;
258 };
259 
260 }  // namespace
261 
262 std::unique_ptr<InlinedFunctionBodyPlacer>
DefaultPlacer(const Graph & graph,const Node & caller)263 InlinedFunctionBodyPlacer::DefaultPlacer(const Graph& graph,
264                                          const Node& caller) {
265   VLOG(3) << "Create default placer for inlined function body.";
266   return absl::make_unique<DefaultFunctionBodyPlacer>(caller);
267 }
268 
269 std::unique_ptr<InlinedFunctionBodyPlacer>
SingleDevicePlacer(const Graph & graph,const Node & caller)270 InlinedFunctionBodyPlacer::SingleDevicePlacer(const Graph& graph,
271                                               const Node& caller) {
272   VLOG(3) << "Create single device placer for inlined function body.";
273   return absl::make_unique<SingleDeviceFunctionBodyPlacer>(caller);
274 }
275 
276 std::unique_ptr<InlinedFunctionBodyPlacer>
MultiDevicePlacer(const Graph & graph,const Node & caller)277 InlinedFunctionBodyPlacer::MultiDevicePlacer(const Graph& graph,
278                                              const Node& caller) {
279   VLOG(3) << "Create multi device placer for inlined function body.";
280   return absl::make_unique<MultiDeviceFunctionBodyPlacer>(caller);
281 }
282 
283 namespace {
284 
ValidateNoInline(const FunctionBody * fbody)285 Status ValidateNoInline(const FunctionBody* fbody) {
286   const auto attr = AttrSlice(&fbody->fdef.attr());
287   bool noinline = false;
288   if (TryGetNodeAttr(attr, kNoInlineAttr, &noinline) && noinline) {
289     return errors::InvalidArgument(
290         "Can't inline function marked with '_noinline'");
291   }
292   return Status::OK();
293 }
294 
295 using OutputControlSrc = InlineFunctionBodyOptions::OutputControlSource;
296 
297 // Propagate the debug info of `nodes` in function `func` to the `target` node.
298 // If the debug info of any node is missing, its node name and function name
299 // is used.
PropagateDebugInfoToNode(const string & func,const std::vector<const Node * > & nodes,NodeDef * target)300 void PropagateDebugInfoToNode(const string& func,
301                               const std::vector<const Node*>& nodes,
302                               NodeDef* target) {
303   if (nodes.empty() || target->has_experimental_debug_info()) {
304     return;
305   }
306   for (const Node* node : nodes) {
307     const auto& node_def = node->def();
308     if (node_def.has_experimental_debug_info()) {
309       target->mutable_experimental_debug_info()->MergeFrom(
310           node_def.experimental_debug_info());
311     } else {
312       target->mutable_experimental_debug_info()->add_original_node_names(
313           node_def.name());
314       target->mutable_experimental_debug_info()->add_original_func_names(func);
315     }
316   }
317 }
318 }  // namespace
319 
DebugString() const320 string InlineFunctionBodyOptions::DebugString() const {
321   const auto true_false = [](bool b) { return b ? "true" : "false"; };
322 
323   const auto keep_caller_node_str = [this]() -> string {
324     switch (keep_caller_node) {
325       case KeepCallerNode::kDoNotKeep:
326         return "DoNotKeep";
327       case KeepCallerNode::kFetchable:
328         return "Fetchable";
329       case KeepCallerNode::kTargetable:
330         return "Targetable";
331     }
332   };
333 
334   return absl::StrCat(
335       "disable_inlining=", true_false(disable_inlining),
336       ", ignore_noinline=", true_false(ignore_noinline),
337       ", inline_impl_selection_group_functions=",
338       true_false(inline_impl_selection_group_functions),
339       ", keep_caller_node=", keep_caller_node_str(), ", output_control_src=",
340       output_control_src == OutputControlSrc::kDataOutputs ? "DataOutputs"
341                                                            : "ControlOutputs",
342       ", inlined_function_body_placer=", inlined_function_body_placer.name,
343       ", uniquify_frame_names=", true_false(uniquify_frame_names));
344 }
345 
ValidateInlining(const Node * node,const FunctionBody * fbody,const InlineFunctionBodyOptions & options)346 Status ValidateInlining(const Node* node, const FunctionBody* fbody,
347                         const InlineFunctionBodyOptions& options) {
348   // TODO(ezhulenev): Currently common_runtime function inlining can't guarantee
349   // that all side-effectful ops will be executed after inlining. See Grappler
350   // function_optimizer for details. Unify all function inlining mechanism.
351   // Do not inline if `!fbody->control_ret_nodes.empty()`.
352 
353   const auto num_node_inputs = static_cast<size_t>(node->num_inputs());
354   const auto num_node_outputs = static_cast<size_t>(node->num_outputs());
355 
356   if (num_node_inputs != fbody->arg_types.size() ||
357       num_node_inputs != fbody->arg_nodes.size()) {
358     return errors::InvalidArgument(
359         "Node inputs do not match function arguments: inputs=", num_node_inputs,
360         " arg_types=", fbody->arg_types.size(),
361         " arg_nodes=", fbody->arg_nodes.size());
362   }
363 
364   if (num_node_outputs != fbody->ret_types.size() ||
365       num_node_outputs != fbody->ret_nodes.size()) {
366     return errors::InvalidArgument(
367         "Node outputs do not match function returns: outputs=",
368         num_node_outputs, " ret_types=", fbody->ret_types.size(),
369         " ret_nodes=", fbody->ret_nodes.size());
370   }
371 
372   for (int i = 0; i < node->num_inputs(); ++i) {
373     if (node->input_type(i) != fbody->arg_types[i]) {
374       return errors::InvalidArgument(
375           "Node input type doesn't match function argument type: ",
376           node->input_type(i), " != ", fbody->arg_types[i], " @ index=", i);
377     }
378   }
379   for (int i = 0; i < node->num_outputs(); ++i) {
380     if (node->output_type(i) != fbody->ret_types[i]) {
381       return errors::InvalidArgument(
382           "Node output type doesn't match function return type: ",
383           node->output_type(i), " != ", fbody->ret_types[i], " @ index=", i);
384     }
385   }
386 
387   if (options.disable_inlining) {
388     return errors::InvalidArgument(
389         "Function inlining explicitly disabled by 'options.disable_inlining'");
390   }
391 
392   if (!options.inline_impl_selection_group_functions) {
393     bool is_impl_selection_group_function =
394         fbody->fdef.attr().find("api_implements") != fbody->fdef.attr().end();
395     if (is_impl_selection_group_function) {
396       return errors::InvalidArgument(
397           "Inlining of implementation selection group function ",
398           fbody->fdef.signature().name(),
399           " is disabled by options.inline_impl_selection_group_functions");
400     }
401   }
402 
403   if (!options.ignore_noinline) {
404     TF_RETURN_IF_ERROR(ValidateNoInline(fbody));
405   }
406 
407   return Status::OK();
408 }
409 
410 // Function inlining must preserve function execution semantics with regards to
411 // side-effects visibility. Tensorflow in Eager mode has an automatic control
412 // dependencies tracking mechanism, which enforces well-defined execution order
413 // of all side-effects. Any other frontend (e.g. Swift) must produce graphs
414 // following the same rules, to ensure that function inlining works correctly.
415 //
416 // IMPORTANT: Currently we do not have a true notion of "side-effectful" node,
417 // we assume that all stateful nodes might have side-effects, though it's not
418 // true in practice, e.g. `ReadVariableOp` doesn't have an observable
419 // side-effect.
420 //
421 // Automatic control dependency rules in Tensorflow 2.0 (python in eager mode):
422 //
423 // 1) When a function has a resource (DT_RESOURCE data type) input argument it
424 //   "captures" the mutable resource.  This is implemented by automatically
425 //    adding a incoming control edge from the previous side-effectful op
426 //    touching that resource, and an outgoing control edge to the next
427 //    side-effectful op using the same resource. This serializes the mutations
428 //    of the resource to make graph execution deterministic.
429 //
430 // 2) All stateful ops inside a function body are guaranteed to execute in
431 //    program order, this is achieved by adding control edges between stateful
432 //    ops at graph construction time. Stateful ops (or ops that must execute)
433 //    should be in the function control return set. Having a data edge to the
434 //    regular function output might be not enough, because after function
435 //    inlining it might happen that data output is unused.
436 //
437 // 3) Furthermore, all ops accepting the same resource as an input are
438 //    guaranteed to run in program order. This is also done by adding control
439 //    edges at graph construction time. The last op touching the resource
440 //    must be in a control return set, which will guarantee that all side
441 //    effects to the resource will happen before function completion.
442 //
443 // Function inlining must preserve side-effect visibility:
444 //
445 // 1) All side-effects to the captured resources, that happened before function
446 //    call must be visible to the function body nodes using that resources.
447 //
448 // 2) All side-effects to the captured resources, that happened inside function
449 //    body, must be visible to every op/function using that resource after the
450 //    function call completed.
451 //
452 // To guarantee that these properties are preserved after inlining we:
453 //
454 // 1) Create "input_control_node" NoOp. Function call node incoming control
455 //    edges will be forwarded *to* this node. Function inputs (Identity nodes)
456 //    will have a control edge *from* this node. If function body has nodes
457 //    without inputs, they will have a control edge *from* this node.
458 //
459 // 2) Create "output_control_node" NoOp. All nodes that have incoming control
460 //    edge *from* the function call node, will be forwarded to this node.
461 //
462 //    We have two options for choosing which nodes will have a control edge *to*
463 //    the "output control node":
464 //       a) control returns            (`control_ret` field in FunctionDef)
465 //       b) data returns               (`ret` field in FunctionDef)
466 //
467 //    We do a) for multi-device function calls in Tensorflow v2 and b)
468 //    for the rest for compatibility with Tensorflow v1.
469 //
470 //    Following the automatic control dependencies tracking rules, a node that
471 //    has an incoming control edge from the function call node is dependent on
472 //    the side-effects happening inside the function body. The output control
473 //    node will guarantee side-effects execution order.
474 //
475 //    If function call node doesn't have an outgoing control edge, it means that
476 //    no one is interested in observing side-effects that might have happened.
477 //
478 // Function inlining might leave the graph in partially-placed state. Function
479 // inlining caller must call Placer to guarantee that all nodes are placed.
480 //
481 // Function inlining with `options.override_device=true` will leave graph in
482 // fully placed state, by overriding all inlined nodes devices with the caller
483 // node device, but it will make functions always single-device. These functions
484 // after inlining will not be able to handle resources on multiple devices. This
485 // is currently acceptable for XLA use cases (XLA cluster is always executed on
486 // a single device).
487 //
488 // TODO(ezhulenev): Documentation above is ahead of implementation below.
InlineFunctionBody(const FunctionLibraryDefinition & flib_def,Graph * g,Node * caller,const FunctionBody * fbody,const InlineFunctionBodyOptions & options)489 Status InlineFunctionBody(const FunctionLibraryDefinition& flib_def, Graph* g,
490                           Node* caller, const FunctionBody* fbody,
491                           const InlineFunctionBodyOptions& options) {
492   VLOG(3) << "Inline function call: " << SummarizeNode(*caller) << " ["
493           << options.DebugString() << "]";
494 
495   Status validation = ValidateInlining(caller, fbody, options);
496   if (!validation.ok()) {
497     return errors::Internal("Inlining mismatch: ", validation.error_message());
498   }
499 
500   // Placer is responsible for assigning devices for all nodes that we will add
501   // to the graph.
502   const std::unique_ptr<InlinedFunctionBodyPlacer> placer =
503       options.inlined_function_body_placer.get(*g, *caller);
504 
505   // We can't possibly introduce a duplicate control edge during function
506   // inlining, so we skip this check in calls to the 'g->AddControlEdge(...)'.
507   static constexpr bool kDoNotCheckDuplicates = true;
508 
509   // ------------------------------------------------------------------------ //
510   // Helper functions to create `NoOp` and `Identity` nodes for auxiliary
511   // control nodes and inlined function inputs and outputs.
512 
513   // Add a NoOp node for function control inputs/outputs.
514   const auto no_op = [&](StringPiece name) -> Node* {
515     Node* node = AddNoOp(absl::StrCat(caller->name(), "/", name), g);
516     const absl::optional<string> device = placer->ControlNodeDevice();
517     if (device.has_value()) node->set_requested_device(*device);
518     return node;
519   };
520 
521   // Add an Identity node for function input.
522   const auto input_identity = [&](StringPiece name, Endpoint input,
523                                   int index) -> Node* {
524     Node* node = AddIdentity(absl::StrCat(caller->name(), "/", name), g, input);
525     const absl::optional<string> device = placer->InputNodeDevice(index);
526     if (device.has_value()) node->set_requested_device(*device);
527     bool colocate_identity = placer->ColocateInputOutputIdentities();
528     if (colocate_identity) {
529       node->AddAttr(kColocationAttrName,
530                     std::vector<string>{absl::StrCat(kColocationGroupPrefix,
531                                                      input.node->name())});
532     }
533     return node;
534   };
535 
536   // Add an Identity node for function output.
537   const auto output_identity = [&](StringPiece name, Endpoint input,
538                                    int index) -> Node* {
539     Node* node = AddIdentity(absl::StrCat(caller->name(), "/", name), g, input);
540     const absl::optional<string> device = placer->OutputNodeDevice(index);
541     if (device.has_value()) node->set_requested_device(*device);
542     bool colocate_identity = placer->ColocateInputOutputIdentities();
543     if (colocate_identity) {
544       node->AddAttr(kColocationAttrName,
545                     std::vector<string>{absl::StrCat(kColocationGroupPrefix,
546                                                      input.node->name())});
547     }
548     return node;
549   };
550 
551   // ------------------------------------------------------------------------ //
552   // Helper function to get an input/output argument name by index. For
553   // functions instantiated from SymbolicGradien corresponding FunctionDef is
554   // empty, and argument name is unknown.
555 
556   auto arg_name = [&](auto& args, size_t i) -> absl::string_view {
557     if (i < args.size()) {
558       return args[i].name();
559     } else {
560       return "<unknown>";
561     }
562   };
563 
564   // ------------------------------------------------------------------------ //
565   // Input edges. For data edges coming into "caller", we first compute the
566   // <src>:<src_output> for the i-th input in "inputs".
567   // If "caller" has any input control dependencies, we add a NoOp
568   // node "input_control_node", which depends on "caller"'s control inputs.
569   std::vector<Endpoint> inputs(caller->num_inputs());
570   Node* input_control_node = nullptr;
571   for (const Edge* e : caller->in_edges()) {
572     if (e->IsControlEdge()) {
573       if (input_control_node == nullptr) {
574         input_control_node = no_op("input_control_node");
575       }
576       g->AddControlEdge(e->src(), input_control_node, kDoNotCheckDuplicates);
577     } else {
578       inputs[e->dst_input()] = {e->src(), e->src_output()};
579     }
580   }
581   if (input_control_node != nullptr) {
582     VLOG(3) << "Created input control node: " << input_control_node->name();
583   }
584 
585   // ------------------------------------------------------------------------ //
586   // Duplicate fbody->graph into 'g'.  First, we copy the nodes of
587   // fbody->graph into 'g' except the source and sink nodes.  We copy
588   // edges among nodes in 'fbody->graph'.
589   //
590   // If 'x' is a node in fbody->graph and its copy in 'g' is 'y', we
591   // remember 'y' in node_map[x->id()].
592   std::unordered_set<string> fn_nodes;
593   for (Node* n : fbody->graph->op_nodes()) {
594     fn_nodes.insert(n->name());
595   }
596   std::vector<Node*> node_map(fbody->graph->num_node_ids());
597   for (Node* n : fbody->graph->op_nodes()) {
598     NodeDef ndef = n->def();
599 
600     // Maybe override requested node device assignment.
601     const absl::optional<string> device = placer->BodyNodeDevice(ndef);
602     if (device.has_value()) ndef.set_device(*device);
603 
604     // Add inlined function name to inlined node debug information.
605     PropagateDebugInfoToNode(fbody->fdef.signature().name(), {n}, &ndef);
606 
607     // Add the function node name as a prefix:
608     //  1) to node name to avoid collisions
609     //  2) to frame name to avoid multiple LoopCond nodes in one frame
610     //  3) to colocation attribute
611     const string prefix = strings::StrCat(caller->name(), "/");
612     TF_RETURN_IF_ERROR(AddPrefixAndSuffixToNode(prefix, /*suffix=*/"", &ndef,
613                                                 options.uniquify_frame_names));
614     TF_RETURN_IF_ERROR(
615         MaybeAddPrefixToColocationConstraints(fn_nodes, prefix, &ndef));
616 
617     Status added_node;
618     Node* clone = g->AddNode(ndef, &added_node);
619     TF_CHECK_OK(added_node);
620     node_map[n->id()] = clone;
621     clone->SetStackTrace(n->GetStackTrace());
622 
623     // If there is an input control node, and one of:
624     // a) the node has no data or control inputs, or
625     // b) the node is a function call (including SymbolicGradient),
626     //    then add a control edge from the input control node to the clone (only
627     //    if it does not already have a control input).
628     //
629     // We must not execute any nodes if the original function call would not
630     // have executed. This is especially critical when the function call is
631     // inside a control-flow construct like tf.cond(). Case (a) ensures that
632     // such nodes do not run.
633     //
634     // The purpose of case (b) is to ensure that instances of case (a) created
635     // by further inlining steps also receive the control dependency.
636     //
637     // This edge is required to transfer execution frame down to all function
638     // body nodes of inlined nested function calls.
639     if (input_control_node) {
640       const auto is_input_edge = [](const Edge* e) -> bool {
641         return !e->src()->IsSource();
642       };
643       const auto is_control_edge = [](const Edge* e) -> bool {
644         return !e->src()->IsSource() && e->IsControlEdge();
645       };
646 
647       // Forward execution frame if:
648       //
649       // a) The node has no data or control inputs.
650       // b) OR the node is a function call without control inputs (control edge
651       //    will be used in nested function inlining to forward execution frame
652       //    to constants inside the function body).
653       //
654       // c) Do not forward control frame to function argument nodes, they will
655       //    be connected to the corresponding function input later.
656       const bool forward_execution_frame =
657           (absl::c_none_of(n->in_edges(), is_input_edge) ||       // (a)
658            (n->IsFunctionCall() &&                                // (b)
659             absl::c_none_of(n->in_edges(), is_control_edge))) &&  //
660           !n->IsArg();                                            // (c)
661 
662       if (forward_execution_frame) {
663         VLOG(4) << "Add control edge from input control node to: "
664                 << clone->name();
665         g->AddControlEdge(input_control_node, clone, kDoNotCheckDuplicates);
666       }
667     }
668   }
669   for (const Edge* e : fbody->graph->edges()) {
670     if (e->src()->IsSource() || e->src()->IsSink() || e->dst()->IsSource() ||
671         e->dst()->IsSink()) {
672       continue;
673     }
674     Node* src_copy = node_map[e->src()->id()];
675     Node* dst_copy = node_map[e->dst()->id()];
676     g->AddEdge(src_copy, e->src_output(), dst_copy, e->dst_input());
677   }
678 
679   // ------------------------------------------------------------------------ //
680   // Connect input edges.
681   //
682   // We create one Identity node for each input. Then, we connect inputs[i] to
683   // the i-th identity node added. The nodes that previously connected
684   // to the j-th output of i-th arg node are reconnected to the i-th
685   // identity node.
686   //
687   // The added identity nodes depend on "input_control_node".
688   VLOG(4) << "Add input Identity nodes for each function argument:";
689 
690   for (std::size_t i = 0; i < fbody->arg_nodes.size(); ++i) {
691     Node* arg = node_map[fbody->arg_nodes[i]->id()];
692     Node* n = input_identity("input", inputs[i], i);
693     VLOG(4) << "    [index " << i << "] "
694             << arg_name(fbody->fdef.signature().input_arg(), i) << " as "
695             << n->name() << " (input: " << inputs[i].name()
696             << ", requested_device: " << n->requested_device() << ")";
697 
698     if (input_control_node) {
699       g->AddControlEdge(input_control_node, n, kDoNotCheckDuplicates);
700     }
701     for (const Edge* e : arg->out_edges()) {
702       if (e->IsControlEdge()) {
703         g->AddControlEdge(n, e->dst(), kDoNotCheckDuplicates);
704       } else {
705         g->AddEdge(n, 0, e->dst(), e->dst_input());
706       }
707     }
708     node_map[fbody->arg_nodes[i]->id()] = n;
709     g->RemoveNode(arg);  // 'arg' is disconnected.
710   }
711 
712   // ------------------------------------------------------------------------ //
713   // Connect output edges.
714   //
715   // For i-th return node in fbody->graph, we add in "g" an identity node
716   // (outputs[i-th]). We then reconnect every incoming edge into the i-th return
717   // node to the added identity node.
718   //
719   // For every data edge coming out of "callee"s i-th output, we reconnect it to
720   // the i-th identity added above.
721   //
722   // If "callee" is control-depended upon by any other nodes, we add a NoOp node
723   // "output_control_node". "output_control_node" depends on all identity nodes
724   // added above or on all control return nodes (controlled by
725   // `options.output_control_src` value). And nodes previously depend on
726   // "callee" is changed to depend on "output_control_node".
727   //
728   // If `keep_node_fetchable` is `true` we always add an output control node, to
729   // guarantee that executing a fetchable node will execute all side-effects.
730   VLOG(4) << "Add output Identity nodes for each function output argument:";
731 
732   std::vector<Node*> outputs(caller->num_outputs());
733   for (std::size_t i = 0; i < fbody->ret_nodes.size(); ++i) {
734     Node* ret = node_map[fbody->ret_nodes[i]->id()];
735     Endpoint data;  // Data input for the ret node.
736     for (const Edge* e : ret->in_edges()) {
737       if (!e->IsControlEdge()) {
738         data = {e->src(), e->src_output()};
739         break;
740       }
741     }
742     CHECK(data.node != nullptr);
743     Node* n = output_identity("output", data, i);
744     outputs[i] = n;
745     VLOG(4) << "    [index " << i << "] "
746             << arg_name(fbody->fdef.signature().output_arg(), i) << " as "
747             << n->name() << " (ret: " << data.node->name() << ":" << data.index
748             << ", requested_device: " << n->requested_device() << ")";
749     for (const Edge* e : ret->in_edges()) {
750       if (e->IsControlEdge()) {
751         g->AddControlEdge(e->src(), n, kDoNotCheckDuplicates);
752       }
753     }
754     g->RemoveNode(ret);  // 'ret' is disconnected.
755   }
756 
757   Node* output_control_node = nullptr;
758   const bool has_control_outputs = absl::c_any_of(
759       caller->out_edges(), [](const Edge* e) { return e->IsControlEdge(); });
760 
761   using KeepCallerNode = InlineFunctionBodyOptions::KeepCallerNode;
762   const bool keep_caller_node =
763       options.keep_caller_node == KeepCallerNode::kFetchable ||
764       options.keep_caller_node == KeepCallerNode::kTargetable;
765 
766   if (has_control_outputs || keep_caller_node) {
767     output_control_node = no_op("output_control_node");
768     VLOG(4) << "Add output control node: " << output_control_node->name();
769     if (options.output_control_src == OutputControlSrc::kDataOutputs) {
770       for (Node* n : outputs) {
771         VLOG(4) << "    [data output] add control edge from: " << n->name();
772         g->AddControlEdge(n, output_control_node, kDoNotCheckDuplicates);
773       }
774     } else {
775       for (Node* fbody_node : fbody->control_ret_nodes) {
776         Node* n = node_map[fbody_node->id()];
777         VLOG(4) << "    [control output] add control edge from: " << n->name();
778         g->AddControlEdge(n, output_control_node, kDoNotCheckDuplicates);
779       }
780     }
781   }
782 
783   // We can't leave output control node without incoming control edges, because
784   // in this case outgoing control edge will loose execution frame information.
785   // We connect input_control_node and output_control_node with a control edge
786   // to forward execution frame to the controlled nodes. Above we add a control
787   // edge to all function calls inside function body, to guarantee that we will
788   // always have input_control_node when we need it.
789   if (output_control_node && output_control_node->in_edges().empty()) {
790     if (input_control_node) {
791       VLOG(4) << "Add a control edge between input and output control nodes: "
792               << input_control_node->name() << " to "
793               << output_control_node->name();
794       g->AddControlEdge(input_control_node, output_control_node,
795                         kDoNotCheckDuplicates);
796     } else {
797       VLOG(4) << "Function inlining potentially dropped execution frame "
798                  "information from outgoing control edges.";
799     }
800   }
801 
802   for (const Edge* e : caller->out_edges()) {
803     if (e->IsControlEdge()) {
804       g->AddControlEdge(output_control_node, e->dst(), kDoNotCheckDuplicates);
805     } else {
806       g->AddEdge(outputs[e->src_output()], 0, e->dst(), e->dst_input());
807     }
808   }
809 
810   // ------------------------------------------------------------------------ //
811   // Add an IdentityN or NoOp node in-place of caller node to keep `caller`
812   // fetchable or targetable.
813 
814   if (keep_caller_node) {
815     std::vector<NodeBuilder::NodeOut> output_tensors;
816     absl::c_transform(outputs, std::back_inserter(output_tensors),
817                       [](Node* n) { return NodeBuilder::NodeOut(n, 0); });
818 
819     Node* caller_substitute_node;
820     if (options.keep_caller_node == KeepCallerNode::kTargetable ||
821         output_tensors.empty()) {
822       // IdentityN node must have at least one data input. If function has no
823       // data outputs, we can't keep it fetchable.
824       TF_CHECK_OK(NodeBuilder(caller->name(), "NoOp")
825                       .Device(caller->requested_device())
826                       .ControlInput(output_control_node)
827                       .Finalize(g, &caller_substitute_node));
828 
829     } else if (options.keep_caller_node == KeepCallerNode::kFetchable) {
830       TF_CHECK_OK(NodeBuilder(caller->name(), "IdentityN")
831                       .Device(caller->requested_device())
832                       .Input(output_tensors)
833                       .ControlInput(output_control_node)
834                       .Finalize(g, &caller_substitute_node));
835     }
836   }
837 
838   // ------------------------------------------------------------------------ //
839   // 'caller' is replaced with inlined function body nodes and maybe IdentityN
840   // to keep it fetchable.
841   VLOG(3) << "Successfully inlined function call node: " << caller->name();
842   g->RemoveNode(caller);
843 
844   return Status::OK();
845 }
846 
ExpandInlineFunctions(FunctionLibraryRuntime * lib,Graph * graph,const ExpandInlineFunctionsOptions & options)847 bool ExpandInlineFunctions(FunctionLibraryRuntime* lib, Graph* graph,
848                            const ExpandInlineFunctionsOptions& options) {
849   std::vector<std::pair<Node*, const FunctionBody*>> candidates;
850 
851   const FunctionLibraryDefinition* fld = lib->GetFunctionLibraryDefinition();
852 
853   for (Node* node : graph->nodes()) {
854     // Skip nodes that are not function calls or SymbolicGradient calls.
855     if (!IsFunctionCall(*lib->GetFunctionLibraryDefinition(), *node)) {
856       continue;
857     }
858     // Skip function calls that marked noinline.
859     bool noinline;
860     if (fld->GetAttr(*node, kNoInlineAttr, &noinline).ok() && noinline) {
861       VLOG(3) << "noinline: " << SummarizeNode(*node);
862       continue;
863     }
864     FunctionLibraryRuntime::Handle handle;
865     Status s = InstantiateFunctionCall(node->def(), lib, &handle);
866     if (!s.ok()) {
867       LOG(ERROR) << "Failed to instantiate a function:  " << s.error_message();
868       continue;
869     }
870     const FunctionBody* fbody = lib->GetFunctionBody(handle);
871     CHECK_NOTNULL(fbody);
872     candidates.emplace_back(node, fbody);
873   }
874 
875   bool inlined_any = false;
876   for (const auto& p : candidates) {
877     Status inlined = InlineFunctionBody(*fld, graph, p.first, p.second,
878                                         p.first->IsPartitionedCall()
879                                             ? options.multi_device_options
880                                             : options.native_options);
881     if (inlined.ok()) {
882       inlined_any = true;
883     } else {
884       VLOG(1) << "Failed to inline function call: node=" << p.first->name()
885               << " error=" << inlined.error_message();
886     }
887   }
888 
889   // TODO(ezhulenev): Release handles for inlined function calls.
890 
891   return inlined_any;
892 }
893 
894 }  // end namespace tensorflow
895