1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/common_runtime/function.h"
17 
18 #include <deque>
19 #include <vector>
20 
21 #include "absl/algorithm/container.h"
22 #include "absl/memory/memory.h"
23 #include "absl/strings/str_cat.h"
24 #include "tensorflow/core/common_runtime/device.h"
25 #include "tensorflow/core/common_runtime/executor.h"
26 #include "tensorflow/core/common_runtime/executor_factory.h"
27 #include "tensorflow/core/common_runtime/gradients.h"
28 #include "tensorflow/core/common_runtime/graph_constructor.h"
29 #include "tensorflow/core/common_runtime/graph_optimizer.h"
30 #include "tensorflow/core/common_runtime/inline_function_utils.h"
31 #include "tensorflow/core/common_runtime/memory_types.h"
32 #include "tensorflow/core/common_runtime/process_function_library_runtime.h"
33 #include "tensorflow/core/common_runtime/rendezvous_mgr.h"
34 #include "tensorflow/core/framework/collective.h"
35 #include "tensorflow/core/framework/function.h"
36 #include "tensorflow/core/framework/node_def.pb.h"
37 #include "tensorflow/core/framework/node_def_util.h"
38 #include "tensorflow/core/framework/op.h"
39 #include "tensorflow/core/framework/op_kernel.h"
40 #include "tensorflow/core/framework/versions.pb.h"
41 #include "tensorflow/core/graph/algorithm.h"
42 #include "tensorflow/core/graph/control_flow.h"
43 #include "tensorflow/core/graph/node_builder.h"
44 #include "tensorflow/core/graph/optimizer_cse.h"
45 #include "tensorflow/core/lib/core/threadpool.h"
46 #include "tensorflow/core/lib/gtl/map_util.h"
47 #include "tensorflow/core/platform/macros.h"
48 #include "tensorflow/core/profiler/lib/connected_traceme.h"
49 #include "tensorflow/core/profiler/lib/traceme.h"
50 #include "tensorflow/core/protobuf/config.pb.h"
51 
52 // See core/kernels/function_ops.cc for related kernels.
53 
54 namespace tensorflow {
55 
56 // A few string constant used throughout this module.
57 static constexpr const char* const kArgOp = FunctionLibraryDefinition::kArgOp;
58 static constexpr const char* const kDeviceArgOp =
59     FunctionLibraryDefinition::kDeviceArgOp;
60 static constexpr const char* const kRetOp = FunctionLibraryDefinition::kRetOp;
61 static constexpr const char* const kDeviceRetOp =
62     FunctionLibraryDefinition::kDeviceRetOp;
63 static constexpr const char* const kGradientOp =
64     FunctionLibraryDefinition::kGradientOp;
65 static constexpr const char* const kNodeLabel = "Func";
66 static constexpr const char* const kFuncAttr =
67     FunctionLibraryDefinition::kFuncAttr;
68 
69 // Represents the index-th output of a node.
70 struct Endpoint {
71   Node* node;
72   int index;
73 
74   // Returns the string name represents this endpoint.
nametensorflow::Endpoint75   string name() const {
76     if (index == 0) {
77       return node->name();
78     } else {
79       return strings::StrCat(node->name(), ":", index);
80     }
81   }
82 
dtypetensorflow::Endpoint83   DataType dtype() const { return node->output_type(index); }
84 };
85 
86 struct EndpointHash {
operator ()tensorflow::EndpointHash87   uint64 operator()(const Endpoint& x) const {
88     return Hash64(reinterpret_cast<const char*>(&x.node), sizeof(Node*),
89                   x.index);
90   }
91 };
92 
93 struct EndpointEq {
operator ()tensorflow::EndpointEq94   bool operator()(const Endpoint& x, const Endpoint& y) const {
95     return (x.node == y.node) && (x.index == y.index);
96   }
97 };
98 
99 // The following Add* routines are used to add a few graph nodes while
100 // functions are transformed.
AddArg(Graph * g,DataType dtype,int index)101 static Node* AddArg(Graph* g, DataType dtype, int index) {
102   DCHECK_LT(0, dtype);
103   DCHECK_LT(dtype, DT_FLOAT_REF);
104   NodeDef ndef;
105   ndef.set_name(g->NewName(kNodeLabel));
106   ndef.set_op(kArgOp);
107   AddNodeAttr("T", dtype, &ndef);
108   AddNodeAttr("index", index, &ndef);
109   Status s;
110   Node* ret = g->AddNode(ndef, &s);
111   TF_CHECK_OK(s);
112   return ret;
113 }
114 
AddRet(Graph * g,Endpoint input,int index)115 static Node* AddRet(Graph* g, Endpoint input, int index) {
116   DCHECK_LT(0, input.dtype());
117   DCHECK_LT(input.dtype(), DT_FLOAT_REF);
118   NodeDef ndef;
119   ndef.set_name(g->NewName(kNodeLabel));
120   ndef.set_op(kRetOp);
121   ndef.add_input(input.name());
122   AddNodeAttr("T", input.dtype(), &ndef);
123   AddNodeAttr("index", index, &ndef);
124   Status s;
125   Node* ret = g->AddNode(ndef, &s);
126   TF_CHECK_OK(s);
127   g->AddEdge(input.node, input.index, ret, 0);
128   return ret;
129 }
130 
131 // FunctionLibraryRuntime implementation that forwards all the function calls to
132 // the base runtime implementation, and only overrides FunctionLibraryDefinition
133 // in calls to Instantiate (if caller doesn't provide the
134 // InstantiateOptions::lib_def option).
135 //
136 // When the function library runtime (FunctionLibraryRuntimeImpl specifically)
137 // instantiates a function into a Graph object, it also creates an Executor for
138 // it. That executor has a pointer to the function library runtime instance,
139 // that is used to instantiate all nested function calls.
140 //
141 // The function library definition used to instantiate the function must be
142 // preserved in the Executor's function library runtime.
143 //
144 // IMPORTANT: This runtime is intended for use only in executors created for
145 // functions instantiated into a graph in FunctionLibraryRuntimeImpl.
146 class FunctionLibraryRuntimeOverlay : public FunctionLibraryRuntime {
147  public:
FunctionLibraryRuntimeOverlay(FunctionLibraryRuntime * base_flr,const FunctionLibraryDefinition * lib_def)148   FunctionLibraryRuntimeOverlay(FunctionLibraryRuntime* base_flr,
149                                 const FunctionLibraryDefinition* lib_def)
150       : base_flr_(base_flr), lib_def_(lib_def) {}
151   ~FunctionLibraryRuntimeOverlay() override;
152 
153   Status Instantiate(const string& function_name, AttrSlice attrs,
154                      const InstantiateOptions& options,
155                      Handle* handle) override;
156 
157   Status ReleaseHandle(Handle handle) override;
158 
159   const FunctionBody* GetFunctionBody(Handle h) override;
160 
161   Status GetRetTypes(Handle h, DataTypeVector* ret_types) override;
162 
163   void Run(const Options& opts, Handle handle, gtl::ArraySlice<Tensor> args,
164            std::vector<Tensor>* rets, DoneCallback done) override;
165 
166   void Run(const Options& opts, Handle handle, CallFrameInterface* call_frame,
167            DoneCallback done) override;
168 
169   Status RunSync(Options opts, Handle handle, gtl::ArraySlice<Tensor> args,
170                  std::vector<Tensor>* rets) override;
171 
172   Status RunSync(Options opts, Handle handle,
173                  CallFrameInterface* frame) override;
174 
175   Status CreateKernel(const std::shared_ptr<const NodeProperties>& props,
176                       OpKernel** kernel) override;
177 
178   bool IsStateful(const string& function_name) const override;
179 
180   const FunctionLibraryDefinition* GetFunctionLibraryDefinition()
181       const override;
182 
183   Env* env() override;
184   const ConfigProto* const config_proto() override;
185   Device* device() override;
186   const Device* device() const override;
187   std::function<void(std::function<void()>)>* runner() override;
188   const DeviceMgr* device_mgr() const override;
189 
190   string DebugString(Handle handle) override;
191   int graph_def_version() const override;
192 
193   Status Clone(std::unique_ptr<FunctionLibraryDefinition>* out_lib_def,
194                std::unique_ptr<ProcessFunctionLibraryRuntime>* out_pflr,
195                FunctionLibraryRuntime** out_flr,
196                bool skip_flib_def = false) override;
197 
198  private:
199   FunctionLibraryRuntime* base_flr_;          // not owned
200   const FunctionLibraryDefinition* lib_def_;  // not owned
201 };
202 
203 FunctionLibraryRuntimeOverlay::~FunctionLibraryRuntimeOverlay() = default;
204 
Instantiate(const string & function_name,AttrSlice attrs,const InstantiateOptions & options,Handle * handle)205 Status FunctionLibraryRuntimeOverlay::Instantiate(
206     const string& function_name, AttrSlice attrs,
207     const InstantiateOptions& options, Handle* handle) {
208   // We automatically set the `lib_def` option for all instantiations, if the
209   // caller doesn't set this option explicitly.
210   if (!options.lib_def && lib_def_) {
211     InstantiateOptions options_copy = options;
212     options_copy.lib_def = lib_def_;
213     return base_flr_->Instantiate(function_name, attrs, options_copy, handle);
214   } else {
215     return base_flr_->Instantiate(function_name, attrs, options, handle);
216   }
217 }
218 
ReleaseHandle(Handle handle)219 Status FunctionLibraryRuntimeOverlay::ReleaseHandle(Handle handle) {
220   return base_flr_->ReleaseHandle(handle);
221 }
222 
GetFunctionBody(Handle h)223 const FunctionBody* FunctionLibraryRuntimeOverlay::GetFunctionBody(Handle h) {
224   return base_flr_->GetFunctionBody(h);
225 }
226 
GetRetTypes(Handle h,DataTypeVector * ret_types)227 Status FunctionLibraryRuntimeOverlay::GetRetTypes(Handle h,
228                                                   DataTypeVector* ret_types) {
229   return base_flr_->GetRetTypes(h, ret_types);
230 }
231 
Run(const Options & opts,Handle handle,gtl::ArraySlice<Tensor> args,std::vector<Tensor> * rets,DoneCallback done)232 void FunctionLibraryRuntimeOverlay::Run(const Options& opts, Handle handle,
233                                         gtl::ArraySlice<Tensor> args,
234                                         std::vector<Tensor>* rets,
235                                         DoneCallback done) {
236   base_flr_->Run(opts, handle, args, rets, std::move(done));
237 }
238 
Run(const Options & opts,Handle handle,CallFrameInterface * call_frame,DoneCallback done)239 void FunctionLibraryRuntimeOverlay::Run(const Options& opts, Handle handle,
240                                         CallFrameInterface* call_frame,
241                                         DoneCallback done) {
242   base_flr_->Run(opts, handle, call_frame, std::move(done));
243 }
244 
RunSync(Options opts,Handle handle,gtl::ArraySlice<Tensor> args,std::vector<Tensor> * rets)245 Status FunctionLibraryRuntimeOverlay::RunSync(Options opts, Handle handle,
246                                               gtl::ArraySlice<Tensor> args,
247                                               std::vector<Tensor>* rets) {
248   return base_flr_->RunSync(std::move(opts), handle, args, rets);
249 }
250 
RunSync(Options opts,Handle handle,CallFrameInterface * call_frame)251 Status FunctionLibraryRuntimeOverlay::RunSync(Options opts, Handle handle,
252                                               CallFrameInterface* call_frame) {
253   return base_flr_->RunSync(std::move(opts), handle, call_frame);
254 }
255 
CreateKernel(const std::shared_ptr<const NodeProperties> &,OpKernel **)256 Status FunctionLibraryRuntimeOverlay::CreateKernel(
257     const std::shared_ptr<const NodeProperties>&, OpKernel**) {
258   // We don't have access to base_lib_def_ in base function library runtime (aka
259   // FunctionLibraryRuntimeImpl), so to make sure we do not create a kernel with
260   // the wrong lib_def we just disable creation of new kernels through overlays.
261   //
262   // When we call Instantiate from the base runtime with the lib_def option,
263   // the base runtime implementation is responsible for correctly passing it
264   // through to all kernel constructions.
265   return errors::Internal(
266       "Overlay function library runtime doesn't support kernel creation.");
267 }
268 
IsStateful(const string & function_name) const269 bool FunctionLibraryRuntimeOverlay::IsStateful(
270     const string& function_name) const {
271   // Important: we do not forward lookup to the base FLR.
272   const OpDef* op_def;
273   const Status s = lib_def_->LookUpOpDef(function_name, &op_def);
274   return s.ok() && op_def->is_stateful();
275 }
276 
env()277 Env* FunctionLibraryRuntimeOverlay::env() { return base_flr_->env(); }
278 
config_proto()279 const ConfigProto* const FunctionLibraryRuntimeOverlay::config_proto() {
280   return base_flr_->config_proto();
281 }
282 
device()283 Device* FunctionLibraryRuntimeOverlay::device() { return base_flr_->device(); }
284 
device() const285 const Device* FunctionLibraryRuntimeOverlay::device() const {
286   return base_flr_->device();
287 }
288 
289 std::function<void(std::function<void()>)>*
runner()290 FunctionLibraryRuntimeOverlay::runner() {
291   return base_flr_->runner();
292 }
293 
device_mgr() const294 const DeviceMgr* FunctionLibraryRuntimeOverlay::device_mgr() const {
295   return base_flr_->device_mgr();
296 }
297 
298 const FunctionLibraryDefinition*
GetFunctionLibraryDefinition() const299 FunctionLibraryRuntimeOverlay::GetFunctionLibraryDefinition() const {
300   return lib_def_ ? lib_def_ : base_flr_->GetFunctionLibraryDefinition();
301 }
302 
DebugString(Handle handle)303 string FunctionLibraryRuntimeOverlay::DebugString(Handle handle) {
304   return base_flr_->DebugString(handle);
305 }
306 
graph_def_version() const307 int FunctionLibraryRuntimeOverlay::graph_def_version() const {
308   return base_flr_->graph_def_version();
309 }
310 
Clone(std::unique_ptr<FunctionLibraryDefinition> * out_lib_def,std::unique_ptr<ProcessFunctionLibraryRuntime> * out_pflr,FunctionLibraryRuntime ** out_flr,bool skip_flib_def)311 Status FunctionLibraryRuntimeOverlay::Clone(
312     std::unique_ptr<FunctionLibraryDefinition>* out_lib_def,
313     std::unique_ptr<ProcessFunctionLibraryRuntime>* out_pflr,
314     FunctionLibraryRuntime** out_flr, bool skip_flib_def) {
315   // NOTE(ezhulenev): The cloned FunctionLibraryRuntime will be missing the
316   // FunctionLibraryDefinition override, but that's ok because we anyway do not
317   // copy / clone instantiated items from the base FLR.
318   return base_flr_->Clone(out_lib_def, out_pflr, out_flr, skip_flib_def);
319 }
320 
321 class FunctionLibraryRuntimeImpl : public FunctionLibraryRuntime {
322  public:
323   FunctionLibraryRuntimeImpl(const DeviceMgr* dmgr, Env* env,
324                              const ConfigProto* config, Device* device,
325                              int graph_def_version,
326                              const FunctionLibraryDefinition* lib_def,
327                              thread::ThreadPool* default_thread_pool,
328                              const OptimizerOptions& optimizer_options,
329                              const SessionMetadata* session_metadata,
330                              ProcessFunctionLibraryRuntime* parent);
331 
332   ~FunctionLibraryRuntimeImpl() override;
333 
334   Status Instantiate(const string& function_name, AttrSlice attrs,
335                      const InstantiateOptions& options,
336                      Handle* handle) override;
337 
338   Status ReleaseHandle(Handle handle) override;
339 
340   const FunctionBody* GetFunctionBody(Handle handle) override;
341 
342   Status GetRetTypes(Handle handle, DataTypeVector* ret_types) override;
343 
344   Status CreateKernel(const std::shared_ptr<const NodeProperties>& props,
345                       OpKernel** kernel) override;
346 
347   void Run(const Options& opts, Handle handle, gtl::ArraySlice<Tensor> args,
348            std::vector<Tensor>* rets, DoneCallback done) override;
349   void Run(const Options& opts, Handle handle, CallFrameInterface* frame,
350            DoneCallback done) override;
351   Status RunSync(Options opts, Handle handle, gtl::ArraySlice<Tensor> args,
352                  std::vector<Tensor>* rets) override;
353   Status RunSync(Options opts, Handle handle,
354                  CallFrameInterface* call_frame) override;
355 
356   bool IsStateful(const string& function) const override;
357 
GetFunctionLibraryDefinition() const358   const FunctionLibraryDefinition* GetFunctionLibraryDefinition()
359       const override {
360     return base_lib_def_;
361   }
362 
device()363   Device* device() override { return device_; }
device() const364   const Device* device() const override { return device_; }
365 
runner()366   std::function<void(std::function<void()>)>* runner() override {
367     return &default_runner_;
368   }
369 
device_mgr() const370   const DeviceMgr* device_mgr() const override { return device_mgr_; }
env()371   Env* env() override { return env_; }
config_proto()372   const ConfigProto* const config_proto() override { return config_; }
graph_def_version() const373   int graph_def_version() const override { return graph_def_version_; }
374 
375   string DebugString(Handle h) override;
376 
377   Status Clone(std::unique_ptr<FunctionLibraryDefinition>* out_lib_def,
378                std::unique_ptr<ProcessFunctionLibraryRuntime>* out_pflr,
379                FunctionLibraryRuntime** out_flr,
380                bool skip_flib_def = false) override;
381 
382  private:
383   typedef FunctionLibraryRuntimeImpl ME;
384 
385   const DeviceMgr* const device_mgr_;
386   Device* const device_;
387   Env* const env_;
388   const ConfigProto* const config_;
389   const int graph_def_version_;
390   const FunctionLibraryDefinition* const base_lib_def_;
391   GraphOptimizer optimizer_;
392   const SessionMetadata* const session_metadata_;
393   Executor::Args::Runner default_runner_;
394   const string device_name_;
395 
396   std::function<Status(const string&, const OpDef**)> get_func_sig_;
397   std::function<Status(const std::shared_ptr<const NodeProperties>&,
398                        OpKernel**)>
399       create_kernel_;
400 
401   mutable mutex mu_;
402 
403   int next_handle_ TF_GUARDED_BY(mu_);
404 
405   // The instantiated and transformed function is encoded as a Graph
406   // object, and an executor is created for the graph.
407   struct Item {
408     uint64 instantiation_counter = 0;
409     std::unique_ptr<const Graph> graph = nullptr;
410     const FunctionLibraryDefinition* lib_def = nullptr;  // Not owned.
411     FunctionBody* func_graph = nullptr;
412     Executor* exec = nullptr;
413     FunctionLibraryRuntimeOverlay* overlay_flr = nullptr;
414     string executor_type;
415 
~Itemtensorflow::FunctionLibraryRuntimeImpl::Item416     ~Item() {
417       delete this->func_graph;
418       delete this->exec;
419       delete this->overlay_flr;
420     }
421   };
422   std::unique_ptr<std::unordered_map<Handle, std::unique_ptr<Item>>> items_
423       TF_GUARDED_BY(mu_);
424 
425   ProcessFunctionLibraryRuntime* parent_ = nullptr;  // not owned.
426 
427   // Overloads the CreateKernel method, providing a FunctionLibraryRuntime
428   // to use for kernel creation and execution. In particular, this method can
429   // accept a FunctionLibraryRuntimeOverlay that overlays a different
430   // FunctionLibraryDefinition.
431   Status CreateKernel(const std::shared_ptr<const NodeProperties>& props,
432                       FunctionLibraryRuntime* flr, OpKernel** kernel);
433   Status FunctionDefToBody(const FunctionDef& fdef, AttrSlice attrs,
434                            const FunctionLibraryDefinition* lib_def,
435                            std::unique_ptr<FunctionBody>* fbody);
436   Status CreateItem(Item** item);
437   Status GetOrCreateItem(LocalHandle local_handle, Item** item);
438   Status InstantiateSymbolicGradient(const NameAttrList& func,
439                                      const FunctionLibraryDefinition* lib_def,
440                                      std::unique_ptr<FunctionBody>* g_body);
441   bool IsLocalTarget(const InstantiateOptions& options) const;
442   AttrValueMap FixAttrs(const AttrSlice& attrs);
443   void RunRemote(const Options& opts, Handle handle,
444                  gtl::ArraySlice<Tensor> args, std::vector<Tensor>* rets,
445                  Item* item, DoneCallback done);
446 
447   Status PrepareRunSync(
448       Handle handle, Options* run_opts, Item** out_item,
449       std::unique_ptr<PrivateIntraProcessRendezvous>* out_rendezvous);
450 
451   void ExecutorArgsFromOptions(const FunctionLibraryRuntime::Options& run_opts,
452                                CallFrameInterface* frame,
453                                Executor::Args* exec_args);
454 
455   TF_DISALLOW_COPY_AND_ASSIGN(FunctionLibraryRuntimeImpl);
456 };
457 
FunctionLibraryRuntimeImpl(const DeviceMgr * dmgr,Env * env,const ConfigProto * config,Device * device,int graph_def_version,const FunctionLibraryDefinition * lib_def,thread::ThreadPool * default_thread_pool,const OptimizerOptions & optimizer_options,const SessionMetadata * session_metadata,ProcessFunctionLibraryRuntime * parent)458 FunctionLibraryRuntimeImpl::FunctionLibraryRuntimeImpl(
459     const DeviceMgr* dmgr, Env* env, const ConfigProto* config, Device* device,
460     int graph_def_version, const FunctionLibraryDefinition* lib_def,
461     thread::ThreadPool* default_thread_pool,
462     const OptimizerOptions& optimizer_options,
463     const SessionMetadata* session_metadata,
464     ProcessFunctionLibraryRuntime* parent)
465     : device_mgr_(dmgr),
466       device_(device),
467       env_(env),
468       config_(config),
469       graph_def_version_(graph_def_version),
470       base_lib_def_(lib_def),
471       optimizer_(optimizer_options),
472       session_metadata_(session_metadata),
473       default_runner_(nullptr),
474       device_name_(device_ == nullptr
475                        ? ProcessFunctionLibraryRuntime::kDefaultFLRDevice
476                        : device_->name()),
477       next_handle_(0),
478       items_(new std::unordered_map<Handle, std::unique_ptr<Item>>),
479       parent_(parent) {
480   get_func_sig_ = [this](const string& op, const OpDef** sig) {
481     return base_lib_def_->LookUpOpDef(op, sig);
482   };
483   create_kernel_ = [this](const std::shared_ptr<const NodeProperties>& props,
484                           OpKernel** kernel) {
485     return CreateKernel(props, kernel);
486   };
487   thread::ThreadPool* pool = nullptr;
488   if (device_ != nullptr) {
489     pool = device_->tensorflow_device_thread_pool();
490   }
491   if (pool == nullptr) {
492     pool = default_thread_pool;
493   }
494   if (pool != nullptr) {
495     default_runner_ = [pool](Executor::Args::Closure c) {
496       pool->Schedule(std::move(c));
497     };
498   }
499 }
500 
~FunctionLibraryRuntimeImpl()501 FunctionLibraryRuntimeImpl::~FunctionLibraryRuntimeImpl() {
502   // Deleting the items_ list will delete all the function handles registered in
503   // this object. A function may contains a few sub-functions which have also
504   // been registered in this object. Deleting the parent function will call
505   // ReleaseHandle in this class again for each of the sub-functions. These
506   // circular calls may cause segfault since the items_ may have already been
507   // partially deleted when releasing handles of sub-functions. Explicitly
508   // release items_ here and check it in ReleaseHandle to avoid this.
509   items_.reset();
510 }
511 
512 // An asynchronous op kernel which executes an instantiated function
513 // defined in a library.
514 class CallOp : public AsyncOpKernel {
515  public:
CallOp(FunctionLibraryRuntime::Handle handle,OpKernelConstruction * ctx)516   CallOp(FunctionLibraryRuntime::Handle handle, OpKernelConstruction* ctx)
517       : AsyncOpKernel(ctx), handle_(handle) {}
518 
~CallOp()519   ~CallOp() override {
520     // TODO(iga): Release the cached handle_
521   }
522 
ComputeAsync(OpKernelContext * ctx,DoneCallback done)523   void ComputeAsync(OpKernelContext* ctx, DoneCallback done) override {
524     FunctionLibraryRuntime* lib = ctx->function_library();
525     OP_REQUIRES_ASYNC(ctx, lib != nullptr,
526                       errors::Internal("No function library is provided."),
527                       done);
528     FunctionLibraryRuntime::Options opts;
529     opts.rendezvous = ctx->rendezvous();
530     opts.cancellation_manager = ctx->cancellation_manager();
531     opts.step_container = ctx->step_container();
532     opts.stats_collector = ctx->stats_collector();
533     opts.runner = ctx->runner();
534     opts.run_all_kernels_inline = ctx->run_all_kernels_inline();
535     opts.collective_executor = ctx->collective_executor();
536     std::vector<Tensor> args;
537     args.reserve(ctx->num_inputs());
538     for (int i = 0; i < ctx->num_inputs(); ++i) {
539       args.push_back(ctx->input(i));
540     }
541     std::vector<Tensor>* rets = new std::vector<Tensor>;
542     profiler::TraceMe trace_me(
543         [&] {
544           return absl::StrCat("CallOp #parent_step_id=", ctx->step_id(),
545                               ",function_step_id=", opts.step_id, "#");
546         },
547         /*level=*/2);
548     lib->Run(opts, handle_, args, rets,
549              [ctx, done, rets](const Status& status) {
550                if (!status.ok()) {
551                  ctx->SetStatus(status);
552                } else {
553                  const int ret_size = static_cast<int>(rets->size());
554                  CHECK_EQ(ret_size, ctx->num_outputs());
555                  for (int i = 0; i < ret_size; ++i) {
556                    ctx->set_output(i, (*rets)[i]);
557                  }
558                }
559                delete rets;
560                done();
561              });
562   }
563 
564  private:
565   FunctionLibraryRuntime::Handle handle_;
566 
567   TF_DISALLOW_COPY_AND_ASSIGN(CallOp);
568 };
569 
GetFunctionBody(Handle h)570 const FunctionBody* FunctionLibraryRuntimeImpl::GetFunctionBody(Handle h) {
571   LocalHandle local_handle = parent_->GetHandleOnDevice(device_name_, h);
572   if (local_handle == kInvalidLocalHandle) {
573     LOG(ERROR) << "Could not find Handle: " << h
574                << " on device: " << device_name_;
575     return nullptr;
576   }
577 
578   tf_shared_lock l(mu_);
579   auto iter = items_->find(local_handle);
580   CHECK(iter != items_->end());
581   return iter->second->func_graph;
582 }
583 
GetRetTypes(Handle h,DataTypeVector * ret_types)584 Status FunctionLibraryRuntimeImpl::GetRetTypes(Handle h,
585                                                DataTypeVector* ret_types) {
586   if (parent_->IsMultiDevice(h)) {
587     return parent_->GetRetTypes(h, ret_types);
588   }
589   LocalHandle local_handle = parent_->GetHandleOnDevice(device_name_, h);
590   if (local_handle == kInvalidLocalHandle) {
591     return errors::InvalidArgument("Handle ", h, " not found.");
592   }
593   const FunctionBody* fbody = GetFunctionBody(h);
594   *ret_types = fbody->ret_types;
595   return Status::OK();
596 }
597 
CreateKernel(const std::shared_ptr<const NodeProperties> & props,OpKernel ** kernel)598 Status FunctionLibraryRuntimeImpl::CreateKernel(
599     const std::shared_ptr<const NodeProperties>& props, OpKernel** kernel) {
600   return CreateKernel(props, this, kernel);
601 }
602 
CreateKernel(const std::shared_ptr<const NodeProperties> & props,FunctionLibraryRuntime * flr,OpKernel ** kernel)603 Status FunctionLibraryRuntimeImpl::CreateKernel(
604     const std::shared_ptr<const NodeProperties>& props,
605     FunctionLibraryRuntime* flr, OpKernel** kernel) {
606   // If a custom kernel creator is given, try that.
607   Status s;
608   const CustomKernelCreator* custom_kernel_creator =
609       GetDefaultCustomKernelCreator();
610   if (custom_kernel_creator &&
611       custom_kernel_creator->CanCreateKernel(*this, props)) {
612     std::unique_ptr<OpKernel> ret;
613     s = custom_kernel_creator->CreateKernel(this, props, &ret);
614     if (s.ok()) {
615       *kernel = ret.release();
616     } else {
617       VLOG(2) << "Custom creator error: " << s;
618     }
619     return s;
620   }
621 
622   const FunctionLibraryDefinition* lib_def =
623       flr->GetFunctionLibraryDefinition();
624   if (lib_def->Find(props->node_def.op()) == nullptr) {
625     // A primitive operation. Creates the registered kernel.
626     return CreateNonCachedKernel(device_, flr, props, graph_def_version_,
627                                  kernel);
628   }
629 
630   // Try to instantiate this function for the func/attr. Maybe it's
631   // cached already.
632   InstantiateOptions options;
633   if (lib_def != base_lib_def_) {
634     options.lib_def = lib_def;
635   }
636   Handle handle;
637   TF_RETURN_IF_ERROR(Instantiate(props->node_def.op(),
638                                  AttrSlice(&props->node_def.attr()), options,
639                                  &handle));
640 
641   const FunctionBody* fbody = GetFunctionBody(handle);
642   CHECK_NOTNULL(fbody);
643 
644   // TODO(zhifengc): For now, we assume int32 and resources are always on host
645   // memory and other types are always on device memory. We should do type
646   // inference over function body to derive the correct input/output memory
647   // types.
648   MemoryTypeVector input_memory_types;
649   for (const auto& t : fbody->arg_types) {
650     input_memory_types.push_back(MTypeFromDType(t));
651   }
652   MemoryTypeVector output_memory_types;
653   for (const auto& t : fbody->ret_types) {
654     output_memory_types.push_back(MTypeFromDType(t));
655   }
656 
657   // Constructs a CallOp kernel for running the instantiated function.
658   auto device_type = DeviceType(device_->attributes().device_type());
659   auto new_props = std::make_shared<NodeProperties>(
660       &fbody->fdef.signature(), props->node_def, fbody->arg_types,
661       fbody->ret_types);
662   OpKernelConstruction construction(
663       device_type, device_, device_->GetAllocator(AllocatorAttributes()), flr,
664       device_->resource_manager(), props, input_memory_types,
665       output_memory_types, graph_def_version_, &s);
666   if (s.ok()) {
667     *kernel = new CallOp(handle, &construction);
668   }
669   return s;
670 }
671 
FunctionDefToBody(const FunctionDef & fdef,AttrSlice attrs,const FunctionLibraryDefinition * lib_def,std::unique_ptr<FunctionBody> * fbody)672 Status FunctionLibraryRuntimeImpl::FunctionDefToBody(
673     const FunctionDef& fdef, AttrSlice attrs,
674     const FunctionLibraryDefinition* lib_def,
675     std::unique_ptr<FunctionBody>* fbody) {
676   if (lib_def == base_lib_def_) {
677     return FunctionDefToBodyHelper(fdef, attrs, lib_def, get_func_sig_, fbody);
678   } else {
679     auto get_func_sig = [lib_def](const string& op, const OpDef** sig) {
680       return lib_def->LookUpOpDef(op, sig);
681     };
682     return FunctionDefToBodyHelper(fdef, attrs, lib_def, get_func_sig, fbody);
683   }
684 }
685 
InstantiateSymbolicGradient(const NameAttrList & func,const FunctionLibraryDefinition * lib_def,std::unique_ptr<FunctionBody> * g_body)686 Status FunctionLibraryRuntimeImpl::InstantiateSymbolicGradient(
687     const NameAttrList& func, const FunctionLibraryDefinition* lib_def,
688     std::unique_ptr<FunctionBody>* g_body) {
689   const FunctionDef* fdef = lib_def->Find(func.name());
690   if (fdef == nullptr) {
691     // f is a primitive op.
692     gradient::Creator creator;
693     TF_RETURN_IF_ERROR(gradient::GetOpGradientCreator(func.name(), &creator));
694     if (creator == nullptr) {
695       return errors::InvalidArgument("No gradient is defined for ",
696                                      func.name());
697     }
698     FunctionDef grad_fdef;
699     // TODO(josh11b): Should filter out the attrs from func that aren't used
700     // by the gradient function.
701     TF_RETURN_IF_ERROR(creator(AttrSlice(&func.attr()), &grad_fdef));
702     TF_RETURN_IF_ERROR(
703         FunctionDefToBody(grad_fdef, AttrSlice(&func.attr()), lib_def, g_body));
704   } else {
705     // f is a user-defined function.
706     InstantiateOptions options;
707     if (lib_def != base_lib_def_) {
708       options.lib_def = lib_def;
709     }
710     Handle f_handle;
711     TF_RETURN_IF_ERROR(
712         Instantiate(func.name(), AttrSlice(&func.attr()), options, &f_handle));
713     const FunctionBody* f_body = GetFunctionBody(f_handle);
714     CHECK_NOTNULL(f_body);
715     *g_body = SymbolicGradient(*f_body);
716   }
717   return Status::OK();
718 }
719 
IsLocalTarget(const InstantiateOptions & options) const720 bool FunctionLibraryRuntimeImpl::IsLocalTarget(
721     const InstantiateOptions& options) const {
722   if (device_ == nullptr) return true;
723   if (options.target.empty()) return true;
724   if (options.is_multi_device_function) return false;
725   Device* target_device;
726   if (!device_mgr_->LookupDevice(options.target, &target_device).ok()) {
727     VLOG(1) << "Not instantiating function in FLR because failed to "
728             << "find device " << options.target << " in device manager";
729     return false;
730   }
731   if (target_device != device_) {
732     VLOG(1) << "Not instantiating function in FLR because target device "
733             << options.target
734             << " is different from FLR's device: " << device_->DebugString();
735     return false;
736   }
737   return true;
738 }
739 
Instantiate(const string & function_name,AttrSlice attrs,const InstantiateOptions & options,Handle * handle)740 Status FunctionLibraryRuntimeImpl::Instantiate(
741     const string& function_name, AttrSlice attrs,
742     const InstantiateOptions& options, Handle* handle) {
743   if (!IsLocalTarget(options)) {
744     return parent_->Instantiate(function_name, attrs, options, handle);
745   }
746 
747   // Since this is a local target, ensure that the local `device_name_` appears
748   // in the canonical key.
749   InstantiateOptions options_copy(options);
750   options_copy.target = device_name_;
751   const string key = Canonicalize(function_name, attrs, options_copy);
752 
753   {
754     mutex_lock l(mu_);
755     *handle = parent_->GetHandle(key);
756     if (*handle != kInvalidHandle) {
757       FunctionLibraryRuntime::LocalHandle handle_on_device =
758           parent_->GetHandleOnDevice(device_name_, *handle);
759       if (handle_on_device == kInvalidLocalHandle) {
760         return errors::Internal("LocalHandle not found for handle ", *handle,
761                                 ".");
762       }
763       auto item_handle = items_->find(handle_on_device);
764       if (item_handle == items_->end()) {
765         return errors::Internal("LocalHandle ", handle_on_device,
766                                 " for handle ", *handle,
767                                 " not found in items.");
768       }
769       ++item_handle->second->instantiation_counter;
770       return Status::OK();
771     }
772   }
773 
774   const FunctionLibraryDefinition* lib_def =
775       options.lib_def ? options.lib_def : base_lib_def_;
776   std::unique_ptr<FunctionBody> fbody;
777   if (function_name == kGradientOp) {
778     const AttrValue* f = attrs.Find(kFuncAttr);
779     if (f == nullptr) {
780       return errors::InvalidArgument("SymbolicGradient is missing attr: f");
781     }
782     const auto& func = f->func();
783     if (func.name() == kGradientOp) {
784       return errors::InvalidArgument("Can't take gradient of SymbolicGradient");
785     }
786     const string grad = lib_def->FindGradient(func.name());
787     if (!grad.empty()) {
788       return Instantiate(grad, AttrSlice(&func.attr()), options, handle);
789     }
790     TF_RETURN_IF_ERROR(InstantiateSymbolicGradient(func, lib_def, &fbody));
791   } else {
792     const FunctionDef* fdef = lib_def->Find(function_name);
793     if (fdef == nullptr) {
794       return errors::NotFound("Function ", function_name, " is not defined.");
795     }
796     TF_RETURN_IF_ERROR(FunctionDefToBody(*fdef, attrs, lib_def, &fbody));
797   }
798 
799   LocalHandle local_handle;
800   {
801     mutex_lock l(mu_);
802     *handle = parent_->GetHandle(key);
803     if (*handle != kInvalidHandle) {
804       local_handle = parent_->GetHandleOnDevice(device_name_, *handle);
805       ++(*items_)[local_handle]->instantiation_counter;
806     } else {
807       *handle = parent_->AddHandle(key, device_name_, next_handle_);
808       Item* item = new Item;
809       item->func_graph = fbody.release();
810       item->instantiation_counter = 1;
811       item->executor_type = ExecutorType(options, attrs);
812       if (options.lib_def) {
813         item->overlay_flr =
814             new FunctionLibraryRuntimeOverlay(this, options.lib_def);
815       }
816       local_handle = next_handle_++;
817       items_->emplace(local_handle, std::unique_ptr<Item>(item));
818     }
819   }
820 
821   if (options.create_kernels_eagerly) {
822     Item* item;
823     TF_RETURN_IF_ERROR(GetOrCreateItem(local_handle, &item));
824   }
825 
826   return Status::OK();
827 }
828 
ReleaseHandle(Handle handle)829 Status FunctionLibraryRuntimeImpl::ReleaseHandle(Handle handle) {
830   LocalHandle h = parent_->GetHandleOnDevice(device_name_, handle);
831   if (h == kInvalidLocalHandle) {
832     return parent_->ReleaseHandle(handle);
833   }
834   std::unique_ptr<Item> item_to_delete;
835   Status parent_status;
836   {
837     mutex_lock l(mu_);
838     // Return directly if all items has already been released.
839     if (items_ == nullptr) return Status::OK();
840 
841     auto it = items_->find(h);
842     if (it == items_->end()) {
843       return errors::Internal(
844           "Inconsistent FunctionLibraryRuntime. Expected to find an item for "
845           "handle ",
846           h, " but found none");
847     }
848     std::unique_ptr<Item>& item = it->second;
849     --item->instantiation_counter;
850     if (item->instantiation_counter == 0) {
851       // We don't simply erase h's item because that would trigger
852       // item destruction while holding mu_. Item destruction can
853       // trigger graph destruction. If the graph contains kernels like
854       // CallOp or PartitionCallOp, their destructors will release cached
855       // function handles, resulting in deadlock here.
856       item_to_delete = std::move(item);
857       items_->erase(h);
858       parent_status = parent_->RemoveHandle(handle);
859     }
860   }
861   return parent_status;
862 }
863 
864 namespace {
865 // Removes all stateless nodes that do not contribute to a return
866 // value from the function body. Unlike `RemoveDeadNodes()`, which is
867 // triggered by `OptimizerOptions.do_function_inlining`, this pass
868 // ignores the SINK node, from which (by definition) all nodes are
869 // reverse reachable, and preserves all nodes that are reachable from
870 // control output nodes.
871 //
872 // TODO(ezhulenev, skyewm): Function body should not have special treatment of
873 // stateful ops, graph should encode nodes that must execute with `control_ret`
874 // and `control_output`.
PruneFunctionBody(const FunctionDef & fdef,Graph * g)875 void PruneFunctionBody(const FunctionDef& fdef, Graph* g) {
876   VLOG(2) << "Pruning function body: function_name=" << fdef.signature().name();
877 
878   // `control_ret` nodes must be always executed.
879   std::unordered_set<StringPiece, StringPieceHasher> control_ret_nodes;
880   for (const auto& control_ret : fdef.control_ret()) {
881     control_ret_nodes.insert(control_ret.second);
882   }
883 
884   std::unordered_set<const Node*> nodes;
885   for (auto n : g->nodes()) {
886     // NOTE(mrry): "_Retval" nodes are stateful, and so will be added
887     // to the seed set of `nodes`. "_Arg" nodes are also stateful, but we
888     // specifically exclude them as seeds, to avoid unconditionally executing
889     // unused argument nodes (e.g. in a function like `lambda x, y: y`).
890     // TODO(mrry): Investigate whether the `n->IsControlFlow()` test is
891     // still needed. It would be preferable to prune entire loops and/or
892     // conditionals if they are not used in the graph.
893     if (n->IsControlFlow() ||
894         (n->op_def().is_stateful() && n->type_string() != kArgOp) ||
895         (control_ret_nodes.find(n->name()) != control_ret_nodes.end())) {
896       nodes.insert(n);
897     }
898   }
899   bool changed = PruneForReverseReachability(g, std::move(nodes));
900   if (changed) {
901     FixupSourceAndSinkEdges(g);
902   }
903 }
904 }  // namespace
905 
CreateItem(Item ** item)906 Status FunctionLibraryRuntimeImpl::CreateItem(Item** item) {
907   const FunctionBody* fbody;
908   FunctionLibraryRuntime* flr;
909   string executor_type;
910   {
911     tf_shared_lock l(mu_);
912     fbody = (*item)->func_graph;
913     flr = (*item)->overlay_flr
914               ? static_cast<FunctionLibraryRuntime*>((*item)->overlay_flr)
915               : static_cast<FunctionLibraryRuntime*>(this);
916     executor_type = (*item)->executor_type;
917   }
918   const FunctionLibraryDefinition* lib_def =
919       flr->GetFunctionLibraryDefinition();
920   std::unique_ptr<Graph> g(new Graph(lib_def));
921   CopyGraph(*fbody->graph, g.get());
922 
923   PruneFunctionBody(fbody->fdef, g.get());
924   optimizer_.Optimize(this, env(), device(), &g, /*shape_map=*/nullptr);
925   TF_RETURN_IF_ERROR(EnsureMemoryTypes(DeviceType(device()->device_type()),
926                                        device()->name(), g.get()));
927 
928   // Creates an executor based on the g. This must be done without
929   // holding mu_ because create_kernel_ calls back into the library.
930   LocalExecutorParams params;
931   params.device = device_;
932   params.function_library = flr;
933   if (flr == this) {
934     params.create_kernel = create_kernel_;
935   } else {
936     params.create_kernel =
937         [this, flr](const std::shared_ptr<const NodeProperties>& props,
938                     OpKernel** kernel) {
939           return CreateKernel(props, flr, kernel);
940         };
941   }
942   params.delete_kernel = [](OpKernel* kernel) {
943     DeleteNonCachedKernel(kernel);
944   };
945   params.session_metadata = session_metadata_;
946   std::unique_ptr<Executor> exec;
947   TF_RETURN_IF_ERROR(NewExecutor(executor_type, params, *g, &exec));
948   {
949     // Guard item since it is already inserted in items_.
950     mutex_lock l(mu_);
951     if ((*item)->exec == nullptr) {
952       (*item)->graph = std::move(g);
953       (*item)->exec = exec.release();
954     }
955   }
956   return Status::OK();
957 }
958 
GetOrCreateItem(LocalHandle local_handle,Item ** item)959 Status FunctionLibraryRuntimeImpl::GetOrCreateItem(LocalHandle local_handle,
960                                                    Item** item) {
961   {
962     tf_shared_lock l(mu_);
963     auto iter = items_->find(local_handle);
964     if (iter == items_->end()) {
965       return errors::Internal("Local function handle ", local_handle,
966                               " is not valid. Likely an internal error.");
967     }
968     *item = iter->second.get();
969     if ((*item)->exec != nullptr) {
970       return Status::OK();
971     }
972   }
973   // NOTE: We need to call CreateItem out of mu_ because creating an
974   // executor needs to call CreateKernel.
975   return CreateItem(item);
976 }
977 
ExecutorArgsFromOptions(const FunctionLibraryRuntime::Options & run_opts,CallFrameInterface * frame,Executor::Args * exec_args)978 void FunctionLibraryRuntimeImpl::ExecutorArgsFromOptions(
979     const FunctionLibraryRuntime::Options& run_opts, CallFrameInterface* frame,
980     Executor::Args* exec_args) {
981   // Inherit the step_id from the caller.
982   exec_args->step_id = run_opts.step_id;
983   exec_args->rendezvous = run_opts.rendezvous;
984   exec_args->stats_collector = run_opts.stats_collector;
985   exec_args->cancellation_manager = run_opts.cancellation_manager;
986   exec_args->step_container = run_opts.step_container;
987   if (run_opts.runner) {
988     exec_args->runner = *run_opts.runner;
989   } else {
990     exec_args->runner = default_runner_;
991   }
992   exec_args->collective_executor = run_opts.collective_executor;
993   exec_args->call_frame = frame;
994   exec_args->run_all_kernels_inline = run_opts.run_all_kernels_inline;
995 }
996 
RunRemote(const Options & opts,Handle handle,gtl::ArraySlice<Tensor> args,std::vector<Tensor> * rets,Item * item,DoneCallback done)997 void FunctionLibraryRuntimeImpl::RunRemote(const Options& opts, Handle handle,
998                                            gtl::ArraySlice<Tensor> args,
999                                            std::vector<Tensor>* rets,
1000                                            Item* item, DoneCallback done) {
1001   string target_device = parent_->GetDeviceName(handle);
1002   string source_device = opts.source_device;
1003   RendezvousInterface* rendezvous = opts.rendezvous;
1004   DeviceContext* device_context;
1005   Status s = parent_->GetDeviceContext(target_device, &device_context);
1006   if (!s.ok()) {
1007     done(s);
1008     return;
1009   }
1010   int64 src_incarnation, target_incarnation;
1011   s = parent_->GetDeviceIncarnation(source_device, &src_incarnation);
1012   s.Update(parent_->GetDeviceIncarnation(target_device, &target_incarnation));
1013   if (!s.ok()) {
1014     done(s);
1015     return;
1016   }
1017 
1018   const FunctionBody* fbody = GetFunctionBody(handle);
1019   FunctionCallFrame* frame =
1020       new FunctionCallFrame(fbody->arg_types, fbody->ret_types);
1021   Executor::Args* exec_args = new Executor::Args;
1022   ExecutorArgsFromOptions(opts, frame, exec_args);
1023 
1024   std::vector<AllocatorAttributes> args_alloc_attrs, rets_alloc_attrs;
1025   args_alloc_attrs.reserve(fbody->arg_types.size());
1026   rets_alloc_attrs.reserve(fbody->ret_types.size());
1027   // Note: Functions assume that int32's are always on host memory.
1028   for (const auto& arg_type : fbody->arg_types) {
1029     AllocatorAttributes arg_alloc_attrs;
1030     if (MTypeFromDType(arg_type) == HOST_MEMORY) {
1031       arg_alloc_attrs.set_on_host(true);
1032     }
1033     args_alloc_attrs.push_back(arg_alloc_attrs);
1034   }
1035   for (const auto& ret_type : fbody->ret_types) {
1036     AllocatorAttributes ret_alloc_attrs;
1037     if (MTypeFromDType(ret_type) == HOST_MEMORY) {
1038       ret_alloc_attrs.set_on_host(true);
1039     }
1040     rets_alloc_attrs.push_back(ret_alloc_attrs);
1041   }
1042 
1043   bool allow_dead_tensors = opts.allow_dead_tensors;
1044 
1045   // The ProcFLR sends the arguments to the function from the source_device to
1046   // the target_device. So here we receive those arguments. Similarly, when the
1047   // computation is done and stored in *rets, we send the return values back
1048   // to the source_device (caller) so that the ProcFLR can receive them later.
1049   std::vector<Tensor>* remote_args = new std::vector<Tensor>;
1050   ProcessFunctionLibraryRuntime::ReceiveTensorsAsync(
1051       source_device, target_device, "arg_", src_incarnation, args.size(),
1052       device_context, args_alloc_attrs, rendezvous, remote_args,
1053       [frame, remote_args, item, source_device, target_device,
1054        target_incarnation, rendezvous, device_context, rets, done, exec_args,
1055        rets_alloc_attrs, allow_dead_tensors](const Status& status) {
1056         Status s = status;
1057         if (s.ok()) {
1058           s = frame->SetArgs(*remote_args);
1059         }
1060         if (!s.ok()) {
1061           delete frame;
1062           delete remote_args;
1063           delete exec_args;
1064           done(s);
1065           return;
1066         }
1067         item->exec->RunAsync(
1068             *exec_args,
1069             [frame, rets, done, source_device, target_device,
1070              target_incarnation, rendezvous, device_context, remote_args,
1071              rets_alloc_attrs, allow_dead_tensors](const Status& status) {
1072               Status s = status;
1073               if (s.ok()) {
1074                 s = frame->ConsumeRetvals(rets, allow_dead_tensors);
1075               }
1076               delete frame;
1077               if (!s.ok()) {
1078                 delete remote_args;
1079                 done(s);
1080                 return;
1081               }
1082               s = ProcessFunctionLibraryRuntime::SendTensors(
1083                   target_device, source_device, "ret_", target_incarnation,
1084                   *rets, device_context, rets_alloc_attrs, rendezvous);
1085               delete remote_args;
1086               done(s);
1087             });
1088         delete exec_args;
1089       });
1090 }
1091 
Run(const Options & opts,Handle handle,gtl::ArraySlice<Tensor> args,std::vector<Tensor> * rets,DoneCallback done)1092 void FunctionLibraryRuntimeImpl::Run(const Options& opts, Handle handle,
1093                                      gtl::ArraySlice<Tensor> args,
1094                                      std::vector<Tensor>* rets,
1095                                      DoneCallback done) {
1096   if (opts.cancellation_manager && opts.cancellation_manager->IsCancelled()) {
1097     done(errors::Cancelled("Function was cancelled before it was started"));
1098     return;
1099   }
1100   Options run_opts = opts;
1101   if (opts.create_rendezvous) {
1102     auto* rendezvous = new PrivateIntraProcessRendezvous(device_mgr_);
1103     run_opts.rendezvous = rendezvous;
1104     run_opts.create_rendezvous = false;
1105     done = [done = std::move(done), rendezvous](const Status& status) mutable {
1106       delete rendezvous;
1107       done(status);
1108     };
1109   }
1110 
1111   LocalHandle local_handle = parent_->GetHandleOnDevice(device_name_, handle);
1112   if (local_handle == kInvalidLocalHandle) {
1113     parent_->Run(run_opts, handle, args, rets, done);
1114     return;
1115   }
1116 
1117   if (run_opts.runner == nullptr) {
1118     run_opts.runner = &default_runner_;
1119   }
1120   DCHECK(run_opts.runner != nullptr);
1121 
1122   Item* item = nullptr;
1123   Status s = GetOrCreateItem(local_handle, &item);
1124   if (!s.ok()) {
1125     done(s);
1126     return;
1127   }
1128 
1129   if (run_opts.remote_execution) {
1130     // NOTE(mrry): `RunRemote()` will set `exec_args->call_frame` for us.
1131     RunRemote(run_opts, handle, args, rets, item, std::move(done));
1132     return;
1133   }
1134 
1135   const FunctionBody* fbody = GetFunctionBody(handle);
1136   FunctionCallFrame* frame =
1137       new FunctionCallFrame(fbody->arg_types, fbody->ret_types);
1138   s = frame->SetArgs(args);
1139   if (!s.ok()) {
1140     delete frame;
1141     done(s);
1142     return;
1143   }
1144 
1145   profiler::TraceMeProducer activity(
1146       // To TraceMeConsumers in ExecutorState::Process/Finish.
1147       [&opts] {
1148         return profiler::TraceMeEncode("FunctionRun", {{"id", opts.step_id}});
1149       },
1150       profiler::ContextType::kTfExecutor, opts.step_id,
1151       profiler::TraceMeLevel::kInfo);
1152 
1153   Executor::Args exec_args;
1154   ExecutorArgsFromOptions(run_opts, frame, &exec_args);
1155 
1156   bool allow_dead_tensors = run_opts.allow_dead_tensors;
1157   item->exec->RunAsync(
1158       // Executor args
1159       exec_args,
1160       // Done callback.
1161       [frame, rets, done, allow_dead_tensors](const Status& status) {
1162         Status s = status;
1163         if (s.ok()) {
1164           s = frame->ConsumeRetvals(rets, allow_dead_tensors);
1165         }
1166         delete frame;
1167         done(s);
1168       });
1169 }
1170 
Run(const Options & opts,Handle handle,CallFrameInterface * frame,DoneCallback done)1171 void FunctionLibraryRuntimeImpl::Run(const Options& opts, Handle handle,
1172                                      CallFrameInterface* frame,
1173                                      DoneCallback done) {
1174   if (opts.cancellation_manager && opts.cancellation_manager->IsCancelled()) {
1175     done(errors::Cancelled(""));
1176     return;
1177   }
1178 
1179   Options run_opts = opts;
1180   if (opts.create_rendezvous) {
1181     auto* rendezvous = new PrivateIntraProcessRendezvous(device_mgr_);
1182     run_opts.rendezvous = rendezvous;
1183     run_opts.create_rendezvous = false;
1184     done = [done = std::move(done), rendezvous](const Status& status) mutable {
1185       delete rendezvous;
1186       done(status);
1187     };
1188   }
1189 
1190   LocalHandle local_handle = parent_->GetHandleOnDevice(
1191       device_name_, handle, /*include_multi_device=*/true);
1192   if (local_handle == kInvalidLocalHandle) {
1193     parent_->Run(run_opts, handle, frame, done);
1194     return;
1195   }
1196 
1197   if (opts.remote_execution) {
1198     // NOTE(mrry): This bit is only set for a local function when `parent_`
1199     // calls back into this class, and the current implementation of
1200     // `ProcessFunctionLibraryRuntime` currently always uses the vector-based
1201     // `args`/`rets` interface.
1202     done(errors::Unimplemented("Remote calling with CallFrameInterface"));
1203     return;
1204   }
1205 
1206   Item* item = nullptr;
1207   Status s = GetOrCreateItem(local_handle, &item);
1208   if (!s.ok()) {
1209     done(s);
1210     return;
1211   }
1212   if (run_opts.runner == nullptr) {
1213     run_opts.runner = &default_runner_;
1214   }
1215   DCHECK(run_opts.runner != nullptr);
1216 
1217   profiler::TraceMeProducer activity(
1218       // To TraceMeConsumers in ExecutorState::Process/Finish.
1219       [&opts] {
1220         return profiler::TraceMeEncode("FunctionRun", {{"id", opts.step_id}});
1221       },
1222       profiler::ContextType::kTfExecutor, opts.step_id,
1223       profiler::TraceMeLevel::kInfo);
1224 
1225   Executor::Args exec_args;
1226   ExecutorArgsFromOptions(run_opts, frame, &exec_args);
1227   item->exec->RunAsync(exec_args, std::move(done));
1228 }
1229 
PrepareRunSync(Handle handle,Options * run_opts,Item ** out_item,std::unique_ptr<PrivateIntraProcessRendezvous> * out_rendezvous)1230 Status FunctionLibraryRuntimeImpl::PrepareRunSync(
1231     Handle handle, Options* run_opts, Item** out_item,
1232     std::unique_ptr<PrivateIntraProcessRendezvous>* out_rendezvous) {
1233   if (run_opts->cancellation_manager &&
1234       run_opts->cancellation_manager->IsCancelled()) {
1235     return errors::Cancelled("");
1236   }
1237 
1238   if (run_opts->remote_execution) {
1239     // NOTE(mrry): This bit is only set for a local function when `parent_`
1240     // calls back into this class, and the current implementation of
1241     // `ProcessFunctionLibraryRuntime` currently always uses the asynchronous
1242     // Run() method.
1243     return errors::Unimplemented("Remote calling with RunSync()");
1244   }
1245 
1246   if (run_opts->create_rendezvous) {
1247     *out_rendezvous =
1248         absl::make_unique<PrivateIntraProcessRendezvous>(device_mgr_);
1249     run_opts->rendezvous = out_rendezvous->get();
1250     run_opts->create_rendezvous = false;
1251   }
1252 
1253   LocalHandle local_handle = parent_->GetHandleOnDevice(
1254       device_name_, handle, /*include_multi_device=*/true);
1255   if (local_handle == kInvalidLocalHandle) {
1256     *out_item = nullptr;
1257     return Status::OK();
1258   }
1259 
1260   TF_RETURN_IF_ERROR(GetOrCreateItem(local_handle, out_item));
1261 
1262   if (run_opts->runner == nullptr) {
1263     run_opts->runner = &default_runner_;
1264   }
1265   DCHECK(run_opts->runner != nullptr);
1266 
1267   return Status::OK();
1268 }
1269 
RunSync(Options opts,Handle handle,gtl::ArraySlice<Tensor> args,std::vector<Tensor> * rets)1270 Status FunctionLibraryRuntimeImpl::RunSync(Options opts, Handle handle,
1271                                            gtl::ArraySlice<Tensor> args,
1272                                            std::vector<Tensor>* rets) {
1273   Item* item = nullptr;
1274   std::unique_ptr<PrivateIntraProcessRendezvous> rendezvous;
1275   TF_RETURN_IF_ERROR(PrepareRunSync(handle, &opts, &item, &rendezvous));
1276   if (item == nullptr) {
1277     return parent_->RunSync(opts, handle, args, rets);
1278   }
1279 
1280   Executor::Args exec_args;
1281   const FunctionBody* fbody = GetFunctionBody(handle);
1282   FunctionCallFrame frame(fbody->arg_types, fbody->ret_types);
1283   TF_RETURN_IF_ERROR(frame.SetArgs(args));
1284   ExecutorArgsFromOptions(opts, &frame, &exec_args);
1285 
1286   TF_RETURN_IF_ERROR(item->exec->Run(exec_args));
1287   return frame.ConsumeRetvals(rets, opts.allow_dead_tensors);
1288 }
1289 
RunSync(Options opts,Handle handle,CallFrameInterface * call_frame)1290 Status FunctionLibraryRuntimeImpl::RunSync(Options opts, Handle handle,
1291                                            CallFrameInterface* call_frame) {
1292   Item* item = nullptr;
1293   std::unique_ptr<PrivateIntraProcessRendezvous> rendezvous;
1294   TF_RETURN_IF_ERROR(PrepareRunSync(handle, &opts, &item, &rendezvous));
1295   if (item == nullptr) {
1296     return parent_->RunSync(opts, handle, call_frame);
1297   }
1298 
1299   Executor::Args exec_args;
1300   ExecutorArgsFromOptions(opts, call_frame, &exec_args);
1301   return item->exec->Run(exec_args);
1302 }
1303 
IsStateful(const string & func) const1304 bool FunctionLibraryRuntimeImpl::IsStateful(const string& func) const {
1305   const OpDef* op_def;
1306   const Status s = base_lib_def_->LookUpOpDef(func, &op_def);
1307   return s.ok() && op_def->is_stateful();
1308 }
1309 
DebugString(Handle handle)1310 string FunctionLibraryRuntimeImpl::DebugString(Handle handle) {
1311   Item* item = nullptr;
1312   LocalHandle local_handle = parent_->GetHandleOnDevice(device_name_, handle);
1313   Status s = GetOrCreateItem(local_handle, &item);
1314   if (s.ok()) {
1315     if (item->graph) {
1316       return tensorflow::DebugString(item->graph.get());
1317     } else {
1318       return tensorflow::DebugString(item->func_graph->graph);
1319     }
1320   } else {
1321     return s.ToString();
1322   }
1323 }
1324 
Clone(std::unique_ptr<FunctionLibraryDefinition> * out_lib_def,std::unique_ptr<ProcessFunctionLibraryRuntime> * out_pflr,FunctionLibraryRuntime ** out_flr,bool skip_flib_def)1325 Status FunctionLibraryRuntimeImpl::Clone(
1326     std::unique_ptr<FunctionLibraryDefinition>* out_lib_def,
1327     std::unique_ptr<ProcessFunctionLibraryRuntime>* out_pflr,
1328     FunctionLibraryRuntime** out_flr, bool skip_flib_def) {
1329   TF_RETURN_IF_ERROR(parent_->Clone(env_, graph_def_version_,
1330                                     optimizer_.options(), out_lib_def, out_pflr,
1331                                     skip_flib_def));
1332   *out_flr = (*out_pflr)->GetFLR(device_->name());
1333   if (*out_flr != nullptr) {
1334     return Status::OK();
1335   } else {
1336     return errors::Internal("Cloning FunctionLibraryRuntime failed.");
1337   }
1338 }
1339 
1340 namespace {
1341 
1342 struct CustomCreatorSingleton {
1343   mutex mu;
1344   CustomKernelCreator* custom_creator = nullptr;
1345 
Settensorflow::__anon9a8797221111::CustomCreatorSingleton1346   void Set(CustomKernelCreator* cb) {
1347     mutex_lock l(mu);
1348     custom_creator = cb;
1349   }
1350 
Gettensorflow::__anon9a8797221111::CustomCreatorSingleton1351   CustomKernelCreator* Get() {
1352     mutex_lock l(mu);
1353     return custom_creator;
1354   }
1355 };
1356 
GetCustomCreatorSingleton()1357 CustomCreatorSingleton* GetCustomCreatorSingleton() {
1358   static CustomCreatorSingleton* ccs = new CustomCreatorSingleton;
1359   return ccs;
1360 }
1361 
1362 }  // namespace
1363 
GetDefaultCustomKernelCreator()1364 const CustomKernelCreator* GetDefaultCustomKernelCreator() {
1365   return GetCustomCreatorSingleton()->Get();
1366 }
1367 
RegisterDefaultCustomKernelCreator(CustomKernelCreator * c)1368 void RegisterDefaultCustomKernelCreator(CustomKernelCreator* c) {
1369   GetCustomCreatorSingleton()->Set(c);
1370 }
1371 
NewFunctionLibraryRuntime(const DeviceMgr * device_mgr,Env * env,const ConfigProto * config,Device * device,int graph_def_version,const FunctionLibraryDefinition * lib_def,thread::ThreadPool * thread_pool,const OptimizerOptions & optimizer_options,const SessionMetadata * session_metadata,ProcessFunctionLibraryRuntime * parent)1372 std::unique_ptr<FunctionLibraryRuntime> NewFunctionLibraryRuntime(
1373     const DeviceMgr* device_mgr, Env* env, const ConfigProto* config,
1374     Device* device, int graph_def_version,
1375     const FunctionLibraryDefinition* lib_def, thread::ThreadPool* thread_pool,
1376     const OptimizerOptions& optimizer_options,
1377     const SessionMetadata* session_metadata,
1378     ProcessFunctionLibraryRuntime* parent) {
1379   return std::unique_ptr<FunctionLibraryRuntime>(new FunctionLibraryRuntimeImpl(
1380       device_mgr, env, config, device, graph_def_version, lib_def, thread_pool,
1381       optimizer_options, session_metadata, parent));
1382 }
1383 
1384 class SymbolicGradientHelper {
1385  public:
SymbolicGradientHelper(const FunctionBody & f)1386   explicit SymbolicGradientHelper(const FunctionBody& f) : fbody_(&f) {}
1387   ~SymbolicGradientHelper() = default;
1388 
1389   std::unique_ptr<FunctionBody> Compute();
1390 
1391  private:
1392   const FunctionBody* fbody_;
1393 
1394   // Makes a copy of fbody_ in gbody.
1395   void Copy(FunctionBody* gbody);
1396 
1397   TF_DISALLOW_COPY_AND_ASSIGN(SymbolicGradientHelper);
1398 };
1399 
Copy(FunctionBody * gbody)1400 void SymbolicGradientHelper::Copy(FunctionBody* gbody) {
1401   const Graph& src = *(fbody_->graph);
1402   gbody->graph = new Graph(src.op_registry());
1403   Graph* dst = gbody->graph;
1404 
1405   std::vector<Node*> node_map(src.num_node_ids());
1406 
1407   // Copy just the fdef attributes (copy '_noinline' and other similar flags to
1408   // the gradient function body).
1409   *(gbody->fdef.mutable_attr()) = fbody_->fdef.attr();
1410 
1411   // Copy the nodes.
1412   node_map[src.source_node()->id()] = dst->source_node();
1413   node_map[src.sink_node()->id()] = dst->sink_node();
1414   for (Node* n : src.op_nodes()) {
1415     node_map[n->id()] = dst->CopyNode(n);
1416   }
1417 
1418   // Copy the edges.
1419   for (const Edge* e : src.edges()) {
1420     Node* src_copy = node_map[e->src()->id()];
1421     Node* dst_copy = node_map[e->dst()->id()];
1422     dst->AddEdge(src_copy, e->src_output(), dst_copy, e->dst_input());
1423   }
1424 
1425   // Save inputs in copied graph.
1426   CHECK_EQ(fbody_->arg_types.size(), fbody_->arg_nodes.size());
1427   gbody->arg_types = fbody_->arg_types;
1428   for (std::size_t i = 0; i < fbody_->arg_nodes.size(); ++i) {
1429     gbody->arg_nodes.push_back(node_map[fbody_->arg_nodes[i]->id()]);
1430   }
1431 
1432   // Save outputs in copied graph.
1433   CHECK_EQ(fbody_->ret_types.size(), fbody_->ret_nodes.size());
1434   gbody->ret_types = fbody_->ret_types;
1435   for (std::size_t i = 0; i < fbody_->ret_nodes.size(); ++i) {
1436     gbody->ret_nodes.push_back(node_map[fbody_->ret_nodes[i]->id()]);
1437   }
1438 }
1439 
Compute()1440 std::unique_ptr<FunctionBody> SymbolicGradientHelper::Compute() {
1441   FunctionBody* gbody = new FunctionBody;
1442   Copy(gbody);  // copy fbody_ into gbody.
1443 
1444   Graph* g = gbody->graph;
1445 
1446   const int num_y = static_cast<int>(gbody->ret_nodes.size());
1447 
1448   // Populate 'y_node_outputs_' with node function body outputs.
1449   // Populate 'y_grad_nodes' with initial gradient nodes for each return node
1450   // of the original function body (these will be 'arg' nodes in the function
1451   // gradient body).
1452   std::vector<NodeOut> y_node_outputs;
1453   y_node_outputs.reserve(num_y);
1454   std::vector<NodeOut> y_grad_node_outputs;
1455   y_grad_node_outputs.reserve(num_y);
1456   for (int i = 0; i < num_y; ++i) {
1457     Node* y = gbody->ret_nodes[i];
1458     y_node_outputs.push_back({y, 0});
1459     DCHECK_EQ(y->type_string(), kRetOp);
1460     const DataType dtype = y->input_type(0);
1461     const int index = static_cast<int>(gbody->arg_nodes.size());
1462     Node* dy = AddArg(g, dtype, index);
1463     gbody->arg_types.push_back(dtype);
1464     gbody->arg_nodes.push_back(dy);
1465     y_grad_node_outputs.push_back({dy, 0});
1466   }
1467 
1468   // Populate 'x_nodes' with function args (excluding 'y_grad_node_outputs').
1469   const size_t num_x = fbody_->arg_nodes.size();
1470   std::vector<NodeOut> x_node_outputs;
1471   x_node_outputs.reserve(num_x);
1472   for (size_t i = 0; i < fbody_->arg_nodes.size(); ++i) {
1473     x_node_outputs.push_back({gbody->arg_nodes[i], 0});
1474   }
1475 
1476   // Call AddSymbolicGradients which will add nodes to graph 'g' that
1477   // compute the function gradient (adding an entry in 'x_grad_node_outputs'
1478   // for each node in 'x_node_outputs').
1479   std::vector<NodeOut> x_grad_node_outputs;
1480   TF_CHECK_OK(AddSymbolicGradients(y_node_outputs, x_node_outputs,
1481                                    y_grad_node_outputs, &x_grad_node_outputs,
1482                                    g));
1483 
1484   // Remove the old return nodes from the function body.
1485   for (Node* n : gbody->ret_nodes) {
1486     g->RemoveNode(n);
1487   }
1488   gbody->ret_types = fbody_->arg_types;
1489   // TODO(apassos): use the right dtype for gradients of  resource variables
1490   for (int i = 0; i < gbody->ret_types.size(); ++i) {
1491     if (gbody->ret_types[i] == DT_RESOURCE) {
1492       gbody->ret_types[i] = DT_FLOAT;
1493     }
1494   }
1495   gbody->ret_nodes.clear();
1496   // Add new return nodes to the function gradient body for each node
1497   // in 'x_grad_nodes'.
1498   const int arg_types_size = static_cast<int>(fbody_->arg_types.size());
1499   for (int i = 0; i < arg_types_size; ++i) {
1500     Endpoint grad = {x_grad_node_outputs[i].node, x_grad_node_outputs[i].index};
1501     Node* ret = AddRet(g, grad, i);
1502     gbody->ret_nodes.push_back(ret);
1503   }
1504 
1505   return std::unique_ptr<FunctionBody>(gbody);
1506 }
1507 
SymbolicGradient(const FunctionBody & f)1508 std::unique_ptr<FunctionBody> SymbolicGradient(const FunctionBody& f) {
1509   return SymbolicGradientHelper(f).Compute();
1510 }
1511 
1512 }  // end namespace tensorflow
1513