1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/common_runtime/function.h"
17 
18 #include <deque>
19 #include <vector>
20 
21 #include "absl/algorithm/container.h"
22 #include "absl/strings/str_cat.h"
23 #include "tensorflow/core/common_runtime/device.h"
24 #include "tensorflow/core/common_runtime/executor.h"
25 #include "tensorflow/core/common_runtime/executor_factory.h"
26 #include "tensorflow/core/common_runtime/graph_optimizer.h"
27 #include "tensorflow/core/common_runtime/memory_types.h"
28 #include "tensorflow/core/common_runtime/rendezvous_mgr.h"
29 #include "tensorflow/core/framework/collective.h"
30 #include "tensorflow/core/framework/function.h"
31 #include "tensorflow/core/framework/node_def.pb.h"
32 #include "tensorflow/core/framework/node_def_util.h"
33 #include "tensorflow/core/framework/op.h"
34 #include "tensorflow/core/framework/op_kernel.h"
35 #include "tensorflow/core/framework/versions.pb.h"
36 #include "tensorflow/core/graph/algorithm.h"
37 #include "tensorflow/core/graph/control_flow.h"
38 #include "tensorflow/core/graph/gradients.h"
39 #include "tensorflow/core/graph/graph_constructor.h"
40 #include "tensorflow/core/graph/optimizer_cse.h"
41 #include "tensorflow/core/lib/core/threadpool.h"
42 #include "tensorflow/core/lib/gtl/map_util.h"
43 #include "tensorflow/core/platform/macros.h"
44 
45 // See core/kernels/function_ops.cc for related kernels.
46 
47 namespace tensorflow {
48 
49 // A few string constant used throughout this module.
50 static constexpr const char* const kArgOp = FunctionLibraryDefinition::kArgOp;
51 static constexpr const char* const kDeviceArgOp =
52     FunctionLibraryDefinition::kDeviceArgOp;
53 static constexpr const char* const kRetOp = FunctionLibraryDefinition::kRetOp;
54 static constexpr const char* const kDeviceRetOp =
55     FunctionLibraryDefinition::kDeviceRetOp;
56 static constexpr const char* const kGradientOp =
57     FunctionLibraryDefinition::kGradientOp;
58 static constexpr const char* const kNodeLabel = "Func";
59 static constexpr const char* const kFuncAttr =
60     FunctionLibraryDefinition::kFuncAttr;
61 
62 // Represents the index-th output of a node.
63 struct Endpoint {
64   Node* node;
65   int index;
66 
67   // Returns the string name represents this endpoint.
nametensorflow::Endpoint68   string name() const {
69     if (index == 0) {
70       return node->name();
71     } else {
72       return strings::StrCat(node->name(), ":", index);
73     }
74   }
75 
dtypetensorflow::Endpoint76   DataType dtype() const { return node->output_type(index); }
77 };
78 
79 struct EndpointHash {
operator ()tensorflow::EndpointHash80   uint64 operator()(const Endpoint& x) const {
81     return Hash64(reinterpret_cast<const char*>(&x.node), sizeof(Node*),
82                   x.index);
83   }
84 };
85 
86 struct EndpointEq {
operator ()tensorflow::EndpointEq87   bool operator()(const Endpoint& x, const Endpoint& y) const {
88     return (x.node == y.node) && (x.index == y.index);
89   }
90 };
91 
92 // The following Add* routines are used to add a few graph nodes while
93 // functions are transformed.
AddNoOp(StringPiece name,Graph * g)94 static Node* AddNoOp(StringPiece name, Graph* g) {
95   NodeDef ndef;
96   ndef.set_name(g->NewName(absl::StrCat(kNodeLabel, "/", name)));
97   ndef.set_op("NoOp");
98   Status s;
99   Node* ret = g->AddNode(ndef, &s);
100   TF_CHECK_OK(s);
101   return ret;
102 }
103 
AddIdentity(StringPiece name,Graph * g,Endpoint input)104 static Node* AddIdentity(StringPiece name, Graph* g, Endpoint input) {
105   DCHECK_LT(0, input.dtype());
106   NodeDef ndef;
107   ndef.set_name(g->NewName(absl::StrCat(kNodeLabel, "/", name)));
108   ndef.set_op("Identity");
109   // NOTE(skyewm): we explicitly set the device here to address a multi-GPU
110   // performance issue where this Identity would be placed alone on a GPU,
111   // causing unnecessary device traffic. See b/122483225 for details.
112   ndef.set_device(input.node->def().device());
113   ndef.add_input(input.name());
114   AddNodeAttr("T", BaseType(input.dtype()), &ndef);
115   Status s;
116   Node* ret = g->AddNode(ndef, &s);
117   TF_CHECK_OK(s);
118   g->AddEdge(input.node, input.index, ret, 0);
119   return ret;
120 }
121 
AddArg(Graph * g,DataType dtype,int index)122 static Node* AddArg(Graph* g, DataType dtype, int index) {
123   DCHECK_LT(0, dtype);
124   DCHECK_LT(dtype, DT_FLOAT_REF);
125   NodeDef ndef;
126   ndef.set_name(g->NewName(kNodeLabel));
127   ndef.set_op(kArgOp);
128   AddNodeAttr("T", dtype, &ndef);
129   AddNodeAttr("index", index, &ndef);
130   Status s;
131   Node* ret = g->AddNode(ndef, &s);
132   TF_CHECK_OK(s);
133   return ret;
134 }
135 
AddRet(Graph * g,Endpoint input,int index)136 static Node* AddRet(Graph* g, Endpoint input, int index) {
137   DCHECK_LT(0, input.dtype());
138   DCHECK_LT(input.dtype(), DT_FLOAT_REF);
139   NodeDef ndef;
140   ndef.set_name(g->NewName(kNodeLabel));
141   ndef.set_op(kRetOp);
142   ndef.add_input(input.name());
143   AddNodeAttr("T", input.dtype(), &ndef);
144   AddNodeAttr("index", index, &ndef);
145   Status s;
146   Node* ret = g->AddNode(ndef, &s);
147   TF_CHECK_OK(s);
148   g->AddEdge(input.node, input.index, ret, 0);
149   return ret;
150 }
151 
152 // FunctionLibraryRuntime implementation that forwards all the function calls to
153 // the base runtime implementation, and only overrides overlay lib in calls to
154 // Instantiate (if caller doesn't provide its own overlay lib).
155 //
156 // When function library runtime (FunctionLibraryRuntimeImpl specifically)
157 // instantiates function into a Graph object, it also creates an Executor for
158 // it. That executor has a pointer to the function library runtime instance,
159 // that is used to instantiate all nested function calls.
160 //
161 // If the original function was instantiated using overlay lib, we must preserve
162 // that overlay lib in the executor's function library runtime.
163 //
164 // IMPORTANT: This runtime is intended for use only in executors created for
165 // functions instantiated into a graph in FunctionLibraryRuntimeImpl.
166 class FunctionLibraryRuntimeOverlay : public FunctionLibraryRuntime {
167  public:
FunctionLibraryRuntimeOverlay(FunctionLibraryRuntime * base_flr,const FunctionLibraryDefinition * overlay_lib_def)168   FunctionLibraryRuntimeOverlay(
169       FunctionLibraryRuntime* base_flr,
170       const FunctionLibraryDefinition* overlay_lib_def)
171       : base_flr_(base_flr), overlay_lib_def_(overlay_lib_def) {}
172   ~FunctionLibraryRuntimeOverlay() override;
173 
174   Status Instantiate(const string& function_name, AttrSlice attrs,
175                      const InstantiateOptions& options,
176                      Handle* handle) override;
177 
178   Status ReleaseHandle(Handle handle) override;
179 
180   const FunctionBody* GetFunctionBody(Handle h) override;
181 
182   void Run(const Options& opts, Handle handle, gtl::ArraySlice<Tensor> args,
183            std::vector<Tensor>* rets, DoneCallback done) override;
184 
185   void Run(const Options& opts, Handle handle, CallFrameInterface* call_frame,
186            DoneCallback done) override;
187 
188   Status CreateKernel(const NodeDef& ndef, OpKernel** kernel) override;
189 
190   bool IsStateful(const string& function_name) override;
191 
192   const FunctionLibraryDefinition* GetFunctionLibraryDefinition()
193       const override;
194 
195   Env* env() override;
196   Device* device() override;
197   std::function<void(std::function<void()>)>* runner() override;
198   const DeviceMgr* device_mgr() const override;
199 
200   string DebugString(Handle handle) override;
201   int graph_def_version() override;
202 
203   Status Clone(std::unique_ptr<FunctionLibraryDefinition>* out_lib_def,
204                std::unique_ptr<ProcessFunctionLibraryRuntime>* out_pflr,
205                FunctionLibraryRuntime** out_flr) override;
206 
207  private:
208   FunctionLibraryRuntime* base_flr_;                  // not owned
209   const FunctionLibraryDefinition* overlay_lib_def_;  // not owned
210 };
211 
212 FunctionLibraryRuntimeOverlay::~FunctionLibraryRuntimeOverlay() = default;
213 
Instantiate(const string & function_name,AttrSlice attrs,const InstantiateOptions & options,Handle * handle)214 Status FunctionLibraryRuntimeOverlay::Instantiate(
215     const string& function_name, AttrSlice attrs,
216     const InstantiateOptions& options, Handle* handle) {
217   // We automatically add overlay lib to all instantiations, if the caller
218   // doesn't provide its own override.
219   if (!options.overlay_lib && overlay_lib_def_) {
220     InstantiateOptions options_copy = options;
221     options_copy.overlay_lib = overlay_lib_def_;
222     return base_flr_->Instantiate(function_name, attrs, options_copy, handle);
223   } else {
224     return base_flr_->Instantiate(function_name, attrs, options, handle);
225   }
226 }
227 
ReleaseHandle(Handle handle)228 Status FunctionLibraryRuntimeOverlay::ReleaseHandle(Handle handle) {
229   return base_flr_->ReleaseHandle(handle);
230 }
231 
GetFunctionBody(Handle h)232 const FunctionBody* FunctionLibraryRuntimeOverlay::GetFunctionBody(Handle h) {
233   return base_flr_->GetFunctionBody(h);
234 }
235 
Run(const Options & opts,Handle handle,gtl::ArraySlice<Tensor> args,std::vector<Tensor> * rets,DoneCallback done)236 void FunctionLibraryRuntimeOverlay::Run(const Options& opts, Handle handle,
237                                         gtl::ArraySlice<Tensor> args,
238                                         std::vector<Tensor>* rets,
239                                         DoneCallback done) {
240   base_flr_->Run(opts, handle, args, rets, std::move(done));
241 }
242 
Run(const Options & opts,Handle handle,CallFrameInterface * call_frame,DoneCallback done)243 void FunctionLibraryRuntimeOverlay::Run(const Options& opts, Handle handle,
244                                         CallFrameInterface* call_frame,
245                                         DoneCallback done) {
246   base_flr_->Run(opts, handle, call_frame, std::move(done));
247 }
248 
CreateKernel(const NodeDef &,OpKernel **)249 Status FunctionLibraryRuntimeOverlay::CreateKernel(const NodeDef&, OpKernel**) {
250   // We don't have access base_lib_def_ in base function library runtime (aka
251   // FunctionLibraryRuntimeImpl), so to make sure we do not create kernel with
252   // wrong lib_def we just disable creation of new kernels through overlays.
253   //
254   // When we call Instantiate from the base runtime with overlay lib override,
255   // the base runtime implementation is responsible for correctly passing custom
256   // overlay lib to all kernel constructions.
257   return errors::Internal(
258       "Overlay function library runtime doesn't support kernel creation.");
259 }
260 
IsStateful(const string & function_name)261 bool FunctionLibraryRuntimeOverlay::IsStateful(const string& function_name) {
262   // Important: we do not forward lookup to the base FLR.
263   const OpDef* op_def;
264   const Status s = overlay_lib_def_->LookUpOpDef(function_name, &op_def);
265   return s.ok() && op_def->is_stateful();
266 }
267 
env()268 Env* FunctionLibraryRuntimeOverlay::env() { return base_flr_->env(); }
269 
device()270 Device* FunctionLibraryRuntimeOverlay::device() { return base_flr_->device(); }
271 
272 std::function<void(std::function<void()>)>*
runner()273 FunctionLibraryRuntimeOverlay::runner() {
274   return base_flr_->runner();
275 }
276 
device_mgr() const277 const DeviceMgr* FunctionLibraryRuntimeOverlay::device_mgr() const {
278   return base_flr_->device_mgr();
279 }
280 
281 const FunctionLibraryDefinition*
GetFunctionLibraryDefinition() const282 FunctionLibraryRuntimeOverlay::GetFunctionLibraryDefinition() const {
283   return overlay_lib_def_ ? overlay_lib_def_
284                           : base_flr_->GetFunctionLibraryDefinition();
285 }
286 
DebugString(Handle handle)287 string FunctionLibraryRuntimeOverlay::DebugString(Handle handle) {
288   return base_flr_->DebugString(handle);
289 }
290 
graph_def_version()291 int FunctionLibraryRuntimeOverlay::graph_def_version() {
292   return base_flr_->graph_def_version();
293 }
294 
Clone(std::unique_ptr<FunctionLibraryDefinition> * out_lib_def,std::unique_ptr<ProcessFunctionLibraryRuntime> * out_pflr,FunctionLibraryRuntime ** out_flr)295 Status FunctionLibraryRuntimeOverlay::Clone(
296     std::unique_ptr<FunctionLibraryDefinition>* out_lib_def,
297     std::unique_ptr<ProcessFunctionLibraryRuntime>* out_pflr,
298     FunctionLibraryRuntime** out_flr) {
299   // NOTE(ezhulenev): Cloned FunctionLibraryRuntime will be missing overlay lib,
300   // but that's ok because we anyway do not copy/clone instantiated items from
301   // the base FLR.
302   return base_flr_->Clone(out_lib_def, out_pflr, out_flr);
303 }
304 
305 class FunctionLibraryRuntimeImpl : public FunctionLibraryRuntime {
306  public:
307   FunctionLibraryRuntimeImpl(const DeviceMgr* dmgr, Env* env, Device* device,
308                              int graph_def_version,
309                              const FunctionLibraryDefinition* lib_def,
310                              thread::ThreadPool* default_thread_pool,
311                              const OptimizerOptions& optimizer_options,
312                              CustomKernelCreator custom_kernel_creator,
313                              ProcessFunctionLibraryRuntime* parent);
314 
315   ~FunctionLibraryRuntimeImpl() override;
316 
317   Status Instantiate(const string& function_name, AttrSlice attrs,
318                      const InstantiateOptions& options,
319                      Handle* handle) override;
320 
321   Status ReleaseHandle(Handle handle) override;
322 
323   const FunctionBody* GetFunctionBody(Handle handle) override;
324 
325   Status CreateKernel(const NodeDef& ndef, OpKernel** kernel) override;
326 
327   void Run(const Options& opts, Handle handle, gtl::ArraySlice<Tensor> args,
328            std::vector<Tensor>* rets, DoneCallback done) override;
329   // NOTE(mrry): This overload is currently only implemented for local function
330   // execution.
331   // TODO(b/70346412): Implement support for remote function execution when
332   // passing a call frame.
333   void Run(const Options& opts, Handle handle, CallFrameInterface* frame,
334            DoneCallback done) override;
335 
336   bool IsStateful(const string& function) override;
337 
GetFunctionLibraryDefinition() const338   const FunctionLibraryDefinition* GetFunctionLibraryDefinition()
339       const override {
340     return base_lib_def_;
341   }
342 
device()343   Device* device() override { return device_; }
344 
runner()345   std::function<void(std::function<void()>)>* runner() override {
346     return &default_runner_;
347   }
348 
device_mgr() const349   const DeviceMgr* device_mgr() const override { return device_mgr_; }
env()350   Env* env() override { return env_; }
graph_def_version()351   int graph_def_version() override { return graph_def_version_; }
352 
353   string DebugString(Handle h) override;
354 
355   Status Clone(std::unique_ptr<FunctionLibraryDefinition>* out_lib_def,
356                std::unique_ptr<ProcessFunctionLibraryRuntime>* out_pflr,
357                FunctionLibraryRuntime** out_flr) override;
358 
359  private:
360   typedef FunctionLibraryRuntimeImpl ME;
361 
362   const DeviceMgr* const device_mgr_;
363   Device* const device_;
364   Env* const env_;
365   const int graph_def_version_;
366   const FunctionLibraryDefinition* const base_lib_def_;
367   GraphOptimizer optimizer_;
368   const CustomKernelCreator custom_kernel_creator_;
369   Executor::Args::Runner default_runner_;
370   const string device_name_;
371 
372   std::function<Status(const string&, const OpDef**)> get_func_sig_;
373   std::function<Status(const NodeDef&, OpKernel**)> create_kernel_;
374 
375   mutable mutex mu_;
376 
377   int next_handle_ GUARDED_BY(mu_);
378 
379   // The instantiated and transformed function is encoded as a Graph
380   // object, and an executor is created for the graph.
381   struct Item {
382     uint64 instantiation_counter = 0;
383     const Graph* graph = nullptr;                            // Owned by exec.
384     const FunctionLibraryDefinition* overlay_lib = nullptr;  // Not owned.
385     FunctionBody* func_graph = nullptr;
386     Executor* exec = nullptr;
387     FunctionLibraryRuntimeOverlay* overlay_flr = nullptr;
388     string executor_type;
389 
~Itemtensorflow::FunctionLibraryRuntimeImpl::Item390     ~Item() {
391       delete this->func_graph;
392       delete this->exec;
393       delete this->overlay_flr;
394     }
395   };
396   std::unordered_map<Handle, std::unique_ptr<Item>> items_ GUARDED_BY(mu_);
397 
398   ProcessFunctionLibraryRuntime* parent_ = nullptr;  // not owned.
399 
400   Status CreateKernel(const NodeDef& ndef,
401                       const FunctionLibraryDefinition* lib_def,
402                       OpKernel** kernel);
403   Status FunctionDefToBody(const FunctionDef& fdef, AttrSlice attrs,
404                            const FunctionLibraryDefinition* lib_def,
405                            FunctionBody** fbody);
406   Status CreateItem(Item** item);
407   Status GetOrCreateItem(LocalHandle local_handle, Item** item);
408   Status InstantiateSymbolicGradient(const NameAttrList& func,
409                                      const FunctionLibraryDefinition* lib_def,
410                                      FunctionBody** g_body);
411   bool IsLocalTarget(const InstantiateOptions& options);
412   AttrValueMap FixAttrs(const AttrSlice& attrs);
413   void RunRemote(const Options& opts, Handle handle,
414                  gtl::ArraySlice<Tensor> args, std::vector<Tensor>* rets,
415                  Item* item, DoneCallback done);
416 
417   void ExecutorArgsFromOptions(const FunctionLibraryRuntime::Options& run_opts,
418                                CallFrameInterface* frame,
419                                Executor::Args* exec_args);
420 
421   TF_DISALLOW_COPY_AND_ASSIGN(FunctionLibraryRuntimeImpl);
422 };
423 
FunctionLibraryRuntimeImpl(const DeviceMgr * dmgr,Env * env,Device * device,int graph_def_version,const FunctionLibraryDefinition * lib_def,thread::ThreadPool * default_thread_pool,const OptimizerOptions & optimizer_options,CustomKernelCreator custom_kernel_creator,ProcessFunctionLibraryRuntime * parent)424 FunctionLibraryRuntimeImpl::FunctionLibraryRuntimeImpl(
425     const DeviceMgr* dmgr, Env* env, Device* device, int graph_def_version,
426     const FunctionLibraryDefinition* lib_def,
427     thread::ThreadPool* default_thread_pool,
428     const OptimizerOptions& optimizer_options,
429     CustomKernelCreator custom_kernel_creator,
430     ProcessFunctionLibraryRuntime* parent)
431     : device_mgr_(dmgr),
432       device_(device),
433       env_(env),
434       graph_def_version_(graph_def_version),
435       base_lib_def_(lib_def),
436       optimizer_(optimizer_options),
437       custom_kernel_creator_(std::move(custom_kernel_creator)),
438       default_runner_(nullptr),
439       device_name_(device_ == nullptr
440                        ? ProcessFunctionLibraryRuntime::kDefaultFLRDevice
441                        : device_->name()),
442       next_handle_(0),
443       parent_(parent) {
444   get_func_sig_ = [this](const string& op, const OpDef** sig) {
445     return base_lib_def_->LookUpOpDef(op, sig);
446   };
447   create_kernel_ = [this](const NodeDef& ndef, OpKernel** kernel) {
448     return CreateKernel(ndef, kernel);
449   };
450   thread::ThreadPool* pool = nullptr;
451   if (device_ != nullptr) {
452     pool = device_->tensorflow_device_thread_pool();
453   }
454   if (pool == nullptr) {
455     pool = default_thread_pool;
456   }
457   if (pool != nullptr) {
458     default_runner_ = [pool](Executor::Args::Closure c) {
459       pool->Schedule(std::move(c));
460     };
461   }
462 }
463 
~FunctionLibraryRuntimeImpl()464 FunctionLibraryRuntimeImpl::~FunctionLibraryRuntimeImpl() {}
465 
466 // An asynchronous op kernel which executes an instantiated function
467 // defined in a library.
468 class CallOp : public AsyncOpKernel {
469  public:
CallOp(FunctionLibraryRuntime::Handle handle,OpKernelConstruction * ctx)470   CallOp(FunctionLibraryRuntime::Handle handle, OpKernelConstruction* ctx)
471       : AsyncOpKernel(ctx), handle_(handle) {}
472 
~CallOp()473   ~CallOp() override {
474     // TODO(iga): Release the cached handle_
475   }
476 
ComputeAsync(OpKernelContext * ctx,DoneCallback done)477   void ComputeAsync(OpKernelContext* ctx, DoneCallback done) override {
478     FunctionLibraryRuntime* lib = ctx->function_library();
479     OP_REQUIRES_ASYNC(ctx, lib != nullptr,
480                       errors::Internal("No function library is provided."),
481                       done);
482     FunctionLibraryRuntime::Options opts;
483     opts.step_id = ctx->step_id();
484     opts.rendezvous = ctx->rendezvous();
485     opts.cancellation_manager = ctx->cancellation_manager();
486     opts.step_container = ctx->step_container();
487     opts.stats_collector = ctx->stats_collector();
488     opts.runner = ctx->runner();
489     opts.collective_executor = ctx->collective_executor();
490     std::vector<Tensor> args;
491     args.reserve(ctx->num_inputs());
492     for (int i = 0; i < ctx->num_inputs(); ++i) {
493       args.push_back(ctx->input(i));
494     }
495     std::vector<Tensor>* rets = new std::vector<Tensor>;
496     lib->Run(opts, handle_, args, rets,
497              [ctx, done, rets](const Status& status) {
498                if (!status.ok()) {
499                  ctx->SetStatus(status);
500                } else {
501                  const int ret_size = static_cast<int>(rets->size());
502                  CHECK_EQ(ret_size, ctx->num_outputs());
503                  for (int i = 0; i < ret_size; ++i) {
504                    ctx->set_output(i, (*rets)[i]);
505                  }
506                }
507                delete rets;
508                done();
509              });
510   }
511 
512  private:
513   FunctionLibraryRuntime::Handle handle_;
514 
515   TF_DISALLOW_COPY_AND_ASSIGN(CallOp);
516 };
517 
GetFunctionBody(Handle h)518 const FunctionBody* FunctionLibraryRuntimeImpl::GetFunctionBody(Handle h) {
519   LocalHandle local_handle = parent_->GetHandleOnDevice(device_name_, h);
520   if (local_handle == kInvalidLocalHandle) {
521     LOG(ERROR) << "Could not find Handle: " << h
522                << " on device: " << device_name_;
523     return nullptr;
524   }
525 
526   tf_shared_lock l(mu_);
527   auto iter = items_.find(local_handle);
528   CHECK(iter != items_.end());
529   return iter->second->func_graph;
530 }
531 
CreateKernel(const NodeDef & ndef,OpKernel ** kernel)532 Status FunctionLibraryRuntimeImpl::CreateKernel(const NodeDef& ndef,
533                                                 OpKernel** kernel) {
534   return CreateKernel(ndef, base_lib_def_, kernel);
535 }
536 
CreateKernel(const NodeDef & ndef,const FunctionLibraryDefinition * lib_def,OpKernel ** kernel)537 Status FunctionLibraryRuntimeImpl::CreateKernel(
538     const NodeDef& ndef, const FunctionLibraryDefinition* lib_def,
539     OpKernel** kernel) {
540   // If a custom kernel creator is given, try that.
541   Status s;
542   if (custom_kernel_creator_) {
543     std::unique_ptr<OpKernel> ret;
544     s = custom_kernel_creator_(this, ndef, &ret);
545     if (s.ok()) {
546       *kernel = ret.release();
547       return s;
548     } else {
549       VLOG(2) << "Custom creator error: " << s;
550       // Falls through.
551       s = Status::OK();
552     }
553   }
554 
555   if (lib_def->Find(ndef.op()) == nullptr) {
556     // A primitive operation. Creates the registered kernel.
557     return CreateNonCachedKernel(device_, this, ndef, graph_def_version_,
558                                  kernel);
559   }
560 
561   // Try to instantiate this function for the func/attr. Maybe it's
562   // cached already.
563   InstantiateOptions options;
564   if (lib_def != base_lib_def_) {
565     options.overlay_lib = lib_def;
566   }
567   Handle handle;
568   TF_RETURN_IF_ERROR(
569       Instantiate(ndef.op(), AttrSlice(&ndef.attr()), options, &handle));
570 
571   const FunctionBody* fbody = GetFunctionBody(handle);
572   CHECK_NOTNULL(fbody);
573 
574   // TODO(zhifengc): For now, we assume int32 and resources are always on host
575   // memory and other types are always on device memory. We should do type
576   // inference over function body to derive the correct input/output memory
577   // types.
578   MemoryTypeVector input_memory_types;
579   for (const auto& t : fbody->arg_types) {
580     input_memory_types.push_back(MTypeFromDType(t));
581   }
582   MemoryTypeVector output_memory_types;
583   for (const auto& t : fbody->ret_types) {
584     output_memory_types.push_back(MTypeFromDType(t));
585   }
586 
587   // Constructs a CallOp kernel for running the instantiated function.
588   auto device_type = DeviceType(device_->attributes().device_type());
589   OpKernelConstruction construction(
590       device_type, device_, device_->GetAllocator(AllocatorAttributes()), &ndef,
591       &fbody->fdef.signature(), this, fbody->arg_types, input_memory_types,
592       fbody->ret_types, output_memory_types, graph_def_version_, &s);
593   if (s.ok()) {
594     *kernel = new CallOp(handle, &construction);
595   }
596   return s;
597 }
598 
FunctionDefToBody(const FunctionDef & fdef,AttrSlice attrs,const FunctionLibraryDefinition * lib_def,FunctionBody ** fbody)599 Status FunctionLibraryRuntimeImpl::FunctionDefToBody(
600     const FunctionDef& fdef, AttrSlice attrs,
601     const FunctionLibraryDefinition* lib_def, FunctionBody** fbody) {
602   if (lib_def == base_lib_def_) {
603     return FunctionDefToBodyHelper(fdef, attrs, lib_def, get_func_sig_, fbody);
604   } else {
605     auto get_func_sig = [lib_def](const string& op, const OpDef** sig) {
606       return lib_def->LookUpOpDef(op, sig);
607     };
608     return FunctionDefToBodyHelper(fdef, attrs, lib_def, get_func_sig, fbody);
609   }
610 }
611 
InstantiateSymbolicGradient(const NameAttrList & func,const FunctionLibraryDefinition * lib_def,FunctionBody ** g_body)612 Status FunctionLibraryRuntimeImpl::InstantiateSymbolicGradient(
613     const NameAttrList& func, const FunctionLibraryDefinition* lib_def,
614     FunctionBody** g_body) {
615   const FunctionDef* fdef = lib_def->Find(func.name());
616   if (fdef == nullptr) {
617     // f is a primitive op.
618     gradient::Creator creator;
619     TF_RETURN_IF_ERROR(gradient::GetOpGradientCreator(func.name(), &creator));
620     if (creator == nullptr) {
621       return errors::InvalidArgument("No gradient is defined for ",
622                                      func.name());
623     }
624     FunctionDef grad_fdef;
625     // TODO(josh11b): Should filter out the attrs from func that aren't used
626     // by the gradient function.
627     TF_RETURN_IF_ERROR(creator(AttrSlice(&func.attr()), &grad_fdef));
628     TF_RETURN_IF_ERROR(
629         FunctionDefToBody(grad_fdef, AttrSlice(&func.attr()), lib_def, g_body));
630   } else {
631     // f is a user-defined function.
632     InstantiateOptions options;
633     if (lib_def != base_lib_def_) {
634       options.overlay_lib = lib_def;
635     }
636     Handle f_handle;
637     TF_RETURN_IF_ERROR(
638         Instantiate(func.name(), AttrSlice(&func.attr()), options, &f_handle));
639     const FunctionBody* f_body = GetFunctionBody(f_handle);
640     CHECK_NOTNULL(f_body);
641     *g_body = SymbolicGradient(*f_body);
642   }
643   return Status::OK();
644 }
645 
IsLocalTarget(const InstantiateOptions & options)646 bool FunctionLibraryRuntimeImpl::IsLocalTarget(
647     const InstantiateOptions& options) {
648   if (device_ == nullptr) return true;
649   if (options.target.empty()) return true;
650   if (options.is_multi_device_function) return false;
651   Device* target_device;
652   if (!device_mgr_->LookupDevice(options.target, &target_device).ok()) {
653     VLOG(1) << "Not instantiating function in FLR because failed to "
654             << "find device " << options.target << " in device manager";
655     return false;
656   }
657   if (target_device != device_) {
658     VLOG(1) << "Not instantiating function in FLR because target device "
659             << options.target
660             << " is different from FLR's device: " << device_->DebugString();
661     return false;
662   }
663   return true;
664 }
665 
Instantiate(const string & function_name,AttrSlice attrs,const InstantiateOptions & options,Handle * handle)666 Status FunctionLibraryRuntimeImpl::Instantiate(
667     const string& function_name, AttrSlice attrs,
668     const InstantiateOptions& options, Handle* handle) {
669   if (!IsLocalTarget(options)) {
670     return parent_->Instantiate(function_name, attrs, options, handle);
671   }
672 
673   // Since this is a local target, ensure that the local `device_name_` appears
674   // in the canonical key.
675   InstantiateOptions options_copy(options);
676   options_copy.target = device_name_;
677   const string key = Canonicalize(function_name, attrs, options_copy);
678 
679   {
680     mutex_lock l(mu_);
681     *handle = parent_->GetHandle(key);
682     if (*handle != kInvalidHandle) {
683       FunctionLibraryRuntime::LocalHandle handle_on_device =
684           parent_->GetHandleOnDevice(device_name_, *handle);
685       if (handle_on_device == kInvalidLocalHandle) {
686         return errors::Internal("LocalHandle not found for handle ", *handle,
687                                 ".");
688       }
689       auto item_handle = items_.find(handle_on_device);
690       if (item_handle == items_.end()) {
691         return errors::Internal("LocalHandle ", handle_on_device,
692                                 " for handle ", *handle,
693                                 " not found in items.");
694       }
695       ++item_handle->second->instantiation_counter;
696       return Status::OK();
697     }
698   }
699 
700   const FunctionLibraryDefinition* lib_def =
701       options.overlay_lib ? options.overlay_lib : base_lib_def_;
702   FunctionBody* fbody = nullptr;
703   if (function_name == kGradientOp) {
704     const AttrValue* f = attrs.Find(kFuncAttr);
705     if (f == nullptr) {
706       return errors::InvalidArgument("SymbolicGradient is missing attr: f");
707     }
708     const auto& func = f->func();
709     if (func.name() == kGradientOp) {
710       return errors::InvalidArgument("Can't take gradient of SymbolicGradient");
711     }
712     const string grad = lib_def->FindGradient(func.name());
713     if (!grad.empty()) {
714       return Instantiate(grad, AttrSlice(&func.attr()), options, handle);
715     }
716     TF_RETURN_IF_ERROR(InstantiateSymbolicGradient(func, lib_def, &fbody));
717   } else {
718     const FunctionDef* fdef = lib_def->Find(function_name);
719     if (fdef == nullptr) {
720       return errors::NotFound("Function ", function_name, " is not defined.");
721     }
722     TF_RETURN_IF_ERROR(FunctionDefToBody(*fdef, attrs, lib_def, &fbody));
723   }
724 
725   LocalHandle local_handle;
726   {
727     mutex_lock l(mu_);
728     *handle = parent_->GetHandle(key);
729     if (*handle != kInvalidHandle) {
730       delete fbody;
731       local_handle = parent_->GetHandleOnDevice(device_name_, *handle);
732       ++items_[local_handle]->instantiation_counter;
733     } else {
734       *handle = parent_->AddHandle(key, device_name_, next_handle_);
735       Item* item = new Item;
736       item->func_graph = fbody;
737       item->overlay_lib = options.overlay_lib;
738       item->instantiation_counter = 1;
739       item->executor_type = ExecutorType(options, attrs);
740       if (options.overlay_lib) {
741         item->overlay_flr =
742             new FunctionLibraryRuntimeOverlay(this, options.overlay_lib);
743       }
744       local_handle = next_handle_++;
745       items_.emplace(local_handle, std::unique_ptr<Item>(item));
746     }
747   }
748 
749   if (options.create_kernels_eagerly) {
750     Item* item;
751     TF_RETURN_IF_ERROR(GetOrCreateItem(local_handle, &item));
752   }
753 
754   return Status::OK();
755 }
756 
ReleaseHandle(Handle handle)757 Status FunctionLibraryRuntimeImpl::ReleaseHandle(Handle handle) {
758   LocalHandle h = parent_->GetHandleOnDevice(device_name_, handle);
759   if (h == kInvalidLocalHandle) {
760     return parent_->ReleaseHandle(handle);
761   }
762 
763   std::unique_ptr<Item> item_to_delete;
764   Status parent_status;
765   {
766     mutex_lock l(mu_);
767     auto it = items_.find(h);
768     if (it == items_.end()) {
769       return errors::Internal(
770           "Inconsistent FunctionLibraryRuntime. Expected to find an item for "
771           "handle ",
772           h, " but found none");
773     }
774     std::unique_ptr<Item>& item = it->second;
775     --item->instantiation_counter;
776     if (item->instantiation_counter == 0) {
777       // We don't simply erase h's item because that would trigger
778       // item destruction while holding mu_. Item destruction can
779       // trigger graph destruction. If the graph contains kernels like
780       // CallOp or PartitionCallOp, their destructors will release cached
781       // function handles, resulting in deadlock here.
782       item_to_delete = std::move(item);
783       items_.erase(h);
784       parent_status = parent_->RemoveHandle(handle);
785     }
786   }
787   return parent_status;
788 }
789 
DumpGraph(StringPiece label,const Graph * g)790 void DumpGraph(StringPiece label, const Graph* g) {
791   // TODO(zhifengc): Change Graph to record #nodes.
792   VLOG(1) << "Graph " << label << " #nodes " << g->num_nodes() << " #edges "
793           << g->num_edges();
794   if (VLOG_IS_ON(2)) {
795     for (const auto& line : str_util::Split(DebugString(g), '\n')) {
796       VLOG(2) << "|| " << line;
797     }
798   }
799 }
800 
OptimizeGraph(FunctionLibraryRuntime * lib,std::unique_ptr<Graph> * g,const GraphOptimizer::Options & graph_optimizer_options)801 void OptimizeGraph(FunctionLibraryRuntime* lib, std::unique_ptr<Graph>* g,
802                    const GraphOptimizer::Options& graph_optimizer_options) {
803   OptimizerOptions opts;
804   opts.set_do_common_subexpression_elimination(true);
805   opts.set_do_function_inlining(true);
806   opts.set_do_constant_folding(true);
807   GraphOptimizer optimizer(opts);
808   optimizer.Optimize(lib, lib->env(), lib->device(), g,
809                      graph_optimizer_options);
810 }
811 
OptimizeGraph(FunctionLibraryRuntime * lib,std::unique_ptr<Graph> * g)812 void OptimizeGraph(FunctionLibraryRuntime* lib, std::unique_ptr<Graph>* g) {
813   OptimizeGraph(lib, g, GraphOptimizer::Options());
814 }
815 
816 namespace {
817 // Removes all stateless nodes that do not contribute to a return
818 // value from the function body. Unlike `RemoveDeadNodes()`, which is
819 // triggered by `OptimizerOptions.do_function_inlining`, this pass
820 // ignores the SINK node, from which (by definition) all nodes are
821 // reverse reachable, and preserves all nodes that are reachable from
822 // control output nodes.
823 //
824 // TODO(ezhulenev, skyewm): Function body should not have special treatment of
825 // stateful ops, graph should encode nodes that must execute with `control_ret`
826 // and `control_output`.
PruneFunctionBody(const FunctionDef & fdef,Graph * g)827 void PruneFunctionBody(const FunctionDef& fdef, Graph* g) {
828   VLOG(2) << "Pruning function body: function_name=" << fdef.signature().name();
829 
830   // `control_ret` nodes must be always executed.
831   std::unordered_set<StringPiece, StringPieceHasher> control_ret_nodes;
832   for (const auto& control_ret : fdef.control_ret()) {
833     control_ret_nodes.insert(control_ret.second);
834   }
835 
836   std::unordered_set<const Node*> nodes;
837   for (auto n : g->nodes()) {
838     // NOTE(mrry): "_Retval" nodes are stateful, and so will be added
839     // to the seed set of `nodes`. "_Arg" nodes are also stateful, but we
840     // specifically exclude them as seeds, to avoid unconditionally executing
841     // unused argument nodes (e.g. in a function like `lambda x, y: y`).
842     // TODO(mrry): Investigate whether the `n->IsControlFlow()` test is
843     // still needed. It would be preferable to prune entire loops and/or
844     // conditionals if they are not used in the graph.
845     if (n->IsControlFlow() ||
846         (n->op_def().is_stateful() && n->type_string() != kArgOp) ||
847         (control_ret_nodes.find(n->name()) != control_ret_nodes.end())) {
848       nodes.insert(n);
849     }
850   }
851   bool changed = PruneForReverseReachability(g, std::move(nodes));
852   if (changed) {
853     FixupSourceAndSinkEdges(g);
854   }
855 }
856 }  // namespace
857 
CreateItem(Item ** item)858 Status FunctionLibraryRuntimeImpl::CreateItem(Item** item) {
859   const FunctionBody* fbody;
860   const FunctionLibraryDefinition* lib_def;
861   string executor_type;
862   {
863     tf_shared_lock l(mu_);
864     fbody = (*item)->func_graph;
865     lib_def = (*item)->overlay_lib;
866     executor_type = (*item)->executor_type;
867   }
868   if (!lib_def) {
869     lib_def = base_lib_def_;
870   }
871   std::unique_ptr<Graph> g(new Graph(lib_def));
872   CopyGraph(*fbody->graph, g.get());
873 
874   PruneFunctionBody(fbody->fdef, g.get());
875   optimizer_.Optimize(this, env(), device(), &g, /*shape_map=*/nullptr);
876   TF_RETURN_IF_ERROR(EnsureMemoryTypes(DeviceType(device()->device_type()),
877                                        device()->name(), g.get()));
878 
879   // Creates an executor based on the g. This must be done without
880   // holding mu_ because create_kernel_ calls back into the library.
881   LocalExecutorParams params;
882   params.device = device_;
883   params.function_library =
884       (*item)->overlay_flr
885           ? static_cast<FunctionLibraryRuntime*>((*item)->overlay_flr)
886           : static_cast<FunctionLibraryRuntime*>(this);
887   if (lib_def == base_lib_def_) {
888     params.create_kernel = create_kernel_;
889   } else {
890     params.create_kernel = [this, lib_def](const NodeDef& ndef,
891                                            OpKernel** kernel) {
892       return CreateKernel(ndef, lib_def, kernel);
893     };
894   }
895   params.delete_kernel = [](OpKernel* kernel) {
896     DeleteNonCachedKernel(kernel);
897   };
898   Graph* graph = g.get();
899   std::unique_ptr<Executor> exec;
900   TF_RETURN_IF_ERROR(NewExecutor(executor_type, params, std::move(g), &exec));
901   {
902     // Guard item since it is already inserted in items_.
903     mutex_lock l(mu_);
904     if ((*item)->exec == nullptr) {
905       (*item)->graph = graph;
906       (*item)->exec = exec.release();
907     }
908   }
909   return Status::OK();
910 }
911 
GetOrCreateItem(LocalHandle local_handle,Item ** item)912 Status FunctionLibraryRuntimeImpl::GetOrCreateItem(LocalHandle local_handle,
913                                                    Item** item) {
914   {
915     tf_shared_lock l(mu_);
916     auto iter = items_.find(local_handle);
917     if (iter == items_.end()) {
918       return errors::Internal("Local function handle ", local_handle,
919                               " is not valid. Likely an internal error.");
920     }
921     *item = iter->second.get();
922     if ((*item)->exec != nullptr) {
923       return Status::OK();
924     }
925   }
926   // NOTE: We need to call CreateItem out of mu_ because creating an
927   // executor needs to call CreateKernel.
928   return CreateItem(item);
929 }
930 
ExecutorArgsFromOptions(const FunctionLibraryRuntime::Options & run_opts,CallFrameInterface * frame,Executor::Args * exec_args)931 void FunctionLibraryRuntimeImpl::ExecutorArgsFromOptions(
932     const FunctionLibraryRuntime::Options& run_opts, CallFrameInterface* frame,
933     Executor::Args* exec_args) {
934   // Inherit the step_id from the caller.
935   exec_args->step_id = run_opts.step_id;
936   exec_args->rendezvous = run_opts.rendezvous;
937   exec_args->stats_collector = run_opts.stats_collector;
938   exec_args->cancellation_manager = run_opts.cancellation_manager;
939   exec_args->step_container = run_opts.step_container;
940   if (run_opts.runner) {
941     exec_args->runner = *run_opts.runner;
942   } else {
943     exec_args->runner = default_runner_;
944   }
945   exec_args->collective_executor = run_opts.collective_executor;
946   exec_args->call_frame = frame;
947 }
948 
RunRemote(const Options & opts,Handle handle,gtl::ArraySlice<Tensor> args,std::vector<Tensor> * rets,Item * item,DoneCallback done)949 void FunctionLibraryRuntimeImpl::RunRemote(const Options& opts, Handle handle,
950                                            gtl::ArraySlice<Tensor> args,
951                                            std::vector<Tensor>* rets,
952                                            Item* item, DoneCallback done) {
953   string target_device = parent_->GetDeviceName(handle);
954   string source_device = opts.source_device;
955   Rendezvous* rendezvous = opts.rendezvous;
956   DeviceContext* device_context;
957   Status s = parent_->GetDeviceContext(target_device, &device_context);
958   if (!s.ok()) {
959     done(s);
960     return;
961   }
962   int64 src_incarnation, target_incarnation;
963   s = parent_->GetDeviceIncarnation(source_device, &src_incarnation);
964   s.Update(parent_->GetDeviceIncarnation(target_device, &target_incarnation));
965   if (!s.ok()) {
966     done(s);
967     return;
968   }
969 
970   const FunctionBody* fbody = GetFunctionBody(handle);
971   FunctionCallFrame* frame =
972       new FunctionCallFrame(fbody->arg_types, fbody->ret_types);
973   Executor::Args* exec_args = new Executor::Args;
974   ExecutorArgsFromOptions(opts, frame, exec_args);
975 
976   std::vector<AllocatorAttributes> args_alloc_attrs, rets_alloc_attrs;
977   args_alloc_attrs.reserve(fbody->arg_types.size());
978   rets_alloc_attrs.reserve(fbody->ret_types.size());
979   // Note: Functions assume that int32's are always on host memory.
980   for (const auto& arg_type : fbody->arg_types) {
981     AllocatorAttributes arg_alloc_attrs;
982     if (MTypeFromDType(arg_type) == HOST_MEMORY) {
983       arg_alloc_attrs.set_on_host(true);
984     }
985     args_alloc_attrs.push_back(arg_alloc_attrs);
986   }
987   for (const auto& ret_type : fbody->ret_types) {
988     AllocatorAttributes ret_alloc_attrs;
989     if (MTypeFromDType(ret_type) == HOST_MEMORY) {
990       ret_alloc_attrs.set_on_host(true);
991     }
992     rets_alloc_attrs.push_back(ret_alloc_attrs);
993   }
994 
995   bool allow_dead_tensors = opts.allow_dead_tensors;
996 
997   // The ProcFLR sends the arguments to the function from the source_device to
998   // the target_device. So here we receive those arguments. Similarly, when the
999   // computation is done and stored in *rets, we send the return values back
1000   // to the source_device (caller) so that the ProcFLR can receive them later.
1001   std::vector<Tensor>* remote_args = new std::vector<Tensor>;
1002   ProcessFunctionLibraryRuntime::ReceiveTensorsAsync(
1003       source_device, target_device, "arg_", src_incarnation, args.size(),
1004       device_context, args_alloc_attrs, rendezvous, remote_args,
1005       [frame, remote_args, item, source_device, target_device,
1006        target_incarnation, rendezvous, device_context, rets, done, exec_args,
1007        rets_alloc_attrs, allow_dead_tensors](const Status& status) {
1008         Status s = status;
1009         if (s.ok()) {
1010           s = frame->SetArgs(*remote_args);
1011         }
1012         if (!s.ok()) {
1013           delete frame;
1014           delete remote_args;
1015           delete exec_args;
1016           done(s);
1017           return;
1018         }
1019         item->exec->RunAsync(
1020             *exec_args,
1021             [frame, rets, done, source_device, target_device,
1022              target_incarnation, rendezvous, device_context, remote_args,
1023              rets_alloc_attrs, allow_dead_tensors](const Status& status) {
1024               Status s = status;
1025               if (s.ok()) {
1026                 s = frame->ConsumeRetvals(rets, allow_dead_tensors);
1027               }
1028               delete frame;
1029               if (!s.ok()) {
1030                 delete remote_args;
1031                 done(s);
1032                 return;
1033               }
1034               s = ProcessFunctionLibraryRuntime::SendTensors(
1035                   target_device, source_device, "ret_", target_incarnation,
1036                   *rets, device_context, rets_alloc_attrs, rendezvous);
1037               delete remote_args;
1038               done(s);
1039             });
1040         delete exec_args;
1041       });
1042 }
1043 
Run(const Options & opts,Handle handle,gtl::ArraySlice<Tensor> args,std::vector<Tensor> * rets,DoneCallback done)1044 void FunctionLibraryRuntimeImpl::Run(const Options& opts, Handle handle,
1045                                      gtl::ArraySlice<Tensor> args,
1046                                      std::vector<Tensor>* rets,
1047                                      DoneCallback done) {
1048   if (opts.cancellation_manager && opts.cancellation_manager->IsCancelled()) {
1049     done(errors::Cancelled(""));
1050     return;
1051   }
1052   Options run_opts = opts;
1053   if (opts.create_rendezvous) {
1054     Rendezvous* rendezvous = new IntraProcessRendezvous(device_mgr_);
1055     run_opts.rendezvous = rendezvous;
1056     run_opts.create_rendezvous = false;
1057     done = [done, rendezvous](const Status& status) {
1058       rendezvous->Unref();
1059       done(status);
1060     };
1061   }
1062 
1063   LocalHandle local_handle = parent_->GetHandleOnDevice(device_name_, handle);
1064   if (local_handle == kInvalidLocalHandle) {
1065     parent_->Run(run_opts, handle, args, rets, done);
1066     return;
1067   }
1068 
1069   if (run_opts.runner == nullptr) {
1070     run_opts.runner = &default_runner_;
1071   }
1072   DCHECK(run_opts.runner != nullptr);
1073 
1074   Item* item = nullptr;
1075   Status s = GetOrCreateItem(local_handle, &item);
1076   if (!s.ok()) {
1077     done(s);
1078     return;
1079   }
1080 
1081   if (run_opts.remote_execution) {
1082     // NOTE(mrry): `RunRemote()` will set `exec_args->call_frame` for us.
1083     RunRemote(run_opts, handle, args, rets, item, done);
1084     return;
1085   }
1086 
1087   const FunctionBody* fbody = GetFunctionBody(handle);
1088   FunctionCallFrame* frame =
1089       new FunctionCallFrame(fbody->arg_types, fbody->ret_types);
1090   s = frame->SetArgs(args);
1091   if (!s.ok()) {
1092     delete frame;
1093     done(s);
1094     return;
1095   }
1096 
1097   Executor::Args exec_args;
1098   ExecutorArgsFromOptions(run_opts, frame, &exec_args);
1099 
1100   bool allow_dead_tensors = run_opts.allow_dead_tensors;
1101   item->exec->RunAsync(
1102       // Executor args
1103       exec_args,
1104       // Done callback.
1105       [frame, rets, done, allow_dead_tensors](const Status& status) {
1106         Status s = status;
1107         if (s.ok()) {
1108           s = frame->ConsumeRetvals(rets, allow_dead_tensors);
1109         }
1110         delete frame;
1111         done(s);
1112       });
1113 }
1114 
Run(const Options & opts,Handle handle,CallFrameInterface * frame,DoneCallback done)1115 void FunctionLibraryRuntimeImpl::Run(const Options& opts, Handle handle,
1116                                      CallFrameInterface* frame,
1117                                      DoneCallback done) {
1118   if (opts.cancellation_manager && opts.cancellation_manager->IsCancelled()) {
1119     done(errors::Cancelled(""));
1120     return;
1121   }
1122   LocalHandle local_handle = parent_->GetHandleOnDevice(device_name_, handle);
1123   if (local_handle == kInvalidLocalHandle || opts.remote_execution) {
1124     done(errors::Unimplemented("Remote calling with CallFrameInterface"));
1125     return;
1126   }
1127 
1128   Options run_opts = opts;
1129   if (opts.create_rendezvous) {
1130     Rendezvous* rendezvous = new IntraProcessRendezvous(device_mgr_);
1131     run_opts.rendezvous = rendezvous;
1132     run_opts.create_rendezvous = false;
1133     done = std::bind(
1134         [rendezvous](DoneCallback done,
1135                      // Begin unbound arguments.
1136                      const Status& status) {
1137           rendezvous->Unref();
1138           done(status);
1139         },
1140         std::move(done), std::placeholders::_1);
1141   }
1142 
1143   Item* item = nullptr;
1144   Status s = GetOrCreateItem(local_handle, &item);
1145   if (!s.ok()) {
1146     done(s);
1147     return;
1148   }
1149   if (run_opts.runner == nullptr) {
1150     run_opts.runner = &default_runner_;
1151   }
1152   DCHECK(run_opts.runner != nullptr);
1153 
1154   Executor::Args exec_args;
1155   ExecutorArgsFromOptions(run_opts, frame, &exec_args);
1156   item->exec->RunAsync(exec_args, std::move(done));
1157 }
1158 
IsStateful(const string & func)1159 bool FunctionLibraryRuntimeImpl::IsStateful(const string& func) {
1160   const OpDef* op_def;
1161   const Status s = base_lib_def_->LookUpOpDef(func, &op_def);
1162   return s.ok() && op_def->is_stateful();
1163 }
1164 
DebugString(Handle handle)1165 string FunctionLibraryRuntimeImpl::DebugString(Handle handle) {
1166   Item* item = nullptr;
1167   LocalHandle local_handle = parent_->GetHandleOnDevice(device_name_, handle);
1168   Status s = GetOrCreateItem(local_handle, &item);
1169   if (s.ok()) {
1170     return tensorflow::DebugString(item->graph);
1171   } else {
1172     return s.ToString();
1173   }
1174 }
1175 
Clone(std::unique_ptr<FunctionLibraryDefinition> * out_lib_def,std::unique_ptr<ProcessFunctionLibraryRuntime> * out_pflr,FunctionLibraryRuntime ** out_flr)1176 Status FunctionLibraryRuntimeImpl::Clone(
1177     std::unique_ptr<FunctionLibraryDefinition>* out_lib_def,
1178     std::unique_ptr<ProcessFunctionLibraryRuntime>* out_pflr,
1179     FunctionLibraryRuntime** out_flr) {
1180   TF_RETURN_IF_ERROR(
1181       parent_->Clone(env_, graph_def_version_, optimizer_.options(),
1182                      custom_kernel_creator_, out_lib_def, out_pflr));
1183   *out_flr = (*out_pflr)->GetFLR(device_->name());
1184   if (out_flr != nullptr) {
1185     return Status::OK();
1186   } else {
1187     return errors::Internal("Cloning FunctionLibraryRuntime failed.");
1188   }
1189 }
1190 
1191 namespace {
1192 
1193 struct CustomCreatorSingleton {
1194   mutex mu;
1195   CustomKernelCreator custom_creator = nullptr;
1196 
Settensorflow::__anon9a8797220e11::CustomCreatorSingleton1197   void Set(CustomKernelCreator cb) {
1198     mutex_lock l(mu);
1199     custom_creator = std::move(cb);
1200   }
1201 
Gettensorflow::__anon9a8797220e11::CustomCreatorSingleton1202   CustomKernelCreator Get() {
1203     mutex_lock l(mu);
1204     return custom_creator;
1205   }
1206 };
1207 
GetCustomCreatorSingleton()1208 CustomCreatorSingleton* GetCustomCreatorSingleton() {
1209   static CustomCreatorSingleton* ccs = new CustomCreatorSingleton;
1210   return ccs;
1211 }
1212 
1213 }  // namespace
1214 
RegisterDefaultCustomKernelCreator(CustomKernelCreator cb)1215 void RegisterDefaultCustomKernelCreator(CustomKernelCreator cb) {
1216   GetCustomCreatorSingleton()->Set(std::move(cb));
1217 }
1218 
NewFunctionLibraryRuntime(const DeviceMgr * device_mgr,Env * env,Device * device,int graph_def_version,const FunctionLibraryDefinition * lib_def,thread::ThreadPool * thread_pool,const OptimizerOptions & optimizer_options,CustomKernelCreator custom_kernel_creator,ProcessFunctionLibraryRuntime * parent)1219 std::unique_ptr<FunctionLibraryRuntime> NewFunctionLibraryRuntime(
1220     const DeviceMgr* device_mgr, Env* env, Device* device,
1221     int graph_def_version, const FunctionLibraryDefinition* lib_def,
1222     thread::ThreadPool* thread_pool, const OptimizerOptions& optimizer_options,
1223     CustomKernelCreator custom_kernel_creator,
1224     ProcessFunctionLibraryRuntime* parent) {
1225   return std::unique_ptr<FunctionLibraryRuntime>(new FunctionLibraryRuntimeImpl(
1226       device_mgr, env, device, graph_def_version, lib_def, thread_pool,
1227       optimizer_options, std::move(custom_kernel_creator), parent));
1228 }
1229 
NewFunctionLibraryRuntime(const DeviceMgr * device_mgr,Env * env,Device * device,int graph_def_version,const FunctionLibraryDefinition * lib_def,thread::ThreadPool * thread_pool,const OptimizerOptions & optimizer_options,ProcessFunctionLibraryRuntime * parent)1230 std::unique_ptr<FunctionLibraryRuntime> NewFunctionLibraryRuntime(
1231     const DeviceMgr* device_mgr, Env* env, Device* device,
1232     int graph_def_version, const FunctionLibraryDefinition* lib_def,
1233     thread::ThreadPool* thread_pool, const OptimizerOptions& optimizer_options,
1234     ProcessFunctionLibraryRuntime* parent) {
1235   return NewFunctionLibraryRuntime(device_mgr, env, device, graph_def_version,
1236                                    lib_def, thread_pool, optimizer_options,
1237                                    GetCustomCreatorSingleton()->Get(), parent);
1238 }
1239 
RemoveDeadNodes(Graph * g)1240 bool RemoveDeadNodes(Graph* g) {
1241   VLOG(2) << "Removing dead nodes";
1242   std::unordered_set<const Node*> nodes;
1243   for (auto n : g->nodes()) {
1244     if (n->IsSource() || n->IsSink() || n->IsControlFlow() ||
1245         n->op_def().is_stateful()) {
1246       nodes.insert(n);
1247     }
1248   }
1249   return PruneForReverseReachability(g, std::move(nodes));
1250 }
1251 
1252 namespace {
1253 // If 'edges' contains only 1 non-control edge, returns it. Otherwise,
1254 // returns a nullptr.
GetTheOnlyDataEdge(const EdgeSet & edges)1255 const Edge* GetTheOnlyDataEdge(const EdgeSet& edges) {
1256   const Edge* ret = nullptr;
1257   for (const Edge* e : edges) {
1258     if (e->IsControlEdge() || ret) {
1259       // Don't touch it if there is a control edge.
1260       return nullptr;
1261     }
1262     if (IsRefType(e->src()->output_type(e->src_output()))) {
1263       // Don't touch it if the identity node is effectively de-reffing
1264       // a ref.
1265       return nullptr;
1266     }
1267     if (IsRecv(e->src()) || IsSwitch(e->src())) {
1268       // Don't touch it if the identity is introduced for control flow.
1269       // Recv disables all its successors if it receives a dead signal.
1270       // When Recv has an outgoing control edge, the current executor
1271       // would not disable the destination. The current solution (see
1272       // graph_partition.cc) is to add an identity after Recv and change
1273       // the control edge to be from this identity node. So the identity
1274       // can't be removed.
1275       return nullptr;
1276     }
1277     ret = e;
1278   }
1279   return ret;
1280 }
1281 }  // end namespace
1282 
RemoveIdentityNodes(Graph * g)1283 bool RemoveIdentityNodes(Graph* g) {
1284   VLOG(2) << "Removing identity nodes";
1285   bool removed_any = false;
1286   gtl::InlinedVector<Node*, 8> matches;
1287   for (Node* n : g->nodes()) {
1288     if (!n->IsIdentity()) continue;
1289     if (!GetTheOnlyDataEdge(n->in_edges())) continue;
1290 
1291     // Some identity nodes are used as sink nodes to give names to output
1292     // tensors. These nodes are not going to be executed unless they are in the
1293     // fetch set. But if they are in the fetch set we don't want to remove them.
1294     if (n->out_edges().empty()) continue;
1295 
1296     matches.push_back(n);
1297   }
1298   if (!matches.empty()) {
1299     for (Node* n : matches) {
1300       const Edge* in = GetTheOnlyDataEdge(n->in_edges());
1301       for (const Edge* out : n->out_edges()) {
1302         if (out->IsControlEdge()) {
1303           g->AddControlEdge(in->src(), out->dst());
1304         } else {
1305           g->AddEdge(in->src(), in->src_output(), out->dst(), out->dst_input());
1306         }
1307       }
1308       VLOG(2) << "Remove Identity: " << n->DebugString();
1309       g->RemoveNode(n);
1310       removed_any = true;
1311     }
1312   }
1313   return removed_any;
1314 }
1315 
RemoveListArrayConverter(Graph * g)1316 bool RemoveListArrayConverter(Graph* g) {
1317   VLOG(2) << "Removing list array converter";
1318   gtl::InlinedVector<Node*, 8> matches;
1319   for (Node* n : g->nodes()) {
1320     if ((n->type_string() == "_ListToArray") ||
1321         (n->type_string() == "_ArrayToList")) {
1322       matches.push_back(n);
1323     }
1324   }
1325   bool removed_any = false;
1326   if (!matches.empty()) {
1327     for (Node* n : matches) {
1328       if (n->num_inputs() != n->num_outputs()) {
1329         continue;  // Not expected. Skip.
1330       }
1331       gtl::InlinedVector<Node*, 8> identity_nodes(n->num_inputs(), nullptr);
1332 
1333       const auto no_op = [&](StringPiece name) {
1334         return AddNoOp(absl::StrCat(n->name(), "/", name), g);
1335       };
1336 
1337       const auto identity = [&](StringPiece name, Endpoint input) {
1338         return AddIdentity(absl::StrCat(n->name(), "/", name), g, input);
1339       };
1340 
1341       // Process input edges first.
1342       Node* input_control_node = nullptr;
1343       for (const Edge* e : n->in_edges()) {
1344         if (e->IsControlEdge()) {
1345           if (input_control_node == nullptr) {
1346             // If node "n" has any control dependencies, adds a no-op
1347             // node (input_control_node) which the additional Identity
1348             // nodes depends on and the input_control_node depends on
1349             // the node "n"s control dependencies.
1350             input_control_node = no_op("input_control_node");
1351           }
1352           g->AddControlEdge(e->src(), input_control_node);
1353         } else {
1354           const int index = e->dst_input();
1355           Node** id_node = &identity_nodes[index];
1356           if (*id_node != nullptr) {
1357             LOG(ERROR)
1358                 << "RemoveListArrayConverter unexpected duplicated input: "
1359                 << e->dst_input();
1360             return removed_any;
1361           }
1362           *id_node = identity("input", {e->src(), e->src_output()});
1363         }
1364       }
1365 
1366       // If node "n" has any control dependencies, the added identity
1367       // nodes should have control dependencies on input_control_node.
1368       if (input_control_node != nullptr) {
1369         for (Node* id : identity_nodes) {
1370           g->AddControlEdge(input_control_node, id);
1371         }
1372       }
1373 
1374       Node* output_control_node = nullptr;
1375       for (const Edge* e : n->out_edges()) {
1376         if (e->IsControlEdge()) {
1377           if (output_control_node == nullptr) {
1378             // If node "n" is control-depended upon by other nodes,
1379             // adds a no-op node (output_control_node) which those
1380             // nodes will depend on and output_control_node depends on
1381             // all Identity nodes.
1382             output_control_node = no_op("output_control_node");
1383           }
1384           g->AddControlEdge(output_control_node, e->dst());
1385         } else {
1386           Node* id_node = identity_nodes[e->src_output()];
1387           if (id_node == nullptr) {
1388             LOG(ERROR) << "RemoveListArrayConverter unexpected missing input: "
1389                        << e->src_output();
1390             return removed_any;
1391           }
1392           CHECK(id_node);
1393           g->AddEdge(id_node, 0, e->dst(), e->dst_input());
1394         }
1395       }
1396 
1397       // If any nodes have control dependencies on node "n", those
1398       // nodes should have control dependencies on
1399       // output_control_node.
1400       if (output_control_node != nullptr) {
1401         for (Node* id : identity_nodes) {
1402           g->AddControlEdge(id, output_control_node);
1403         }
1404       }
1405 
1406       g->RemoveNode(n);
1407       removed_any = true;
1408     }
1409   }
1410   return removed_any;
1411 }
1412 
InstantiateFunctionCall(const NodeDef & call_def,FunctionLibraryRuntime & flr,FunctionLibraryRuntime::Handle * handle)1413 Status InstantiateFunctionCall(const NodeDef& call_def,
1414                                FunctionLibraryRuntime& flr,
1415                                FunctionLibraryRuntime::Handle* handle) {
1416   const string* func_name;
1417   AttrSlice attrs;
1418 
1419   NameAttrList func;
1420   if (call_def.op() == "PartitionedCall" ||
1421       call_def.op() == "StatefulPartitionedCall") {
1422     TF_RETURN_IF_ERROR(GetNodeAttr(call_def, "f", &func));
1423     func_name = &func.name();
1424     attrs = AttrSlice(&func.attr());
1425   } else {
1426     func_name = &call_def.op();
1427     attrs = AttrSlice(call_def);
1428   }
1429 
1430   return flr.Instantiate(*func_name, attrs, handle);
1431 }
1432 
1433 namespace {
1434 
ValidateNoInline(const FunctionBody * fbody)1435 Status ValidateNoInline(const FunctionBody* fbody) {
1436   const auto attr = AttrSlice(&fbody->fdef.attr());
1437   bool noinline = false;
1438   if (GetNodeAttr(attr, kNoInlineAttr, &noinline).ok() && noinline) {
1439     return errors::InvalidArgument(
1440         "Can't inline function marked with '_noinline'");
1441   }
1442   return Status::OK();
1443 }
1444 
1445 using OutputControlSrc = InlineFunctionBodyOptions::OutputControlSource;
1446 
1447 }  // namespace
1448 
DebugString() const1449 string InlineFunctionBodyOptions::DebugString() const {
1450   return absl::StrCat("ignore_noinline=", ignore_noinline ? "true" : "false",
1451                       ", override_device=", override_device ? "true" : "false",
1452                       ", output_control_src=",
1453                       output_control_src == OutputControlSrc::kDataOutputs
1454                           ? "DataOutputs"
1455                           : "ControlOutputs");
1456 }
1457 
ValidateInlining(const Node * node,const FunctionBody * fbody,const InlineFunctionBodyOptions & options)1458 Status ValidateInlining(const Node* node, const FunctionBody* fbody,
1459                         const InlineFunctionBodyOptions& options) {
1460   // TODO(ezhulenev): Currently common_runtime function inlining can't guarantee
1461   // that all side-effectful ops will be executed after inlining. See Grappler
1462   // function_optimizer for details. Unify all function inlining mechanism.
1463   // Do not inline if `!fbody->control_ret_nodes.empty()`.
1464 
1465   const auto num_node_inputs = static_cast<size_t>(node->num_inputs());
1466   const auto num_node_outputs = static_cast<size_t>(node->num_outputs());
1467 
1468   if (num_node_inputs != fbody->arg_types.size() ||
1469       num_node_inputs != fbody->arg_nodes.size()) {
1470     return errors::InvalidArgument(
1471         "Node inputs do not match function arguments: inputs=", num_node_inputs,
1472         " arg_types=", fbody->arg_types.size(),
1473         " arg_nodes=", fbody->arg_nodes.size());
1474   }
1475 
1476   if (num_node_outputs != fbody->ret_types.size() ||
1477       num_node_outputs != fbody->ret_nodes.size()) {
1478     return errors::InvalidArgument(
1479         "Node outputs do not match function returns: outputs=",
1480         num_node_outputs, " ret_types=", fbody->ret_types.size(),
1481         " ret_nodes=", fbody->ret_nodes.size());
1482   }
1483 
1484   for (int i = 0; i < node->num_inputs(); ++i) {
1485     if (node->input_type(i) != fbody->arg_types[i]) {
1486       return errors::InvalidArgument(
1487           "Node input type doesn't match function argument type: ",
1488           node->input_type(i), " != ", fbody->arg_types[i], " @ index=", i);
1489     }
1490   }
1491   for (int i = 0; i < node->num_outputs(); ++i) {
1492     if (node->output_type(i) != fbody->ret_types[i]) {
1493       return errors::InvalidArgument(
1494           "Node output type doesn't match function return type: ",
1495           node->output_type(i), " != ", fbody->ret_types[i], " @ index=", i);
1496     }
1497   }
1498 
1499   if (!options.ignore_noinline) {
1500     TF_RETURN_IF_ERROR(ValidateNoInline(fbody));
1501   }
1502 
1503   return Status::OK();
1504 }
1505 
1506 // Function inlining must preserve function execution semantics with regards to
1507 // side-effects visibility. Tensorflow in Eager mode has an automatic control
1508 // dependencies tracking mechanism, which enforces well-defined execution order
1509 // of all side-effects. Any other frontend (e.g. Swift) must produce graphs
1510 // following the same rules, to ensure that function inlining works correctly.
1511 //
1512 // IMPORTANT: Currently we do not have a true notion of "side-effectful" node,
1513 // we assume that all stateful nodes might have side-effects, though it's not
1514 // true in practice, e.g. `ReadVariableOp` doesn't have an observable
1515 // side-effect.
1516 //
1517 // Automatic control dependency rules in Tensorflow 2.0 (python in eager mode):
1518 //
1519 // 1) When a function has a resource (DT_RESOURCE data type) input argument it
1520 //   "captures" the mutable resource.  This is implemented by automatically
1521 //    adding a incoming control edge from the previous side-effectful op
1522 //    touching that resource, and an outgoing control edge to the next
1523 //    side-effectful op using the same resource. This serializes the mutations
1524 //    of the resource to make graph execution deterministic.
1525 //
1526 // 2) All stateful ops inside a function body are guaranteed to execute in
1527 //    program order, this is achieved by adding control edges between stateful
1528 //    ops at graph construction time. Stateful ops (or ops that must execute)
1529 //    should be in the function control return set. Having a data edge to the
1530 //    regular function output might be not enough, because after function
1531 //    inlining it might happen that data output is unused.
1532 //
1533 // 3) Furthermore, all ops accepting the same resource as an input are
1534 //    guaranteed to run in program order. This is also done by adding control
1535 //    edges at graph construction time. The last op touching the resource
1536 //    must be in a control return set, which will guarantee that all side
1537 //    effects to the resource will happen before function completion.
1538 //
1539 // Function inlining must preserve side-effect visibility:
1540 //
1541 // 1) All side-effects to the captured resources, that happened before function
1542 //    call must be visible to the function body nodes using that resources.
1543 //
1544 // 2) All side-effects to the captured resources, that happened inside function
1545 //    body, must be visible to every op/function using that resource after the
1546 //    function call completed.
1547 //
1548 // To guarantee that these properties are preserved after inlining we:
1549 //
1550 // 1) Create "input_control_node" NoOp. Function call node incoming control
1551 //    edges will be forwarded *to* this node. Function inputs (Identity nodes)
1552 //    will have a control edge *from* this node. If function body has nodes
1553 //    without inputs, they will have a control edge *from* this node.
1554 //
1555 // 2) Create "output_control_node" NoOp. All nodes that have incoming control
1556 //    edge *from* the function call node, will be forwarded to this node.
1557 //
1558 //    We have two options for choosing which nodes will have a control edge *to*
1559 //    the "output control node":
1560 //       a) control returns            (`control_ret` field in FunctionDef)
1561 //       b) data returns               (`ret` field in FunctionDef)
1562 //
1563 //    We do a) for multi-device function calls in Tensorflow v2 and b)
1564 //    for the rest for compatibility with Tensorflow v1.
1565 //
1566 //    Following the automatic control dependencies tracking rules, a node that
1567 //    has an incoming control edge from the function call node is dependent on
1568 //    the side-effects happening inside the function body. The output control
1569 //    node will guarantee side-effects execution order.
1570 //
1571 //    If function call node doesn't have an outgoing control edge, it means that
1572 //    no one is interested in observing side-effects that might have happened.
1573 //
1574 // Function inlining might leave the graph in partially-placed state. Function
1575 // inlining caller must call Placer to guarantee that all nodes are placed.
1576 //
1577 // Function inlining with `options.override_device=true` will leave graph in
1578 // fully placed state, by overriding all inlined nodes devices with the caller
1579 // node device, but it will make functions always single-device. These functions
1580 // after inlining will not be able to handle resources on multiple devices. This
1581 // is currently acceptable for XLA use cases (XLA cluster is always executed on
1582 // a single device).
1583 //
1584 // TODO(ezhulenev): Documentation above is ahead of implementation below.
InlineFunctionBody(const FunctionLibraryDefinition & flib_def,Graph * g,Node * caller,const FunctionBody * fbody,const InlineFunctionBodyOptions & options)1585 Status InlineFunctionBody(const FunctionLibraryDefinition& flib_def, Graph* g,
1586                           Node* caller, const FunctionBody* fbody,
1587                           const InlineFunctionBodyOptions& options) {
1588   VLOG(3) << "Inline function call: " << SummarizeNode(*caller) << " ["
1589           << options.DebugString() << "]";
1590   VLOG(4) << "Inlined function definition: " << DebugString(fbody->fdef);
1591 
1592   Status validation = ValidateInlining(caller, fbody, options);
1593   if (!validation.ok()) {
1594     LOG(WARNING) << "Inlining mismatch: " << SummarizeNode(*caller) << " vs. "
1595                  << DebugString(fbody->graph);
1596     return errors::Internal("Inlining mismatch: ", validation.error_message());
1597   }
1598 
1599   // ------------------------------------------------------------------------ //
1600   // Helper functions to create `NoOp` and `Identity` nodes for auxiliary
1601   // control nodes and inlined function inputs and outputs.
1602 
1603   // Add a NoOp node for function control inputs/outputs.
1604   const auto no_op = [&](StringPiece name) {
1605     Node* node = AddNoOp(absl::StrCat(caller->name(), "/", name), g);
1606     node->set_requested_device(caller->def().device());
1607     return node;
1608   };
1609 
1610   // Add an Identity node for function data inputs/outputs.
1611   const auto identity = [&](StringPiece name, Endpoint input) {
1612     return AddIdentity(absl::StrCat(caller->name(), "/", name), g, input);
1613   };
1614 
1615   // ------------------------------------------------------------------------ //
1616   // Input edges. For data edges coming into "caller", we first compute the
1617   // <src>:<src_output> for the i-th input in "inputs".
1618   // If "caller" has any input control dependencies, we add a NoOp
1619   // node "input_control_node", which depends on "caller"'s control inputs.
1620   std::vector<Endpoint> inputs(caller->num_inputs());
1621   Node* input_control_node = nullptr;
1622   for (const Edge* e : caller->in_edges()) {
1623     if (e->IsControlEdge()) {
1624       if (input_control_node == nullptr) {
1625         input_control_node = no_op("input_control_node");
1626       }
1627       g->AddControlEdge(e->src(), input_control_node);
1628     } else {
1629       inputs[e->dst_input()] = {e->src(), e->src_output()};
1630     }
1631   }
1632 
1633   // ------------------------------------------------------------------------ //
1634   // Duplicate fbody->graph into 'g'.  First, we copy the nodes of
1635   // fbody->graph into 'g' except the source and sink nodes.  We copy
1636   // edges among nodes in 'fbody->graph'.
1637   //
1638   // If 'x' is a node in fbody->graph and its copy in 'g' is 'y', we
1639   // remember 'y' in node_map[x->id()].
1640   std::vector<Node*> node_map(fbody->graph->num_node_ids());
1641   for (Node* n : fbody->graph->op_nodes()) {
1642     NodeDef ndef = n->def();
1643     ndef.set_name(strings::StrCat(caller->name(), "/", ndef.name()));
1644     if (options.override_device || ndef.device().empty()) {
1645       ndef.set_device(caller->def().device());
1646     }
1647     for (auto& attr : *ndef.mutable_attr()) {
1648       if (attr.first == "_class") {
1649         attr.second.set_s(
1650             strings::StrCat(caller->name(), "/", attr.second.s()));
1651       }
1652     }
1653     Status added_node;
1654     Node* clone = g->AddNode(ndef, &added_node);
1655     if (options.override_device && !caller->assigned_device_name().empty()) {
1656       clone->set_assigned_device_name(caller->assigned_device_name());
1657     }
1658     TF_CHECK_OK(added_node);
1659     node_map[n->id()] = clone;
1660 
1661     // If there is an input control node, and one of:
1662     // a) the node has no data or control inputs, or
1663     // b) the node is a function call or SymbolicGradient,
1664     // then add a control edge from the input control node to the clone.
1665     //
1666     // We must not execute any nodes if the original function call would not
1667     // have executed. This is especially critical when the function call is
1668     // inside a control-flow construct like tf.cond(). Case (a) ensures that
1669     // such nodes do not run.
1670     //
1671     // The purpose of case (b) is to ensure that instances of case (a) created
1672     // by further inlining steps also receive the control dependency.
1673     //
1674     // TODO(ezhulenev): If caller has no control inputs, should we add a control
1675     // edge from one of the inputs to ensure that function body node will
1676     // execute in correct frame?
1677     if (input_control_node) {
1678       bool has_inputs = absl::c_any_of(
1679           n->in_edges(), [](const Edge* e) { return !e->src()->IsSource(); });
1680       if (!has_inputs || flib_def.Find(clone->type_string()) != nullptr ||
1681           clone->type_string() == kGradientOp) {
1682         g->AddControlEdge(input_control_node, clone);
1683       }
1684     }
1685   }
1686   for (const Edge* e : fbody->graph->edges()) {
1687     if (e->src()->IsSource() || e->src()->IsSink() || e->dst()->IsSource() ||
1688         e->dst()->IsSink()) {
1689       continue;
1690     }
1691     Node* src_copy = node_map[e->src()->id()];
1692     Node* dst_copy = node_map[e->dst()->id()];
1693     g->AddEdge(src_copy, e->src_output(), dst_copy, e->dst_input());
1694   }
1695 
1696   // ------------------------------------------------------------------------ //
1697   // Connect input edges.
1698   //
1699   // We create one Identity node for each input. Then, we connect inputs[i] to
1700   // the i-th identity node added. The nodes that previously connected
1701   // to the j-th output of i-th arg node are reconnected to the i-th
1702   // identity node.
1703   //
1704   // The added identity nodes depend on "input_control_node".
1705   for (std::size_t i = 0; i < fbody->arg_nodes.size(); ++i) {
1706     Node* arg = node_map[fbody->arg_nodes[i]->id()];
1707     Node* n = identity("input", inputs[i]);
1708     if (input_control_node) {
1709       g->AddControlEdge(input_control_node, n);
1710     }
1711     for (const Edge* e : arg->out_edges()) {
1712       if (e->IsControlEdge()) {
1713         g->AddControlEdge(n, e->dst());
1714       } else {
1715         g->AddEdge(n, 0, e->dst(), e->dst_input());
1716       }
1717     }
1718     node_map[fbody->arg_nodes[i]->id()] = n;
1719     g->RemoveNode(arg);  // 'arg' is disconnected.
1720   }
1721 
1722   // ------------------------------------------------------------------------ //
1723   // Connect output edges.
1724   //
1725   // For i-th return node in fbody->graph, we add in "g" an identity node
1726   // (outputs[i-th]). We then reconnect every incoming edge into the i-th return
1727   // node to the added identity node.
1728   //
1729   // For every data edge coming out of "callee"s i-th output, we reconnect it to
1730   // the i-th identity added above.
1731   //
1732   // If "callee" is control-depended upon by any other nodes, we add a NoOp node
1733   // "output_control_node". "output_control_node" depends on all identity nodes
1734   // added above or on all control return nodes (controlled by
1735   // `options.output_control_src` value). And nodes previously depend on
1736   // "callee" is changed to depend on "output_control_node".
1737   std::vector<Node*> outputs(caller->num_outputs());
1738   for (std::size_t i = 0; i < fbody->ret_nodes.size(); ++i) {
1739     Node* ret = node_map[fbody->ret_nodes[i]->id()];
1740     Endpoint data;  // Data input for the ret node.
1741     for (const Edge* e : ret->in_edges()) {
1742       if (!e->IsControlEdge()) {
1743         data = {e->src(), e->src_output()};
1744         break;
1745       }
1746     }
1747     CHECK(data.node != nullptr);
1748     Node* n = identity("output", data);
1749     outputs[i] = n;
1750     for (const Edge* e : ret->in_edges()) {
1751       if (e->IsControlEdge()) {
1752         g->AddControlEdge(e->src(), n);
1753       }
1754     }
1755     g->RemoveNode(ret);  // 'ret' is disconnected.
1756   }
1757   Node* output_control_node = nullptr;
1758   for (const Edge* e : caller->out_edges()) {
1759     if (e->IsControlEdge()) {
1760       if (output_control_node == nullptr) {
1761         output_control_node = no_op("output_control_node");
1762         if (options.output_control_src ==
1763             InlineFunctionBodyOptions::OutputControlSource::kDataOutputs) {
1764           for (Node* n : outputs) {
1765             g->AddControlEdge(n, output_control_node);
1766           }
1767         } else {
1768           for (Node* fbody_node : fbody->control_ret_nodes) {
1769             Node* n = node_map[fbody_node->id()];
1770             g->AddControlEdge(n, output_control_node);
1771           }
1772         }
1773       }
1774       g->AddControlEdge(output_control_node, e->dst());
1775     } else {
1776       g->AddEdge(outputs[e->src_output()], 0, e->dst(), e->dst_input());
1777     }
1778   }
1779   g->RemoveNode(caller);  // 'caller' is replaced with inlined nodes.
1780 
1781   return Status::OK();
1782 }
1783 
IsFunctionCall(const FunctionLibraryDefinition & lib_def,const Node & node)1784 bool IsFunctionCall(const FunctionLibraryDefinition& lib_def,
1785                     const Node& node) {
1786   return node.IsPartitionedCall() ||
1787          node.type_string() == FunctionLibraryDefinition::kGradientOp ||
1788          lib_def.Find(node.def().op()) != nullptr;
1789 }
1790 
ExpandInlineFunctions(FunctionLibraryRuntime * lib,Graph * graph,const ExpandInlineFunctionsOptions & options)1791 bool ExpandInlineFunctions(FunctionLibraryRuntime* lib, Graph* graph,
1792                            const ExpandInlineFunctionsOptions& options) {
1793   std::vector<std::pair<Node*, const FunctionBody*>> candidates;
1794 
1795   const FunctionLibraryDefinition* fld = lib->GetFunctionLibraryDefinition();
1796 
1797   for (Node* node : graph->nodes()) {
1798     // Skip nodes that are not function calls or SymbolicGradient calls.
1799     if (!IsFunctionCall(*lib->GetFunctionLibraryDefinition(), *node)) {
1800       continue;
1801     }
1802     // Skip function calls that marked noinline.
1803     bool noinline;
1804     if (fld->GetAttr(*node, kNoInlineAttr, &noinline).ok() && noinline) {
1805       VLOG(3) << "noinline: " << SummarizeNode(*node);
1806       continue;
1807     }
1808     FunctionLibraryRuntime::Handle handle;
1809     Status s = InstantiateFunctionCall(node->def(), *lib, &handle);
1810     if (!s.ok()) {
1811       LOG(ERROR) << "Failed to instantiate a function:  " << s.error_message();
1812       continue;
1813     }
1814     const FunctionBody* fbody = lib->GetFunctionBody(handle);
1815     CHECK_NOTNULL(fbody);
1816     candidates.emplace_back(node, fbody);
1817   }
1818 
1819   bool inlined_any = false;
1820   for (const auto& p : candidates) {
1821     Status inlined = InlineFunctionBody(*fld, graph, p.first, p.second,
1822                                         p.first->IsPartitionedCall()
1823                                             ? options.multi_device_options
1824                                             : options.native_options);
1825     if (inlined.ok()) {
1826       inlined_any = true;
1827     } else {
1828       VLOG(1) << "Failed to inline function call: node=" << p.first->name()
1829               << " error=" << inlined.error_message();
1830     }
1831   }
1832 
1833   // TODO(ezhulenev): Release handles for inlined function calls.
1834 
1835   return inlined_any;
1836 }
1837 
NewName(const Node * n,bool pretty)1838 string NewName(const Node* n, bool pretty) {
1839   if (pretty) {
1840     return strings::StrCat(n->type_string(), n->id());
1841   } else {
1842     return strings::StrCat("n", n->id());
1843   }
1844 }
1845 
1846 // TODO(zhifengc): Maybe this should be the default Graph::AsGraphDef.
1847 // and stash the original NodeDef name as an attr for documentation
1848 // purpose.
ToGraphDef(const Graph * g,GraphDef * gdef,bool pretty)1849 void ToGraphDef(const Graph* g, GraphDef* gdef, bool pretty) {
1850   // We visit nodes in forward topological sort order, which is a
1851   // possible execution order of the graph.
1852   gtl::InlinedVector<const Edge*, 4> inputs;
1853   gdef->Clear();
1854   gdef->mutable_versions()->CopyFrom(g->versions());
1855 
1856   std::vector<Node*> start_nodes;
1857   for (Node* n : g->nodes()) {
1858     if (n->out_edges().empty()) {
1859       start_nodes.push_back(n);
1860     }
1861   }
1862 
1863   ReverseDFSFrom(*g, start_nodes, nullptr, [gdef, pretty, &inputs](Node* n) {
1864     if (!n->IsOp()) return;
1865     NodeDef* ndef = gdef->add_node();
1866     ndef->set_name(NewName(n, pretty));
1867     ndef->set_op(n->type_string());
1868     for (const auto& attr : n->attrs()) {
1869       (*ndef->mutable_attr())[attr.first] = attr.second;
1870     }
1871 
1872     if (!n->assigned_device_name().empty()) {
1873       ndef->set_device(n->assigned_device_name());
1874     } else {
1875       ndef->set_device(n->requested_device());
1876     }
1877 
1878     inputs.clear();
1879     inputs.resize(n->num_inputs());
1880     for (const Edge* e : n->in_edges()) {
1881       if (e->IsControlEdge()) {
1882         inputs.push_back(e);
1883       } else {
1884         if (inputs[e->dst_input()] == nullptr) {
1885           inputs[e->dst_input()] = e;
1886         } else {
1887           LOG(WARNING) << "Malformed graph node. multiple input edges: "
1888                        << n->DebugString();
1889         }
1890       }
1891     }
1892     // node->name() is merely NodeDef::name, which are not guaranteed
1893     // to be unique and stable after optimization rewrites. Therefore,
1894     // we use "n<node id>" instead.
1895     for (const Edge* e : inputs) {
1896       if (e == nullptr) {
1897         ndef->add_input("unknown");
1898         continue;
1899       }
1900       const string srcname = NewName(e->src(), pretty);
1901       if (!e->src()->IsOp()) {
1902       } else if (e->IsControlEdge()) {
1903         ndef->add_input(strings::StrCat("^", srcname));
1904       } else if (e->src_output() == 0) {
1905         ndef->add_input(srcname);
1906       } else {
1907         ndef->add_input(strings::StrCat(srcname, ":", e->src_output()));
1908       }
1909     }
1910   });
1911 }
1912 
DebugString(const Graph * g)1913 string DebugString(const Graph* g) {
1914   GraphDef gdef;
1915   ToGraphDef(g, &gdef);
1916   return DebugString(gdef);
1917 }
1918 
FunctionBody(const FunctionDef & f,DataTypeSlice arg_t,DataTypeSlice ret_t,Graph * g)1919 FunctionBody::FunctionBody(const FunctionDef& f, DataTypeSlice arg_t,
1920                            DataTypeSlice ret_t, Graph* g)
1921     : fdef(f),
1922       graph(g),
1923       arg_types(arg_t.begin(), arg_t.end()),
1924       ret_types(ret_t.begin(), ret_t.end()) {
1925   // 1. Find regular Arg/Ret nodes.
1926   this->arg_nodes.resize(arg_types.size());
1927   this->ret_nodes.resize(ret_types.size());
1928   for (Node* n : this->graph->op_nodes()) {
1929     gtl::InlinedVector<Node*, 4>* node_vec;
1930     if (n->type_string() == kRetOp || n->type_string() == kDeviceRetOp) {
1931       node_vec = &this->ret_nodes;
1932     } else if (n->type_string() == kArgOp || n->type_string() == kDeviceArgOp) {
1933       node_vec = &this->arg_nodes;
1934     } else {
1935       continue;
1936     }
1937     int index;
1938     TF_CHECK_OK(GetNodeAttr(n->attrs(), "index", &index));
1939     CHECK_LE(0, index);
1940     CHECK_LT(index, node_vec->size());
1941     (*node_vec)[index] = n;
1942   }
1943   // 2. Find ControlRet nodes that must be always executed.
1944   std::unordered_set<StringPiece, StringPieceHasher> control_ret_node_names;
1945   for (const auto& control_ret : fdef.control_ret()) {
1946     control_ret_node_names.insert(control_ret.second);
1947   }
1948   this->control_ret_nodes.reserve(control_ret_node_names.size());
1949   for (Node* n : this->graph->op_nodes()) {
1950     if (control_ret_node_names.count(n->name()) > 0) {
1951       this->control_ret_nodes.push_back(n);
1952     }
1953   }
1954 }
1955 
~FunctionBody()1956 FunctionBody::~FunctionBody() { delete this->graph; }
1957 
1958 class SymbolicGradientHelper {
1959  public:
SymbolicGradientHelper(const FunctionBody & f)1960   explicit SymbolicGradientHelper(const FunctionBody& f) : fbody_(&f) {}
1961 
~SymbolicGradientHelper()1962   ~SymbolicGradientHelper() { delete gbody_; }
1963 
1964   FunctionBody* Compute();
1965 
1966  private:
1967   const FunctionBody* fbody_;
1968   FunctionBody* gbody_ = nullptr;
1969 
1970   // Makes a copy of fbody_ in gbody_.
1971   void Copy();
1972 
1973   TF_DISALLOW_COPY_AND_ASSIGN(SymbolicGradientHelper);
1974 };
1975 
Copy()1976 void SymbolicGradientHelper::Copy() {
1977   const Graph& src = *(fbody_->graph);
1978   gbody_->graph = new Graph(src.op_registry());
1979   Graph* dst = gbody_->graph;
1980 
1981   std::vector<Node*> node_map(src.num_node_ids());
1982 
1983   // Copy the nodes.
1984   node_map[src.source_node()->id()] = dst->source_node();
1985   node_map[src.sink_node()->id()] = dst->sink_node();
1986   for (Node* n : src.op_nodes()) {
1987     node_map[n->id()] = dst->CopyNode(n);
1988   }
1989 
1990   // Copy the edges.
1991   for (const Edge* e : src.edges()) {
1992     Node* src_copy = node_map[e->src()->id()];
1993     Node* dst_copy = node_map[e->dst()->id()];
1994     dst->AddEdge(src_copy, e->src_output(), dst_copy, e->dst_input());
1995   }
1996 
1997   // Save inputs in copied graph.
1998   CHECK_EQ(fbody_->arg_types.size(), fbody_->arg_nodes.size());
1999   gbody_->arg_types = fbody_->arg_types;
2000   for (std::size_t i = 0; i < fbody_->arg_nodes.size(); ++i) {
2001     gbody_->arg_nodes.push_back(node_map[fbody_->arg_nodes[i]->id()]);
2002   }
2003 
2004   // Save outputs in copied graph.
2005   CHECK_EQ(fbody_->ret_types.size(), fbody_->ret_nodes.size());
2006   gbody_->ret_types = fbody_->ret_types;
2007   for (std::size_t i = 0; i < fbody_->ret_nodes.size(); ++i) {
2008     gbody_->ret_nodes.push_back(node_map[fbody_->ret_nodes[i]->id()]);
2009   }
2010 }
2011 
Compute()2012 FunctionBody* SymbolicGradientHelper::Compute() {
2013   CHECK(gbody_ == nullptr);
2014   gbody_ = new FunctionBody;
2015 
2016   // Copy fbody_ into gbody_.
2017   Copy();
2018 
2019   Graph* g = gbody_->graph;
2020 
2021   const int num_y = static_cast<int>(gbody_->ret_nodes.size());
2022 
2023   // Populate 'y_node_outputs_' with node function body outputs.
2024   // Populate 'y_grad_nodes' with initial gradient nodes for each return node
2025   // of the original function body (these will be 'arg' nodes in the function
2026   // gradient body).
2027   std::vector<NodeOut> y_node_outputs;
2028   y_node_outputs.reserve(num_y);
2029   std::vector<NodeOut> y_grad_node_outputs;
2030   y_grad_node_outputs.reserve(num_y);
2031   for (int i = 0; i < num_y; ++i) {
2032     Node* y = gbody_->ret_nodes[i];
2033     y_node_outputs.push_back({y, 0});
2034     DCHECK_EQ(y->type_string(), kRetOp);
2035     const DataType dtype = y->input_type(0);
2036     const int index = static_cast<int>(gbody_->arg_nodes.size());
2037     Node* dy = AddArg(g, dtype, index);
2038     gbody_->arg_types.push_back(dtype);
2039     gbody_->arg_nodes.push_back(dy);
2040     y_grad_node_outputs.push_back({dy, 0});
2041   }
2042 
2043   // Populate 'x_nodes' with function args (excluding 'y_grad_node_outputs').
2044   const size_t num_x = fbody_->arg_nodes.size();
2045   std::vector<NodeOut> x_node_outputs;
2046   x_node_outputs.reserve(num_x);
2047   for (size_t i = 0; i < fbody_->arg_nodes.size(); ++i) {
2048     x_node_outputs.push_back({gbody_->arg_nodes[i], 0});
2049   }
2050 
2051   // Call AddSymbolicGradients which will add nodes to graph 'g' that
2052   // compute the function gradient (adding an entry in 'x_grad_node_outputs'
2053   // for each node in 'x_node_outputs').
2054   std::vector<NodeOut> x_grad_node_outputs;
2055   TF_CHECK_OK(AddSymbolicGradients(y_node_outputs, x_node_outputs,
2056                                    y_grad_node_outputs, &x_grad_node_outputs,
2057                                    g));
2058 
2059   // Remove the old return nodes from the function body.
2060   for (Node* n : gbody_->ret_nodes) {
2061     g->RemoveNode(n);
2062   }
2063   gbody_->ret_types = fbody_->arg_types;
2064   // TODO(apassos): use the right dtype for gradients of  resource variables
2065   for (int i = 0; i < gbody_->ret_types.size(); ++i) {
2066     if (gbody_->ret_types[i] == DT_RESOURCE) {
2067       gbody_->ret_types[i] = DT_FLOAT;
2068     }
2069   }
2070   gbody_->ret_nodes.clear();
2071   // Add new return nodes to the function gradient body for each node
2072   // in 'x_grad_nodes'.
2073   const int arg_types_size = static_cast<int>(fbody_->arg_types.size());
2074   for (int i = 0; i < arg_types_size; ++i) {
2075     Endpoint grad = {x_grad_node_outputs[i].node, x_grad_node_outputs[i].index};
2076     Node* ret = AddRet(g, grad, i);
2077     gbody_->ret_nodes.push_back(ret);
2078   }
2079 
2080   auto ret = gbody_;
2081   gbody_ = nullptr;
2082   return ret;
2083 }
2084 
SymbolicGradient(const FunctionBody & f)2085 FunctionBody* SymbolicGradient(const FunctionBody& f) {
2086   return SymbolicGradientHelper(f).Compute();
2087 }
2088 
FunctionDefToBodyHelper(const FunctionDef & fdef,const AttrSlice & attrs,const FunctionLibraryDefinition * const lib_def,const std::function<Status (const string &,const OpDef **)> & get_func_sig,FunctionBody ** fbody)2089 Status FunctionDefToBodyHelper(
2090     const FunctionDef& fdef, const AttrSlice& attrs,
2091     const FunctionLibraryDefinition* const lib_def,
2092     const std::function<Status(const string&, const OpDef**)>& get_func_sig,
2093     FunctionBody** fbody) {
2094   // Instantiates the function template into a graph def.
2095   InstantiationResult result;
2096   TF_RETURN_IF_ERROR(InstantiateFunction(fdef, attrs, get_func_sig, &result));
2097 
2098   std::unique_ptr<Graph> graph(new Graph(lib_def));
2099   GraphConstructorOptions opts;
2100   opts.allow_internal_ops = true;
2101   opts.expect_device_spec = false;
2102   TF_RETURN_IF_ERROR(ConvertNodeDefsToGraph(opts, result.nodes, graph.get()));
2103 
2104   // Call BuildControlFlowInfo to validate that this function body has
2105   // well-formed control flow.
2106   std::vector<ControlFlowInfo> dummy;
2107   TF_RETURN_IF_ERROR(BuildControlFlowInfo(graph.get(), &dummy));
2108 
2109   *fbody = new FunctionBody(fdef, result.arg_types, result.ret_types,
2110                             graph.release());
2111   return Status::OK();
2112 }
2113 
2114 }  // end namespace tensorflow
2115