1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/core/common_runtime/function.h"
17
18 #include <deque>
19 #include <vector>
20
21 #include "absl/algorithm/container.h"
22 #include "absl/memory/memory.h"
23 #include "absl/strings/str_cat.h"
24 #include "tensorflow/core/common_runtime/device.h"
25 #include "tensorflow/core/common_runtime/executor.h"
26 #include "tensorflow/core/common_runtime/executor_factory.h"
27 #include "tensorflow/core/common_runtime/gradients.h"
28 #include "tensorflow/core/common_runtime/graph_constructor.h"
29 #include "tensorflow/core/common_runtime/graph_optimizer.h"
30 #include "tensorflow/core/common_runtime/inline_function_utils.h"
31 #include "tensorflow/core/common_runtime/memory_types.h"
32 #include "tensorflow/core/common_runtime/process_function_library_runtime.h"
33 #include "tensorflow/core/common_runtime/rendezvous_mgr.h"
34 #include "tensorflow/core/framework/collective.h"
35 #include "tensorflow/core/framework/function.h"
36 #include "tensorflow/core/framework/node_def.pb.h"
37 #include "tensorflow/core/framework/node_def_util.h"
38 #include "tensorflow/core/framework/op.h"
39 #include "tensorflow/core/framework/op_kernel.h"
40 #include "tensorflow/core/framework/versions.pb.h"
41 #include "tensorflow/core/graph/algorithm.h"
42 #include "tensorflow/core/graph/control_flow.h"
43 #include "tensorflow/core/graph/node_builder.h"
44 #include "tensorflow/core/graph/optimizer_cse.h"
45 #include "tensorflow/core/lib/core/threadpool.h"
46 #include "tensorflow/core/lib/gtl/map_util.h"
47 #include "tensorflow/core/platform/macros.h"
48 #include "tensorflow/core/profiler/lib/connected_traceme.h"
49 #include "tensorflow/core/profiler/lib/traceme.h"
50 #include "tensorflow/core/protobuf/config.pb.h"
51
52 // See core/kernels/function_ops.cc for related kernels.
53
54 namespace tensorflow {
55
56 // A few string constant used throughout this module.
57 static constexpr const char* const kArgOp = FunctionLibraryDefinition::kArgOp;
58 static constexpr const char* const kDeviceArgOp =
59 FunctionLibraryDefinition::kDeviceArgOp;
60 static constexpr const char* const kRetOp = FunctionLibraryDefinition::kRetOp;
61 static constexpr const char* const kDeviceRetOp =
62 FunctionLibraryDefinition::kDeviceRetOp;
63 static constexpr const char* const kGradientOp =
64 FunctionLibraryDefinition::kGradientOp;
65 static constexpr const char* const kNodeLabel = "Func";
66 static constexpr const char* const kFuncAttr =
67 FunctionLibraryDefinition::kFuncAttr;
68
69 // Represents the index-th output of a node.
70 struct Endpoint {
71 Node* node;
72 int index;
73
74 // Returns the string name represents this endpoint.
nametensorflow::Endpoint75 string name() const {
76 if (index == 0) {
77 return node->name();
78 } else {
79 return strings::StrCat(node->name(), ":", index);
80 }
81 }
82
dtypetensorflow::Endpoint83 DataType dtype() const { return node->output_type(index); }
84 };
85
86 struct EndpointHash {
operator ()tensorflow::EndpointHash87 uint64 operator()(const Endpoint& x) const {
88 return Hash64(reinterpret_cast<const char*>(&x.node), sizeof(Node*),
89 x.index);
90 }
91 };
92
93 struct EndpointEq {
operator ()tensorflow::EndpointEq94 bool operator()(const Endpoint& x, const Endpoint& y) const {
95 return (x.node == y.node) && (x.index == y.index);
96 }
97 };
98
99 // The following Add* routines are used to add a few graph nodes while
100 // functions are transformed.
AddArg(Graph * g,DataType dtype,int index)101 static Node* AddArg(Graph* g, DataType dtype, int index) {
102 DCHECK_LT(0, dtype);
103 DCHECK_LT(dtype, DT_FLOAT_REF);
104 NodeDef ndef;
105 ndef.set_name(g->NewName(kNodeLabel));
106 ndef.set_op(kArgOp);
107 AddNodeAttr("T", dtype, &ndef);
108 AddNodeAttr("index", index, &ndef);
109 Status s;
110 Node* ret = g->AddNode(ndef, &s);
111 TF_CHECK_OK(s);
112 return ret;
113 }
114
AddRet(Graph * g,Endpoint input,int index)115 static Node* AddRet(Graph* g, Endpoint input, int index) {
116 DCHECK_LT(0, input.dtype());
117 DCHECK_LT(input.dtype(), DT_FLOAT_REF);
118 NodeDef ndef;
119 ndef.set_name(g->NewName(kNodeLabel));
120 ndef.set_op(kRetOp);
121 ndef.add_input(input.name());
122 AddNodeAttr("T", input.dtype(), &ndef);
123 AddNodeAttr("index", index, &ndef);
124 Status s;
125 Node* ret = g->AddNode(ndef, &s);
126 TF_CHECK_OK(s);
127 g->AddEdge(input.node, input.index, ret, 0);
128 return ret;
129 }
130
131 // FunctionLibraryRuntime implementation that forwards all the function calls to
132 // the base runtime implementation, and only overrides FunctionLibraryDefinition
133 // in calls to Instantiate (if caller doesn't provide the
134 // InstantiateOptions::lib_def option).
135 //
136 // When the function library runtime (FunctionLibraryRuntimeImpl specifically)
137 // instantiates a function into a Graph object, it also creates an Executor for
138 // it. That executor has a pointer to the function library runtime instance,
139 // that is used to instantiate all nested function calls.
140 //
141 // The function library definition used to instantiate the function must be
142 // preserved in the Executor's function library runtime.
143 //
144 // IMPORTANT: This runtime is intended for use only in executors created for
145 // functions instantiated into a graph in FunctionLibraryRuntimeImpl.
146 class FunctionLibraryRuntimeOverlay : public FunctionLibraryRuntime {
147 public:
FunctionLibraryRuntimeOverlay(FunctionLibraryRuntime * base_flr,const FunctionLibraryDefinition * lib_def)148 FunctionLibraryRuntimeOverlay(FunctionLibraryRuntime* base_flr,
149 const FunctionLibraryDefinition* lib_def)
150 : base_flr_(base_flr), lib_def_(lib_def) {}
151 ~FunctionLibraryRuntimeOverlay() override;
152
153 Status Instantiate(const string& function_name, AttrSlice attrs,
154 const InstantiateOptions& options,
155 Handle* handle) override;
156
157 Status ReleaseHandle(Handle handle) override;
158
159 const FunctionBody* GetFunctionBody(Handle h) override;
160
161 Status GetRetTypes(Handle h, DataTypeVector* ret_types) override;
162
163 void Run(const Options& opts, Handle handle, gtl::ArraySlice<Tensor> args,
164 std::vector<Tensor>* rets, DoneCallback done) override;
165
166 void Run(const Options& opts, Handle handle, CallFrameInterface* call_frame,
167 DoneCallback done) override;
168
169 Status RunSync(Options opts, Handle handle, gtl::ArraySlice<Tensor> args,
170 std::vector<Tensor>* rets) override;
171
172 Status RunSync(Options opts, Handle handle,
173 CallFrameInterface* frame) override;
174
175 Status CreateKernel(const std::shared_ptr<const NodeProperties>& props,
176 OpKernel** kernel) override;
177
178 bool IsStateful(const string& function_name) const override;
179
180 const FunctionLibraryDefinition* GetFunctionLibraryDefinition()
181 const override;
182
183 Env* env() override;
184 const ConfigProto* const config_proto() override;
185 Device* device() override;
186 const Device* device() const override;
187 std::function<void(std::function<void()>)>* runner() override;
188 const DeviceMgr* device_mgr() const override;
189
190 string DebugString(Handle handle) override;
191 int graph_def_version() const override;
192
193 Status Clone(std::unique_ptr<FunctionLibraryDefinition>* out_lib_def,
194 std::unique_ptr<ProcessFunctionLibraryRuntime>* out_pflr,
195 FunctionLibraryRuntime** out_flr,
196 bool skip_flib_def = false) override;
197
198 private:
199 FunctionLibraryRuntime* base_flr_; // not owned
200 const FunctionLibraryDefinition* lib_def_; // not owned
201 };
202
203 FunctionLibraryRuntimeOverlay::~FunctionLibraryRuntimeOverlay() = default;
204
Instantiate(const string & function_name,AttrSlice attrs,const InstantiateOptions & options,Handle * handle)205 Status FunctionLibraryRuntimeOverlay::Instantiate(
206 const string& function_name, AttrSlice attrs,
207 const InstantiateOptions& options, Handle* handle) {
208 // We automatically set the `lib_def` option for all instantiations, if the
209 // caller doesn't set this option explicitly.
210 if (!options.lib_def && lib_def_) {
211 InstantiateOptions options_copy = options;
212 options_copy.lib_def = lib_def_;
213 return base_flr_->Instantiate(function_name, attrs, options_copy, handle);
214 } else {
215 return base_flr_->Instantiate(function_name, attrs, options, handle);
216 }
217 }
218
ReleaseHandle(Handle handle)219 Status FunctionLibraryRuntimeOverlay::ReleaseHandle(Handle handle) {
220 return base_flr_->ReleaseHandle(handle);
221 }
222
GetFunctionBody(Handle h)223 const FunctionBody* FunctionLibraryRuntimeOverlay::GetFunctionBody(Handle h) {
224 return base_flr_->GetFunctionBody(h);
225 }
226
GetRetTypes(Handle h,DataTypeVector * ret_types)227 Status FunctionLibraryRuntimeOverlay::GetRetTypes(Handle h,
228 DataTypeVector* ret_types) {
229 return base_flr_->GetRetTypes(h, ret_types);
230 }
231
Run(const Options & opts,Handle handle,gtl::ArraySlice<Tensor> args,std::vector<Tensor> * rets,DoneCallback done)232 void FunctionLibraryRuntimeOverlay::Run(const Options& opts, Handle handle,
233 gtl::ArraySlice<Tensor> args,
234 std::vector<Tensor>* rets,
235 DoneCallback done) {
236 base_flr_->Run(opts, handle, args, rets, std::move(done));
237 }
238
Run(const Options & opts,Handle handle,CallFrameInterface * call_frame,DoneCallback done)239 void FunctionLibraryRuntimeOverlay::Run(const Options& opts, Handle handle,
240 CallFrameInterface* call_frame,
241 DoneCallback done) {
242 base_flr_->Run(opts, handle, call_frame, std::move(done));
243 }
244
RunSync(Options opts,Handle handle,gtl::ArraySlice<Tensor> args,std::vector<Tensor> * rets)245 Status FunctionLibraryRuntimeOverlay::RunSync(Options opts, Handle handle,
246 gtl::ArraySlice<Tensor> args,
247 std::vector<Tensor>* rets) {
248 return base_flr_->RunSync(std::move(opts), handle, args, rets);
249 }
250
RunSync(Options opts,Handle handle,CallFrameInterface * call_frame)251 Status FunctionLibraryRuntimeOverlay::RunSync(Options opts, Handle handle,
252 CallFrameInterface* call_frame) {
253 return base_flr_->RunSync(std::move(opts), handle, call_frame);
254 }
255
CreateKernel(const std::shared_ptr<const NodeProperties> &,OpKernel **)256 Status FunctionLibraryRuntimeOverlay::CreateKernel(
257 const std::shared_ptr<const NodeProperties>&, OpKernel**) {
258 // We don't have access to base_lib_def_ in base function library runtime (aka
259 // FunctionLibraryRuntimeImpl), so to make sure we do not create a kernel with
260 // the wrong lib_def we just disable creation of new kernels through overlays.
261 //
262 // When we call Instantiate from the base runtime with the lib_def option,
263 // the base runtime implementation is responsible for correctly passing it
264 // through to all kernel constructions.
265 return errors::Internal(
266 "Overlay function library runtime doesn't support kernel creation.");
267 }
268
IsStateful(const string & function_name) const269 bool FunctionLibraryRuntimeOverlay::IsStateful(
270 const string& function_name) const {
271 // Important: we do not forward lookup to the base FLR.
272 const OpDef* op_def;
273 const Status s = lib_def_->LookUpOpDef(function_name, &op_def);
274 return s.ok() && op_def->is_stateful();
275 }
276
env()277 Env* FunctionLibraryRuntimeOverlay::env() { return base_flr_->env(); }
278
config_proto()279 const ConfigProto* const FunctionLibraryRuntimeOverlay::config_proto() {
280 return base_flr_->config_proto();
281 }
282
device()283 Device* FunctionLibraryRuntimeOverlay::device() { return base_flr_->device(); }
284
device() const285 const Device* FunctionLibraryRuntimeOverlay::device() const {
286 return base_flr_->device();
287 }
288
289 std::function<void(std::function<void()>)>*
runner()290 FunctionLibraryRuntimeOverlay::runner() {
291 return base_flr_->runner();
292 }
293
device_mgr() const294 const DeviceMgr* FunctionLibraryRuntimeOverlay::device_mgr() const {
295 return base_flr_->device_mgr();
296 }
297
298 const FunctionLibraryDefinition*
GetFunctionLibraryDefinition() const299 FunctionLibraryRuntimeOverlay::GetFunctionLibraryDefinition() const {
300 return lib_def_ ? lib_def_ : base_flr_->GetFunctionLibraryDefinition();
301 }
302
DebugString(Handle handle)303 string FunctionLibraryRuntimeOverlay::DebugString(Handle handle) {
304 return base_flr_->DebugString(handle);
305 }
306
graph_def_version() const307 int FunctionLibraryRuntimeOverlay::graph_def_version() const {
308 return base_flr_->graph_def_version();
309 }
310
Clone(std::unique_ptr<FunctionLibraryDefinition> * out_lib_def,std::unique_ptr<ProcessFunctionLibraryRuntime> * out_pflr,FunctionLibraryRuntime ** out_flr,bool skip_flib_def)311 Status FunctionLibraryRuntimeOverlay::Clone(
312 std::unique_ptr<FunctionLibraryDefinition>* out_lib_def,
313 std::unique_ptr<ProcessFunctionLibraryRuntime>* out_pflr,
314 FunctionLibraryRuntime** out_flr, bool skip_flib_def) {
315 // NOTE(ezhulenev): The cloned FunctionLibraryRuntime will be missing the
316 // FunctionLibraryDefinition override, but that's ok because we anyway do not
317 // copy / clone instantiated items from the base FLR.
318 return base_flr_->Clone(out_lib_def, out_pflr, out_flr, skip_flib_def);
319 }
320
321 class FunctionLibraryRuntimeImpl : public FunctionLibraryRuntime {
322 public:
323 FunctionLibraryRuntimeImpl(const DeviceMgr* dmgr, Env* env,
324 const ConfigProto* config, Device* device,
325 int graph_def_version,
326 const FunctionLibraryDefinition* lib_def,
327 thread::ThreadPool* default_thread_pool,
328 const OptimizerOptions& optimizer_options,
329 const SessionMetadata* session_metadata,
330 ProcessFunctionLibraryRuntime* parent);
331
332 ~FunctionLibraryRuntimeImpl() override;
333
334 Status Instantiate(const string& function_name, AttrSlice attrs,
335 const InstantiateOptions& options,
336 Handle* handle) override;
337
338 Status ReleaseHandle(Handle handle) override;
339
340 const FunctionBody* GetFunctionBody(Handle handle) override;
341
342 Status GetRetTypes(Handle handle, DataTypeVector* ret_types) override;
343
344 Status CreateKernel(const std::shared_ptr<const NodeProperties>& props,
345 OpKernel** kernel) override;
346
347 void Run(const Options& opts, Handle handle, gtl::ArraySlice<Tensor> args,
348 std::vector<Tensor>* rets, DoneCallback done) override;
349 void Run(const Options& opts, Handle handle, CallFrameInterface* frame,
350 DoneCallback done) override;
351 Status RunSync(Options opts, Handle handle, gtl::ArraySlice<Tensor> args,
352 std::vector<Tensor>* rets) override;
353 Status RunSync(Options opts, Handle handle,
354 CallFrameInterface* call_frame) override;
355
356 bool IsStateful(const string& function) const override;
357
GetFunctionLibraryDefinition() const358 const FunctionLibraryDefinition* GetFunctionLibraryDefinition()
359 const override {
360 return base_lib_def_;
361 }
362
device()363 Device* device() override { return device_; }
device() const364 const Device* device() const override { return device_; }
365
runner()366 std::function<void(std::function<void()>)>* runner() override {
367 return &default_runner_;
368 }
369
device_mgr() const370 const DeviceMgr* device_mgr() const override { return device_mgr_; }
env()371 Env* env() override { return env_; }
config_proto()372 const ConfigProto* const config_proto() override { return config_; }
graph_def_version() const373 int graph_def_version() const override { return graph_def_version_; }
374
375 string DebugString(Handle h) override;
376
377 Status Clone(std::unique_ptr<FunctionLibraryDefinition>* out_lib_def,
378 std::unique_ptr<ProcessFunctionLibraryRuntime>* out_pflr,
379 FunctionLibraryRuntime** out_flr,
380 bool skip_flib_def = false) override;
381
382 private:
383 typedef FunctionLibraryRuntimeImpl ME;
384
385 const DeviceMgr* const device_mgr_;
386 Device* const device_;
387 Env* const env_;
388 const ConfigProto* const config_;
389 const int graph_def_version_;
390 const FunctionLibraryDefinition* const base_lib_def_;
391 GraphOptimizer optimizer_;
392 const SessionMetadata* const session_metadata_;
393 Executor::Args::Runner default_runner_;
394 const string device_name_;
395
396 std::function<Status(const string&, const OpDef**)> get_func_sig_;
397 std::function<Status(const std::shared_ptr<const NodeProperties>&,
398 OpKernel**)>
399 create_kernel_;
400
401 mutable mutex mu_;
402
403 int next_handle_ TF_GUARDED_BY(mu_);
404
405 // The instantiated and transformed function is encoded as a Graph
406 // object, and an executor is created for the graph.
407 struct Item {
408 uint64 instantiation_counter = 0;
409 std::unique_ptr<const Graph> graph = nullptr;
410 const FunctionLibraryDefinition* lib_def = nullptr; // Not owned.
411 FunctionBody* func_graph = nullptr;
412 Executor* exec = nullptr;
413 FunctionLibraryRuntimeOverlay* overlay_flr = nullptr;
414 string executor_type;
415
~Itemtensorflow::FunctionLibraryRuntimeImpl::Item416 ~Item() {
417 delete this->func_graph;
418 delete this->exec;
419 delete this->overlay_flr;
420 }
421 };
422 std::unique_ptr<std::unordered_map<Handle, std::unique_ptr<Item>>> items_
423 TF_GUARDED_BY(mu_);
424
425 ProcessFunctionLibraryRuntime* parent_ = nullptr; // not owned.
426
427 // Overloads the CreateKernel method, providing a FunctionLibraryRuntime
428 // to use for kernel creation and execution. In particular, this method can
429 // accept a FunctionLibraryRuntimeOverlay that overlays a different
430 // FunctionLibraryDefinition.
431 Status CreateKernel(const std::shared_ptr<const NodeProperties>& props,
432 FunctionLibraryRuntime* flr, OpKernel** kernel);
433 Status FunctionDefToBody(const FunctionDef& fdef, AttrSlice attrs,
434 const FunctionLibraryDefinition* lib_def,
435 std::unique_ptr<FunctionBody>* fbody);
436 Status CreateItem(Item** item);
437 Status GetOrCreateItem(LocalHandle local_handle, Item** item);
438 Status InstantiateSymbolicGradient(const NameAttrList& func,
439 const FunctionLibraryDefinition* lib_def,
440 std::unique_ptr<FunctionBody>* g_body);
441 bool IsLocalTarget(const InstantiateOptions& options) const;
442 AttrValueMap FixAttrs(const AttrSlice& attrs);
443 void RunRemote(const Options& opts, Handle handle,
444 gtl::ArraySlice<Tensor> args, std::vector<Tensor>* rets,
445 Item* item, DoneCallback done);
446
447 Status PrepareRunSync(
448 Handle handle, Options* run_opts, Item** out_item,
449 std::unique_ptr<PrivateIntraProcessRendezvous>* out_rendezvous);
450
451 void ExecutorArgsFromOptions(const FunctionLibraryRuntime::Options& run_opts,
452 CallFrameInterface* frame,
453 Executor::Args* exec_args);
454
455 TF_DISALLOW_COPY_AND_ASSIGN(FunctionLibraryRuntimeImpl);
456 };
457
FunctionLibraryRuntimeImpl(const DeviceMgr * dmgr,Env * env,const ConfigProto * config,Device * device,int graph_def_version,const FunctionLibraryDefinition * lib_def,thread::ThreadPool * default_thread_pool,const OptimizerOptions & optimizer_options,const SessionMetadata * session_metadata,ProcessFunctionLibraryRuntime * parent)458 FunctionLibraryRuntimeImpl::FunctionLibraryRuntimeImpl(
459 const DeviceMgr* dmgr, Env* env, const ConfigProto* config, Device* device,
460 int graph_def_version, const FunctionLibraryDefinition* lib_def,
461 thread::ThreadPool* default_thread_pool,
462 const OptimizerOptions& optimizer_options,
463 const SessionMetadata* session_metadata,
464 ProcessFunctionLibraryRuntime* parent)
465 : device_mgr_(dmgr),
466 device_(device),
467 env_(env),
468 config_(config),
469 graph_def_version_(graph_def_version),
470 base_lib_def_(lib_def),
471 optimizer_(optimizer_options),
472 session_metadata_(session_metadata),
473 default_runner_(nullptr),
474 device_name_(device_ == nullptr
475 ? ProcessFunctionLibraryRuntime::kDefaultFLRDevice
476 : device_->name()),
477 next_handle_(0),
478 items_(new std::unordered_map<Handle, std::unique_ptr<Item>>),
479 parent_(parent) {
480 get_func_sig_ = [this](const string& op, const OpDef** sig) {
481 return base_lib_def_->LookUpOpDef(op, sig);
482 };
483 create_kernel_ = [this](const std::shared_ptr<const NodeProperties>& props,
484 OpKernel** kernel) {
485 return CreateKernel(props, kernel);
486 };
487 thread::ThreadPool* pool = nullptr;
488 if (device_ != nullptr) {
489 pool = device_->tensorflow_device_thread_pool();
490 }
491 if (pool == nullptr) {
492 pool = default_thread_pool;
493 }
494 if (pool != nullptr) {
495 default_runner_ = [pool](Executor::Args::Closure c) {
496 pool->Schedule(std::move(c));
497 };
498 }
499 }
500
~FunctionLibraryRuntimeImpl()501 FunctionLibraryRuntimeImpl::~FunctionLibraryRuntimeImpl() {
502 // Deleting the items_ list will delete all the function handles registered in
503 // this object. A function may contains a few sub-functions which have also
504 // been registered in this object. Deleting the parent function will call
505 // ReleaseHandle in this class again for each of the sub-functions. These
506 // circular calls may cause segfault since the items_ may have already been
507 // partially deleted when releasing handles of sub-functions. Explicitly
508 // release items_ here and check it in ReleaseHandle to avoid this.
509 items_.reset();
510 }
511
512 // An asynchronous op kernel which executes an instantiated function
513 // defined in a library.
514 class CallOp : public AsyncOpKernel {
515 public:
CallOp(FunctionLibraryRuntime::Handle handle,OpKernelConstruction * ctx)516 CallOp(FunctionLibraryRuntime::Handle handle, OpKernelConstruction* ctx)
517 : AsyncOpKernel(ctx), handle_(handle) {}
518
~CallOp()519 ~CallOp() override {
520 // TODO(iga): Release the cached handle_
521 }
522
ComputeAsync(OpKernelContext * ctx,DoneCallback done)523 void ComputeAsync(OpKernelContext* ctx, DoneCallback done) override {
524 FunctionLibraryRuntime* lib = ctx->function_library();
525 OP_REQUIRES_ASYNC(ctx, lib != nullptr,
526 errors::Internal("No function library is provided."),
527 done);
528 FunctionLibraryRuntime::Options opts;
529 opts.rendezvous = ctx->rendezvous();
530 opts.cancellation_manager = ctx->cancellation_manager();
531 opts.step_container = ctx->step_container();
532 opts.stats_collector = ctx->stats_collector();
533 opts.runner = ctx->runner();
534 opts.run_all_kernels_inline = ctx->run_all_kernels_inline();
535 opts.collective_executor = ctx->collective_executor();
536 std::vector<Tensor> args;
537 args.reserve(ctx->num_inputs());
538 for (int i = 0; i < ctx->num_inputs(); ++i) {
539 args.push_back(ctx->input(i));
540 }
541 std::vector<Tensor>* rets = new std::vector<Tensor>;
542 profiler::TraceMe trace_me(
543 [&] {
544 return absl::StrCat("CallOp #parent_step_id=", ctx->step_id(),
545 ",function_step_id=", opts.step_id, "#");
546 },
547 /*level=*/2);
548 lib->Run(opts, handle_, args, rets,
549 [ctx, done, rets](const Status& status) {
550 if (!status.ok()) {
551 ctx->SetStatus(status);
552 } else {
553 const int ret_size = static_cast<int>(rets->size());
554 CHECK_EQ(ret_size, ctx->num_outputs());
555 for (int i = 0; i < ret_size; ++i) {
556 ctx->set_output(i, (*rets)[i]);
557 }
558 }
559 delete rets;
560 done();
561 });
562 }
563
564 private:
565 FunctionLibraryRuntime::Handle handle_;
566
567 TF_DISALLOW_COPY_AND_ASSIGN(CallOp);
568 };
569
GetFunctionBody(Handle h)570 const FunctionBody* FunctionLibraryRuntimeImpl::GetFunctionBody(Handle h) {
571 LocalHandle local_handle = parent_->GetHandleOnDevice(device_name_, h);
572 if (local_handle == kInvalidLocalHandle) {
573 LOG(ERROR) << "Could not find Handle: " << h
574 << " on device: " << device_name_;
575 return nullptr;
576 }
577
578 tf_shared_lock l(mu_);
579 auto iter = items_->find(local_handle);
580 CHECK(iter != items_->end());
581 return iter->second->func_graph;
582 }
583
GetRetTypes(Handle h,DataTypeVector * ret_types)584 Status FunctionLibraryRuntimeImpl::GetRetTypes(Handle h,
585 DataTypeVector* ret_types) {
586 if (parent_->IsMultiDevice(h)) {
587 return parent_->GetRetTypes(h, ret_types);
588 }
589 LocalHandle local_handle = parent_->GetHandleOnDevice(device_name_, h);
590 if (local_handle == kInvalidLocalHandle) {
591 return errors::InvalidArgument("Handle ", h, " not found.");
592 }
593 const FunctionBody* fbody = GetFunctionBody(h);
594 *ret_types = fbody->ret_types;
595 return Status::OK();
596 }
597
CreateKernel(const std::shared_ptr<const NodeProperties> & props,OpKernel ** kernel)598 Status FunctionLibraryRuntimeImpl::CreateKernel(
599 const std::shared_ptr<const NodeProperties>& props, OpKernel** kernel) {
600 return CreateKernel(props, this, kernel);
601 }
602
CreateKernel(const std::shared_ptr<const NodeProperties> & props,FunctionLibraryRuntime * flr,OpKernel ** kernel)603 Status FunctionLibraryRuntimeImpl::CreateKernel(
604 const std::shared_ptr<const NodeProperties>& props,
605 FunctionLibraryRuntime* flr, OpKernel** kernel) {
606 // If a custom kernel creator is given, try that.
607 Status s;
608 const CustomKernelCreator* custom_kernel_creator =
609 GetDefaultCustomKernelCreator();
610 if (custom_kernel_creator &&
611 custom_kernel_creator->CanCreateKernel(*this, props)) {
612 std::unique_ptr<OpKernel> ret;
613 s = custom_kernel_creator->CreateKernel(this, props, &ret);
614 if (s.ok()) {
615 *kernel = ret.release();
616 } else {
617 VLOG(2) << "Custom creator error: " << s;
618 }
619 return s;
620 }
621
622 const FunctionLibraryDefinition* lib_def =
623 flr->GetFunctionLibraryDefinition();
624 if (lib_def->Find(props->node_def.op()) == nullptr) {
625 // A primitive operation. Creates the registered kernel.
626 return CreateNonCachedKernel(device_, flr, props, graph_def_version_,
627 kernel);
628 }
629
630 // Try to instantiate this function for the func/attr. Maybe it's
631 // cached already.
632 InstantiateOptions options;
633 if (lib_def != base_lib_def_) {
634 options.lib_def = lib_def;
635 }
636 Handle handle;
637 TF_RETURN_IF_ERROR(Instantiate(props->node_def.op(),
638 AttrSlice(&props->node_def.attr()), options,
639 &handle));
640
641 const FunctionBody* fbody = GetFunctionBody(handle);
642 CHECK_NOTNULL(fbody);
643
644 // TODO(zhifengc): For now, we assume int32 and resources are always on host
645 // memory and other types are always on device memory. We should do type
646 // inference over function body to derive the correct input/output memory
647 // types.
648 MemoryTypeVector input_memory_types;
649 for (const auto& t : fbody->arg_types) {
650 input_memory_types.push_back(MTypeFromDType(t));
651 }
652 MemoryTypeVector output_memory_types;
653 for (const auto& t : fbody->ret_types) {
654 output_memory_types.push_back(MTypeFromDType(t));
655 }
656
657 // Constructs a CallOp kernel for running the instantiated function.
658 auto device_type = DeviceType(device_->attributes().device_type());
659 auto new_props = std::make_shared<NodeProperties>(
660 &fbody->fdef.signature(), props->node_def, fbody->arg_types,
661 fbody->ret_types);
662 OpKernelConstruction construction(
663 device_type, device_, device_->GetAllocator(AllocatorAttributes()), flr,
664 device_->resource_manager(), props, input_memory_types,
665 output_memory_types, graph_def_version_, &s);
666 if (s.ok()) {
667 *kernel = new CallOp(handle, &construction);
668 }
669 return s;
670 }
671
FunctionDefToBody(const FunctionDef & fdef,AttrSlice attrs,const FunctionLibraryDefinition * lib_def,std::unique_ptr<FunctionBody> * fbody)672 Status FunctionLibraryRuntimeImpl::FunctionDefToBody(
673 const FunctionDef& fdef, AttrSlice attrs,
674 const FunctionLibraryDefinition* lib_def,
675 std::unique_ptr<FunctionBody>* fbody) {
676 if (lib_def == base_lib_def_) {
677 return FunctionDefToBodyHelper(fdef, attrs, lib_def, get_func_sig_, fbody);
678 } else {
679 auto get_func_sig = [lib_def](const string& op, const OpDef** sig) {
680 return lib_def->LookUpOpDef(op, sig);
681 };
682 return FunctionDefToBodyHelper(fdef, attrs, lib_def, get_func_sig, fbody);
683 }
684 }
685
InstantiateSymbolicGradient(const NameAttrList & func,const FunctionLibraryDefinition * lib_def,std::unique_ptr<FunctionBody> * g_body)686 Status FunctionLibraryRuntimeImpl::InstantiateSymbolicGradient(
687 const NameAttrList& func, const FunctionLibraryDefinition* lib_def,
688 std::unique_ptr<FunctionBody>* g_body) {
689 const FunctionDef* fdef = lib_def->Find(func.name());
690 if (fdef == nullptr) {
691 // f is a primitive op.
692 gradient::Creator creator;
693 TF_RETURN_IF_ERROR(gradient::GetOpGradientCreator(func.name(), &creator));
694 if (creator == nullptr) {
695 return errors::InvalidArgument("No gradient is defined for ",
696 func.name());
697 }
698 FunctionDef grad_fdef;
699 // TODO(josh11b): Should filter out the attrs from func that aren't used
700 // by the gradient function.
701 TF_RETURN_IF_ERROR(creator(AttrSlice(&func.attr()), &grad_fdef));
702 TF_RETURN_IF_ERROR(
703 FunctionDefToBody(grad_fdef, AttrSlice(&func.attr()), lib_def, g_body));
704 } else {
705 // f is a user-defined function.
706 InstantiateOptions options;
707 if (lib_def != base_lib_def_) {
708 options.lib_def = lib_def;
709 }
710 Handle f_handle;
711 TF_RETURN_IF_ERROR(
712 Instantiate(func.name(), AttrSlice(&func.attr()), options, &f_handle));
713 const FunctionBody* f_body = GetFunctionBody(f_handle);
714 CHECK_NOTNULL(f_body);
715 *g_body = SymbolicGradient(*f_body);
716 }
717 return Status::OK();
718 }
719
IsLocalTarget(const InstantiateOptions & options) const720 bool FunctionLibraryRuntimeImpl::IsLocalTarget(
721 const InstantiateOptions& options) const {
722 if (device_ == nullptr) return true;
723 if (options.target.empty()) return true;
724 if (options.is_multi_device_function) return false;
725 Device* target_device;
726 if (!device_mgr_->LookupDevice(options.target, &target_device).ok()) {
727 VLOG(1) << "Not instantiating function in FLR because failed to "
728 << "find device " << options.target << " in device manager";
729 return false;
730 }
731 if (target_device != device_) {
732 VLOG(1) << "Not instantiating function in FLR because target device "
733 << options.target
734 << " is different from FLR's device: " << device_->DebugString();
735 return false;
736 }
737 return true;
738 }
739
Instantiate(const string & function_name,AttrSlice attrs,const InstantiateOptions & options,Handle * handle)740 Status FunctionLibraryRuntimeImpl::Instantiate(
741 const string& function_name, AttrSlice attrs,
742 const InstantiateOptions& options, Handle* handle) {
743 if (!IsLocalTarget(options)) {
744 return parent_->Instantiate(function_name, attrs, options, handle);
745 }
746
747 // Since this is a local target, ensure that the local `device_name_` appears
748 // in the canonical key.
749 InstantiateOptions options_copy(options);
750 options_copy.target = device_name_;
751 const string key = Canonicalize(function_name, attrs, options_copy);
752
753 {
754 mutex_lock l(mu_);
755 *handle = parent_->GetHandle(key);
756 if (*handle != kInvalidHandle) {
757 FunctionLibraryRuntime::LocalHandle handle_on_device =
758 parent_->GetHandleOnDevice(device_name_, *handle);
759 if (handle_on_device == kInvalidLocalHandle) {
760 return errors::Internal("LocalHandle not found for handle ", *handle,
761 ".");
762 }
763 auto item_handle = items_->find(handle_on_device);
764 if (item_handle == items_->end()) {
765 return errors::Internal("LocalHandle ", handle_on_device,
766 " for handle ", *handle,
767 " not found in items.");
768 }
769 ++item_handle->second->instantiation_counter;
770 return Status::OK();
771 }
772 }
773
774 const FunctionLibraryDefinition* lib_def =
775 options.lib_def ? options.lib_def : base_lib_def_;
776 std::unique_ptr<FunctionBody> fbody;
777 if (function_name == kGradientOp) {
778 const AttrValue* f = attrs.Find(kFuncAttr);
779 if (f == nullptr) {
780 return errors::InvalidArgument("SymbolicGradient is missing attr: f");
781 }
782 const auto& func = f->func();
783 if (func.name() == kGradientOp) {
784 return errors::InvalidArgument("Can't take gradient of SymbolicGradient");
785 }
786 const string grad = lib_def->FindGradient(func.name());
787 if (!grad.empty()) {
788 return Instantiate(grad, AttrSlice(&func.attr()), options, handle);
789 }
790 TF_RETURN_IF_ERROR(InstantiateSymbolicGradient(func, lib_def, &fbody));
791 } else {
792 const FunctionDef* fdef = lib_def->Find(function_name);
793 if (fdef == nullptr) {
794 return errors::NotFound("Function ", function_name, " is not defined.");
795 }
796 TF_RETURN_IF_ERROR(FunctionDefToBody(*fdef, attrs, lib_def, &fbody));
797 }
798
799 LocalHandle local_handle;
800 {
801 mutex_lock l(mu_);
802 *handle = parent_->GetHandle(key);
803 if (*handle != kInvalidHandle) {
804 local_handle = parent_->GetHandleOnDevice(device_name_, *handle);
805 ++(*items_)[local_handle]->instantiation_counter;
806 } else {
807 *handle = parent_->AddHandle(key, device_name_, next_handle_);
808 Item* item = new Item;
809 item->func_graph = fbody.release();
810 item->instantiation_counter = 1;
811 item->executor_type = ExecutorType(options, attrs);
812 if (options.lib_def) {
813 item->overlay_flr =
814 new FunctionLibraryRuntimeOverlay(this, options.lib_def);
815 }
816 local_handle = next_handle_++;
817 items_->emplace(local_handle, std::unique_ptr<Item>(item));
818 }
819 }
820
821 if (options.create_kernels_eagerly) {
822 Item* item;
823 TF_RETURN_IF_ERROR(GetOrCreateItem(local_handle, &item));
824 }
825
826 return Status::OK();
827 }
828
ReleaseHandle(Handle handle)829 Status FunctionLibraryRuntimeImpl::ReleaseHandle(Handle handle) {
830 LocalHandle h = parent_->GetHandleOnDevice(device_name_, handle);
831 if (h == kInvalidLocalHandle) {
832 return parent_->ReleaseHandle(handle);
833 }
834 std::unique_ptr<Item> item_to_delete;
835 Status parent_status;
836 {
837 mutex_lock l(mu_);
838 // Return directly if all items has already been released.
839 if (items_ == nullptr) return Status::OK();
840
841 auto it = items_->find(h);
842 if (it == items_->end()) {
843 return errors::Internal(
844 "Inconsistent FunctionLibraryRuntime. Expected to find an item for "
845 "handle ",
846 h, " but found none");
847 }
848 std::unique_ptr<Item>& item = it->second;
849 --item->instantiation_counter;
850 if (item->instantiation_counter == 0) {
851 // We don't simply erase h's item because that would trigger
852 // item destruction while holding mu_. Item destruction can
853 // trigger graph destruction. If the graph contains kernels like
854 // CallOp or PartitionCallOp, their destructors will release cached
855 // function handles, resulting in deadlock here.
856 item_to_delete = std::move(item);
857 items_->erase(h);
858 parent_status = parent_->RemoveHandle(handle);
859 }
860 }
861 return parent_status;
862 }
863
864 namespace {
865 // Removes all stateless nodes that do not contribute to a return
866 // value from the function body. Unlike `RemoveDeadNodes()`, which is
867 // triggered by `OptimizerOptions.do_function_inlining`, this pass
868 // ignores the SINK node, from which (by definition) all nodes are
869 // reverse reachable, and preserves all nodes that are reachable from
870 // control output nodes.
871 //
872 // TODO(ezhulenev, skyewm): Function body should not have special treatment of
873 // stateful ops, graph should encode nodes that must execute with `control_ret`
874 // and `control_output`.
PruneFunctionBody(const FunctionDef & fdef,Graph * g)875 void PruneFunctionBody(const FunctionDef& fdef, Graph* g) {
876 VLOG(2) << "Pruning function body: function_name=" << fdef.signature().name();
877
878 // `control_ret` nodes must be always executed.
879 std::unordered_set<StringPiece, StringPieceHasher> control_ret_nodes;
880 for (const auto& control_ret : fdef.control_ret()) {
881 control_ret_nodes.insert(control_ret.second);
882 }
883
884 std::unordered_set<const Node*> nodes;
885 for (auto n : g->nodes()) {
886 // NOTE(mrry): "_Retval" nodes are stateful, and so will be added
887 // to the seed set of `nodes`. "_Arg" nodes are also stateful, but we
888 // specifically exclude them as seeds, to avoid unconditionally executing
889 // unused argument nodes (e.g. in a function like `lambda x, y: y`).
890 // TODO(mrry): Investigate whether the `n->IsControlFlow()` test is
891 // still needed. It would be preferable to prune entire loops and/or
892 // conditionals if they are not used in the graph.
893 if (n->IsControlFlow() ||
894 (n->op_def().is_stateful() && n->type_string() != kArgOp) ||
895 (control_ret_nodes.find(n->name()) != control_ret_nodes.end())) {
896 nodes.insert(n);
897 }
898 }
899 bool changed = PruneForReverseReachability(g, std::move(nodes));
900 if (changed) {
901 FixupSourceAndSinkEdges(g);
902 }
903 }
904 } // namespace
905
CreateItem(Item ** item)906 Status FunctionLibraryRuntimeImpl::CreateItem(Item** item) {
907 const FunctionBody* fbody;
908 FunctionLibraryRuntime* flr;
909 string executor_type;
910 {
911 tf_shared_lock l(mu_);
912 fbody = (*item)->func_graph;
913 flr = (*item)->overlay_flr
914 ? static_cast<FunctionLibraryRuntime*>((*item)->overlay_flr)
915 : static_cast<FunctionLibraryRuntime*>(this);
916 executor_type = (*item)->executor_type;
917 }
918 const FunctionLibraryDefinition* lib_def =
919 flr->GetFunctionLibraryDefinition();
920 std::unique_ptr<Graph> g(new Graph(lib_def));
921 CopyGraph(*fbody->graph, g.get());
922
923 PruneFunctionBody(fbody->fdef, g.get());
924 optimizer_.Optimize(this, env(), device(), &g, /*shape_map=*/nullptr);
925 TF_RETURN_IF_ERROR(EnsureMemoryTypes(DeviceType(device()->device_type()),
926 device()->name(), g.get()));
927
928 // Creates an executor based on the g. This must be done without
929 // holding mu_ because create_kernel_ calls back into the library.
930 LocalExecutorParams params;
931 params.device = device_;
932 params.function_library = flr;
933 if (flr == this) {
934 params.create_kernel = create_kernel_;
935 } else {
936 params.create_kernel =
937 [this, flr](const std::shared_ptr<const NodeProperties>& props,
938 OpKernel** kernel) {
939 return CreateKernel(props, flr, kernel);
940 };
941 }
942 params.delete_kernel = [](OpKernel* kernel) {
943 DeleteNonCachedKernel(kernel);
944 };
945 params.session_metadata = session_metadata_;
946 std::unique_ptr<Executor> exec;
947 TF_RETURN_IF_ERROR(NewExecutor(executor_type, params, *g, &exec));
948 {
949 // Guard item since it is already inserted in items_.
950 mutex_lock l(mu_);
951 if ((*item)->exec == nullptr) {
952 (*item)->graph = std::move(g);
953 (*item)->exec = exec.release();
954 }
955 }
956 return Status::OK();
957 }
958
GetOrCreateItem(LocalHandle local_handle,Item ** item)959 Status FunctionLibraryRuntimeImpl::GetOrCreateItem(LocalHandle local_handle,
960 Item** item) {
961 {
962 tf_shared_lock l(mu_);
963 auto iter = items_->find(local_handle);
964 if (iter == items_->end()) {
965 return errors::Internal("Local function handle ", local_handle,
966 " is not valid. Likely an internal error.");
967 }
968 *item = iter->second.get();
969 if ((*item)->exec != nullptr) {
970 return Status::OK();
971 }
972 }
973 // NOTE: We need to call CreateItem out of mu_ because creating an
974 // executor needs to call CreateKernel.
975 return CreateItem(item);
976 }
977
ExecutorArgsFromOptions(const FunctionLibraryRuntime::Options & run_opts,CallFrameInterface * frame,Executor::Args * exec_args)978 void FunctionLibraryRuntimeImpl::ExecutorArgsFromOptions(
979 const FunctionLibraryRuntime::Options& run_opts, CallFrameInterface* frame,
980 Executor::Args* exec_args) {
981 // Inherit the step_id from the caller.
982 exec_args->step_id = run_opts.step_id;
983 exec_args->rendezvous = run_opts.rendezvous;
984 exec_args->stats_collector = run_opts.stats_collector;
985 exec_args->cancellation_manager = run_opts.cancellation_manager;
986 exec_args->step_container = run_opts.step_container;
987 if (run_opts.runner) {
988 exec_args->runner = *run_opts.runner;
989 } else {
990 exec_args->runner = default_runner_;
991 }
992 exec_args->collective_executor = run_opts.collective_executor;
993 exec_args->call_frame = frame;
994 exec_args->run_all_kernels_inline = run_opts.run_all_kernels_inline;
995 }
996
RunRemote(const Options & opts,Handle handle,gtl::ArraySlice<Tensor> args,std::vector<Tensor> * rets,Item * item,DoneCallback done)997 void FunctionLibraryRuntimeImpl::RunRemote(const Options& opts, Handle handle,
998 gtl::ArraySlice<Tensor> args,
999 std::vector<Tensor>* rets,
1000 Item* item, DoneCallback done) {
1001 string target_device = parent_->GetDeviceName(handle);
1002 string source_device = opts.source_device;
1003 RendezvousInterface* rendezvous = opts.rendezvous;
1004 DeviceContext* device_context;
1005 Status s = parent_->GetDeviceContext(target_device, &device_context);
1006 if (!s.ok()) {
1007 done(s);
1008 return;
1009 }
1010 int64 src_incarnation, target_incarnation;
1011 s = parent_->GetDeviceIncarnation(source_device, &src_incarnation);
1012 s.Update(parent_->GetDeviceIncarnation(target_device, &target_incarnation));
1013 if (!s.ok()) {
1014 done(s);
1015 return;
1016 }
1017
1018 const FunctionBody* fbody = GetFunctionBody(handle);
1019 FunctionCallFrame* frame =
1020 new FunctionCallFrame(fbody->arg_types, fbody->ret_types);
1021 Executor::Args* exec_args = new Executor::Args;
1022 ExecutorArgsFromOptions(opts, frame, exec_args);
1023
1024 std::vector<AllocatorAttributes> args_alloc_attrs, rets_alloc_attrs;
1025 args_alloc_attrs.reserve(fbody->arg_types.size());
1026 rets_alloc_attrs.reserve(fbody->ret_types.size());
1027 // Note: Functions assume that int32's are always on host memory.
1028 for (const auto& arg_type : fbody->arg_types) {
1029 AllocatorAttributes arg_alloc_attrs;
1030 if (MTypeFromDType(arg_type) == HOST_MEMORY) {
1031 arg_alloc_attrs.set_on_host(true);
1032 }
1033 args_alloc_attrs.push_back(arg_alloc_attrs);
1034 }
1035 for (const auto& ret_type : fbody->ret_types) {
1036 AllocatorAttributes ret_alloc_attrs;
1037 if (MTypeFromDType(ret_type) == HOST_MEMORY) {
1038 ret_alloc_attrs.set_on_host(true);
1039 }
1040 rets_alloc_attrs.push_back(ret_alloc_attrs);
1041 }
1042
1043 bool allow_dead_tensors = opts.allow_dead_tensors;
1044
1045 // The ProcFLR sends the arguments to the function from the source_device to
1046 // the target_device. So here we receive those arguments. Similarly, when the
1047 // computation is done and stored in *rets, we send the return values back
1048 // to the source_device (caller) so that the ProcFLR can receive them later.
1049 std::vector<Tensor>* remote_args = new std::vector<Tensor>;
1050 ProcessFunctionLibraryRuntime::ReceiveTensorsAsync(
1051 source_device, target_device, "arg_", src_incarnation, args.size(),
1052 device_context, args_alloc_attrs, rendezvous, remote_args,
1053 [frame, remote_args, item, source_device, target_device,
1054 target_incarnation, rendezvous, device_context, rets, done, exec_args,
1055 rets_alloc_attrs, allow_dead_tensors](const Status& status) {
1056 Status s = status;
1057 if (s.ok()) {
1058 s = frame->SetArgs(*remote_args);
1059 }
1060 if (!s.ok()) {
1061 delete frame;
1062 delete remote_args;
1063 delete exec_args;
1064 done(s);
1065 return;
1066 }
1067 item->exec->RunAsync(
1068 *exec_args,
1069 [frame, rets, done, source_device, target_device,
1070 target_incarnation, rendezvous, device_context, remote_args,
1071 rets_alloc_attrs, allow_dead_tensors](const Status& status) {
1072 Status s = status;
1073 if (s.ok()) {
1074 s = frame->ConsumeRetvals(rets, allow_dead_tensors);
1075 }
1076 delete frame;
1077 if (!s.ok()) {
1078 delete remote_args;
1079 done(s);
1080 return;
1081 }
1082 s = ProcessFunctionLibraryRuntime::SendTensors(
1083 target_device, source_device, "ret_", target_incarnation,
1084 *rets, device_context, rets_alloc_attrs, rendezvous);
1085 delete remote_args;
1086 done(s);
1087 });
1088 delete exec_args;
1089 });
1090 }
1091
Run(const Options & opts,Handle handle,gtl::ArraySlice<Tensor> args,std::vector<Tensor> * rets,DoneCallback done)1092 void FunctionLibraryRuntimeImpl::Run(const Options& opts, Handle handle,
1093 gtl::ArraySlice<Tensor> args,
1094 std::vector<Tensor>* rets,
1095 DoneCallback done) {
1096 if (opts.cancellation_manager && opts.cancellation_manager->IsCancelled()) {
1097 done(errors::Cancelled("Function was cancelled before it was started"));
1098 return;
1099 }
1100 Options run_opts = opts;
1101 if (opts.create_rendezvous) {
1102 auto* rendezvous = new PrivateIntraProcessRendezvous(device_mgr_);
1103 run_opts.rendezvous = rendezvous;
1104 run_opts.create_rendezvous = false;
1105 done = [done = std::move(done), rendezvous](const Status& status) mutable {
1106 delete rendezvous;
1107 done(status);
1108 };
1109 }
1110
1111 LocalHandle local_handle = parent_->GetHandleOnDevice(device_name_, handle);
1112 if (local_handle == kInvalidLocalHandle) {
1113 parent_->Run(run_opts, handle, args, rets, done);
1114 return;
1115 }
1116
1117 if (run_opts.runner == nullptr) {
1118 run_opts.runner = &default_runner_;
1119 }
1120 DCHECK(run_opts.runner != nullptr);
1121
1122 Item* item = nullptr;
1123 Status s = GetOrCreateItem(local_handle, &item);
1124 if (!s.ok()) {
1125 done(s);
1126 return;
1127 }
1128
1129 if (run_opts.remote_execution) {
1130 // NOTE(mrry): `RunRemote()` will set `exec_args->call_frame` for us.
1131 RunRemote(run_opts, handle, args, rets, item, std::move(done));
1132 return;
1133 }
1134
1135 const FunctionBody* fbody = GetFunctionBody(handle);
1136 FunctionCallFrame* frame =
1137 new FunctionCallFrame(fbody->arg_types, fbody->ret_types);
1138 s = frame->SetArgs(args);
1139 if (!s.ok()) {
1140 delete frame;
1141 done(s);
1142 return;
1143 }
1144
1145 profiler::TraceMeProducer activity(
1146 // To TraceMeConsumers in ExecutorState::Process/Finish.
1147 [&opts] {
1148 return profiler::TraceMeEncode("FunctionRun", {{"id", opts.step_id}});
1149 },
1150 profiler::ContextType::kTfExecutor, opts.step_id,
1151 profiler::TraceMeLevel::kInfo);
1152
1153 Executor::Args exec_args;
1154 ExecutorArgsFromOptions(run_opts, frame, &exec_args);
1155
1156 bool allow_dead_tensors = run_opts.allow_dead_tensors;
1157 item->exec->RunAsync(
1158 // Executor args
1159 exec_args,
1160 // Done callback.
1161 [frame, rets, done, allow_dead_tensors](const Status& status) {
1162 Status s = status;
1163 if (s.ok()) {
1164 s = frame->ConsumeRetvals(rets, allow_dead_tensors);
1165 }
1166 delete frame;
1167 done(s);
1168 });
1169 }
1170
Run(const Options & opts,Handle handle,CallFrameInterface * frame,DoneCallback done)1171 void FunctionLibraryRuntimeImpl::Run(const Options& opts, Handle handle,
1172 CallFrameInterface* frame,
1173 DoneCallback done) {
1174 if (opts.cancellation_manager && opts.cancellation_manager->IsCancelled()) {
1175 done(errors::Cancelled(""));
1176 return;
1177 }
1178
1179 Options run_opts = opts;
1180 if (opts.create_rendezvous) {
1181 auto* rendezvous = new PrivateIntraProcessRendezvous(device_mgr_);
1182 run_opts.rendezvous = rendezvous;
1183 run_opts.create_rendezvous = false;
1184 done = [done = std::move(done), rendezvous](const Status& status) mutable {
1185 delete rendezvous;
1186 done(status);
1187 };
1188 }
1189
1190 LocalHandle local_handle = parent_->GetHandleOnDevice(
1191 device_name_, handle, /*include_multi_device=*/true);
1192 if (local_handle == kInvalidLocalHandle) {
1193 parent_->Run(run_opts, handle, frame, done);
1194 return;
1195 }
1196
1197 if (opts.remote_execution) {
1198 // NOTE(mrry): This bit is only set for a local function when `parent_`
1199 // calls back into this class, and the current implementation of
1200 // `ProcessFunctionLibraryRuntime` currently always uses the vector-based
1201 // `args`/`rets` interface.
1202 done(errors::Unimplemented("Remote calling with CallFrameInterface"));
1203 return;
1204 }
1205
1206 Item* item = nullptr;
1207 Status s = GetOrCreateItem(local_handle, &item);
1208 if (!s.ok()) {
1209 done(s);
1210 return;
1211 }
1212 if (run_opts.runner == nullptr) {
1213 run_opts.runner = &default_runner_;
1214 }
1215 DCHECK(run_opts.runner != nullptr);
1216
1217 profiler::TraceMeProducer activity(
1218 // To TraceMeConsumers in ExecutorState::Process/Finish.
1219 [&opts] {
1220 return profiler::TraceMeEncode("FunctionRun", {{"id", opts.step_id}});
1221 },
1222 profiler::ContextType::kTfExecutor, opts.step_id,
1223 profiler::TraceMeLevel::kInfo);
1224
1225 Executor::Args exec_args;
1226 ExecutorArgsFromOptions(run_opts, frame, &exec_args);
1227 item->exec->RunAsync(exec_args, std::move(done));
1228 }
1229
PrepareRunSync(Handle handle,Options * run_opts,Item ** out_item,std::unique_ptr<PrivateIntraProcessRendezvous> * out_rendezvous)1230 Status FunctionLibraryRuntimeImpl::PrepareRunSync(
1231 Handle handle, Options* run_opts, Item** out_item,
1232 std::unique_ptr<PrivateIntraProcessRendezvous>* out_rendezvous) {
1233 if (run_opts->cancellation_manager &&
1234 run_opts->cancellation_manager->IsCancelled()) {
1235 return errors::Cancelled("");
1236 }
1237
1238 if (run_opts->remote_execution) {
1239 // NOTE(mrry): This bit is only set for a local function when `parent_`
1240 // calls back into this class, and the current implementation of
1241 // `ProcessFunctionLibraryRuntime` currently always uses the asynchronous
1242 // Run() method.
1243 return errors::Unimplemented("Remote calling with RunSync()");
1244 }
1245
1246 if (run_opts->create_rendezvous) {
1247 *out_rendezvous =
1248 absl::make_unique<PrivateIntraProcessRendezvous>(device_mgr_);
1249 run_opts->rendezvous = out_rendezvous->get();
1250 run_opts->create_rendezvous = false;
1251 }
1252
1253 LocalHandle local_handle = parent_->GetHandleOnDevice(
1254 device_name_, handle, /*include_multi_device=*/true);
1255 if (local_handle == kInvalidLocalHandle) {
1256 *out_item = nullptr;
1257 return Status::OK();
1258 }
1259
1260 TF_RETURN_IF_ERROR(GetOrCreateItem(local_handle, out_item));
1261
1262 if (run_opts->runner == nullptr) {
1263 run_opts->runner = &default_runner_;
1264 }
1265 DCHECK(run_opts->runner != nullptr);
1266
1267 return Status::OK();
1268 }
1269
RunSync(Options opts,Handle handle,gtl::ArraySlice<Tensor> args,std::vector<Tensor> * rets)1270 Status FunctionLibraryRuntimeImpl::RunSync(Options opts, Handle handle,
1271 gtl::ArraySlice<Tensor> args,
1272 std::vector<Tensor>* rets) {
1273 Item* item = nullptr;
1274 std::unique_ptr<PrivateIntraProcessRendezvous> rendezvous;
1275 TF_RETURN_IF_ERROR(PrepareRunSync(handle, &opts, &item, &rendezvous));
1276 if (item == nullptr) {
1277 return parent_->RunSync(opts, handle, args, rets);
1278 }
1279
1280 Executor::Args exec_args;
1281 const FunctionBody* fbody = GetFunctionBody(handle);
1282 FunctionCallFrame frame(fbody->arg_types, fbody->ret_types);
1283 TF_RETURN_IF_ERROR(frame.SetArgs(args));
1284 ExecutorArgsFromOptions(opts, &frame, &exec_args);
1285
1286 TF_RETURN_IF_ERROR(item->exec->Run(exec_args));
1287 return frame.ConsumeRetvals(rets, opts.allow_dead_tensors);
1288 }
1289
RunSync(Options opts,Handle handle,CallFrameInterface * call_frame)1290 Status FunctionLibraryRuntimeImpl::RunSync(Options opts, Handle handle,
1291 CallFrameInterface* call_frame) {
1292 Item* item = nullptr;
1293 std::unique_ptr<PrivateIntraProcessRendezvous> rendezvous;
1294 TF_RETURN_IF_ERROR(PrepareRunSync(handle, &opts, &item, &rendezvous));
1295 if (item == nullptr) {
1296 return parent_->RunSync(opts, handle, call_frame);
1297 }
1298
1299 Executor::Args exec_args;
1300 ExecutorArgsFromOptions(opts, call_frame, &exec_args);
1301 return item->exec->Run(exec_args);
1302 }
1303
IsStateful(const string & func) const1304 bool FunctionLibraryRuntimeImpl::IsStateful(const string& func) const {
1305 const OpDef* op_def;
1306 const Status s = base_lib_def_->LookUpOpDef(func, &op_def);
1307 return s.ok() && op_def->is_stateful();
1308 }
1309
DebugString(Handle handle)1310 string FunctionLibraryRuntimeImpl::DebugString(Handle handle) {
1311 Item* item = nullptr;
1312 LocalHandle local_handle = parent_->GetHandleOnDevice(device_name_, handle);
1313 Status s = GetOrCreateItem(local_handle, &item);
1314 if (s.ok()) {
1315 if (item->graph) {
1316 return tensorflow::DebugString(item->graph.get());
1317 } else {
1318 return tensorflow::DebugString(item->func_graph->graph);
1319 }
1320 } else {
1321 return s.ToString();
1322 }
1323 }
1324
Clone(std::unique_ptr<FunctionLibraryDefinition> * out_lib_def,std::unique_ptr<ProcessFunctionLibraryRuntime> * out_pflr,FunctionLibraryRuntime ** out_flr,bool skip_flib_def)1325 Status FunctionLibraryRuntimeImpl::Clone(
1326 std::unique_ptr<FunctionLibraryDefinition>* out_lib_def,
1327 std::unique_ptr<ProcessFunctionLibraryRuntime>* out_pflr,
1328 FunctionLibraryRuntime** out_flr, bool skip_flib_def) {
1329 TF_RETURN_IF_ERROR(parent_->Clone(env_, graph_def_version_,
1330 optimizer_.options(), out_lib_def, out_pflr,
1331 skip_flib_def));
1332 *out_flr = (*out_pflr)->GetFLR(device_->name());
1333 if (*out_flr != nullptr) {
1334 return Status::OK();
1335 } else {
1336 return errors::Internal("Cloning FunctionLibraryRuntime failed.");
1337 }
1338 }
1339
1340 namespace {
1341
1342 struct CustomCreatorSingleton {
1343 mutex mu;
1344 CustomKernelCreator* custom_creator = nullptr;
1345
Settensorflow::__anon9a8797221111::CustomCreatorSingleton1346 void Set(CustomKernelCreator* cb) {
1347 mutex_lock l(mu);
1348 custom_creator = cb;
1349 }
1350
Gettensorflow::__anon9a8797221111::CustomCreatorSingleton1351 CustomKernelCreator* Get() {
1352 mutex_lock l(mu);
1353 return custom_creator;
1354 }
1355 };
1356
GetCustomCreatorSingleton()1357 CustomCreatorSingleton* GetCustomCreatorSingleton() {
1358 static CustomCreatorSingleton* ccs = new CustomCreatorSingleton;
1359 return ccs;
1360 }
1361
1362 } // namespace
1363
GetDefaultCustomKernelCreator()1364 const CustomKernelCreator* GetDefaultCustomKernelCreator() {
1365 return GetCustomCreatorSingleton()->Get();
1366 }
1367
RegisterDefaultCustomKernelCreator(CustomKernelCreator * c)1368 void RegisterDefaultCustomKernelCreator(CustomKernelCreator* c) {
1369 GetCustomCreatorSingleton()->Set(c);
1370 }
1371
NewFunctionLibraryRuntime(const DeviceMgr * device_mgr,Env * env,const ConfigProto * config,Device * device,int graph_def_version,const FunctionLibraryDefinition * lib_def,thread::ThreadPool * thread_pool,const OptimizerOptions & optimizer_options,const SessionMetadata * session_metadata,ProcessFunctionLibraryRuntime * parent)1372 std::unique_ptr<FunctionLibraryRuntime> NewFunctionLibraryRuntime(
1373 const DeviceMgr* device_mgr, Env* env, const ConfigProto* config,
1374 Device* device, int graph_def_version,
1375 const FunctionLibraryDefinition* lib_def, thread::ThreadPool* thread_pool,
1376 const OptimizerOptions& optimizer_options,
1377 const SessionMetadata* session_metadata,
1378 ProcessFunctionLibraryRuntime* parent) {
1379 return std::unique_ptr<FunctionLibraryRuntime>(new FunctionLibraryRuntimeImpl(
1380 device_mgr, env, config, device, graph_def_version, lib_def, thread_pool,
1381 optimizer_options, session_metadata, parent));
1382 }
1383
1384 class SymbolicGradientHelper {
1385 public:
SymbolicGradientHelper(const FunctionBody & f)1386 explicit SymbolicGradientHelper(const FunctionBody& f) : fbody_(&f) {}
1387 ~SymbolicGradientHelper() = default;
1388
1389 std::unique_ptr<FunctionBody> Compute();
1390
1391 private:
1392 const FunctionBody* fbody_;
1393
1394 // Makes a copy of fbody_ in gbody.
1395 void Copy(FunctionBody* gbody);
1396
1397 TF_DISALLOW_COPY_AND_ASSIGN(SymbolicGradientHelper);
1398 };
1399
Copy(FunctionBody * gbody)1400 void SymbolicGradientHelper::Copy(FunctionBody* gbody) {
1401 const Graph& src = *(fbody_->graph);
1402 gbody->graph = new Graph(src.op_registry());
1403 Graph* dst = gbody->graph;
1404
1405 std::vector<Node*> node_map(src.num_node_ids());
1406
1407 // Copy just the fdef attributes (copy '_noinline' and other similar flags to
1408 // the gradient function body).
1409 *(gbody->fdef.mutable_attr()) = fbody_->fdef.attr();
1410
1411 // Copy the nodes.
1412 node_map[src.source_node()->id()] = dst->source_node();
1413 node_map[src.sink_node()->id()] = dst->sink_node();
1414 for (Node* n : src.op_nodes()) {
1415 node_map[n->id()] = dst->CopyNode(n);
1416 }
1417
1418 // Copy the edges.
1419 for (const Edge* e : src.edges()) {
1420 Node* src_copy = node_map[e->src()->id()];
1421 Node* dst_copy = node_map[e->dst()->id()];
1422 dst->AddEdge(src_copy, e->src_output(), dst_copy, e->dst_input());
1423 }
1424
1425 // Save inputs in copied graph.
1426 CHECK_EQ(fbody_->arg_types.size(), fbody_->arg_nodes.size());
1427 gbody->arg_types = fbody_->arg_types;
1428 for (std::size_t i = 0; i < fbody_->arg_nodes.size(); ++i) {
1429 gbody->arg_nodes.push_back(node_map[fbody_->arg_nodes[i]->id()]);
1430 }
1431
1432 // Save outputs in copied graph.
1433 CHECK_EQ(fbody_->ret_types.size(), fbody_->ret_nodes.size());
1434 gbody->ret_types = fbody_->ret_types;
1435 for (std::size_t i = 0; i < fbody_->ret_nodes.size(); ++i) {
1436 gbody->ret_nodes.push_back(node_map[fbody_->ret_nodes[i]->id()]);
1437 }
1438 }
1439
Compute()1440 std::unique_ptr<FunctionBody> SymbolicGradientHelper::Compute() {
1441 FunctionBody* gbody = new FunctionBody;
1442 Copy(gbody); // copy fbody_ into gbody.
1443
1444 Graph* g = gbody->graph;
1445
1446 const int num_y = static_cast<int>(gbody->ret_nodes.size());
1447
1448 // Populate 'y_node_outputs_' with node function body outputs.
1449 // Populate 'y_grad_nodes' with initial gradient nodes for each return node
1450 // of the original function body (these will be 'arg' nodes in the function
1451 // gradient body).
1452 std::vector<NodeOut> y_node_outputs;
1453 y_node_outputs.reserve(num_y);
1454 std::vector<NodeOut> y_grad_node_outputs;
1455 y_grad_node_outputs.reserve(num_y);
1456 for (int i = 0; i < num_y; ++i) {
1457 Node* y = gbody->ret_nodes[i];
1458 y_node_outputs.push_back({y, 0});
1459 DCHECK_EQ(y->type_string(), kRetOp);
1460 const DataType dtype = y->input_type(0);
1461 const int index = static_cast<int>(gbody->arg_nodes.size());
1462 Node* dy = AddArg(g, dtype, index);
1463 gbody->arg_types.push_back(dtype);
1464 gbody->arg_nodes.push_back(dy);
1465 y_grad_node_outputs.push_back({dy, 0});
1466 }
1467
1468 // Populate 'x_nodes' with function args (excluding 'y_grad_node_outputs').
1469 const size_t num_x = fbody_->arg_nodes.size();
1470 std::vector<NodeOut> x_node_outputs;
1471 x_node_outputs.reserve(num_x);
1472 for (size_t i = 0; i < fbody_->arg_nodes.size(); ++i) {
1473 x_node_outputs.push_back({gbody->arg_nodes[i], 0});
1474 }
1475
1476 // Call AddSymbolicGradients which will add nodes to graph 'g' that
1477 // compute the function gradient (adding an entry in 'x_grad_node_outputs'
1478 // for each node in 'x_node_outputs').
1479 std::vector<NodeOut> x_grad_node_outputs;
1480 TF_CHECK_OK(AddSymbolicGradients(y_node_outputs, x_node_outputs,
1481 y_grad_node_outputs, &x_grad_node_outputs,
1482 g));
1483
1484 // Remove the old return nodes from the function body.
1485 for (Node* n : gbody->ret_nodes) {
1486 g->RemoveNode(n);
1487 }
1488 gbody->ret_types = fbody_->arg_types;
1489 // TODO(apassos): use the right dtype for gradients of resource variables
1490 for (int i = 0; i < gbody->ret_types.size(); ++i) {
1491 if (gbody->ret_types[i] == DT_RESOURCE) {
1492 gbody->ret_types[i] = DT_FLOAT;
1493 }
1494 }
1495 gbody->ret_nodes.clear();
1496 // Add new return nodes to the function gradient body for each node
1497 // in 'x_grad_nodes'.
1498 const int arg_types_size = static_cast<int>(fbody_->arg_types.size());
1499 for (int i = 0; i < arg_types_size; ++i) {
1500 Endpoint grad = {x_grad_node_outputs[i].node, x_grad_node_outputs[i].index};
1501 Node* ret = AddRet(g, grad, i);
1502 gbody->ret_nodes.push_back(ret);
1503 }
1504
1505 return std::unique_ptr<FunctionBody>(gbody);
1506 }
1507
SymbolicGradient(const FunctionBody & f)1508 std::unique_ptr<FunctionBody> SymbolicGradient(const FunctionBody& f) {
1509 return SymbolicGradientHelper(f).Compute();
1510 }
1511
1512 } // end namespace tensorflow
1513