1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/common_runtime/eager/kernel_and_device.h"
17 
18 #include <memory>
19 
20 #include "absl/strings/match.h"
21 #include "tensorflow/core/common_runtime/device_factory.h"
22 #include "tensorflow/core/common_runtime/eager/attr_builder.h"
23 #include "tensorflow/core/common_runtime/process_function_library_runtime.h"
24 #include "tensorflow/core/common_runtime/rendezvous_mgr.h"
25 #include "tensorflow/core/framework/allocator.h"
26 #include "tensorflow/core/framework/cancellation.h"
27 #include "tensorflow/core/framework/function.h"
28 #include "tensorflow/core/framework/node_def.pb.h"
29 #include "tensorflow/core/framework/op_kernel.h"
30 #include "tensorflow/core/framework/resource_mgr.h"
31 #include "tensorflow/core/framework/types.h"
32 #include "tensorflow/core/framework/types.pb.h"
33 #include "tensorflow/core/lib/core/errors.h"
34 #include "tensorflow/core/lib/core/refcount.h"
35 #include "tensorflow/core/lib/gtl/cleanup.h"
36 #include "tensorflow/core/lib/gtl/map_util.h"
37 #include "tensorflow/core/lib/random/random.h"
38 #include "tensorflow/core/platform/denormal.h"
39 #include "tensorflow/core/platform/errors.h"
40 #include "tensorflow/core/platform/fingerprint.h"
41 #include "tensorflow/core/platform/setround.h"
42 #include "tensorflow/core/profiler/lib/annotated_traceme.h"
43 #include "tensorflow/core/profiler/lib/traceme.h"
44 #include "tensorflow/core/public/version.h"
45 #include "tensorflow/core/util/tensor_slice_reader_cache.h"
46 #if !defined(IS_MOBILE_PLATFORM)
47 #include "tensorflow/core/grappler/grappler_item.h"
48 #include "tensorflow/core/grappler/optimizers/meta_optimizer.h"
49 #endif  // !IS_MOBILE_PLATFORM
50 
51 namespace tensorflow {
52 
GetLocalArg(const FunctionArgIndex & index,Tensor * val) const53 Status EagerKernelArgs::GetLocalArg(const FunctionArgIndex& index,
54                                     Tensor* val) const {
55   if (index.sub_index >= 0) {
56     return errors::InvalidArgument("Got unexpected sub_index ", index.sub_index,
57                                    " for argument ", index.index);
58   }
59   Tensor* arg = tensor_args_.at(index.index).tensor;
60   if (arg) {
61     *val = *arg;
62     return Status::OK();
63   } else {
64     return errors::NotFound("Argument ", index.index, " has no local tensor.");
65   }
66 }
67 
GetLocalTensors() const68 std::vector<Tensor> EagerKernelArgs::GetLocalTensors() const {
69   std::vector<Tensor> lcoal_inputs;
70   lcoal_inputs.reserve(tensor_args_.size());
71   for (const TensorValue& tensor_value : tensor_args_) {
72     lcoal_inputs.push_back(*tensor_value.tensor);
73   }
74   return lcoal_inputs;
75 }
76 
get_runner() const77 std::function<void(std::function<void()>)>* KernelAndDevice::get_runner()
78     const {
79   if (runner_) {
80     return runner_;
81   } else {
82     static auto* default_runner =
83         new std::function<void(std::function<void()>)>(
84             [](const std::function<void()>& f) { f(); });
85     return default_runner;
86   }
87 }
88 
~KernelAndDeviceFunc()89 KernelAndDeviceFunc::~KernelAndDeviceFunc() {
90   if (handle_ != kInvalidHandle) {
91     Status status = pflr_->ReleaseHandle(handle_);
92     if (!status.ok()) {
93       LOG(INFO) << "Ignoring error status when releasing multi-device function "
94                    "handle "
95                 << status.ToString();
96     }
97   }
98 }
99 
Init(const bool log_device_placement,const NodeDef & ndef,GraphCollector * graph_collector)100 Status KernelAndDeviceOp::Init(const bool log_device_placement,
101                                const NodeDef& ndef,
102                                GraphCollector* graph_collector) {
103   OpKernel* k = nullptr;
104   if (flr_ == nullptr) {
105     return errors::Internal(
106         "A valid FunctionLibraryRuntime must be provided when running ops "
107         "based on OpKernel.");
108   }
109   std::shared_ptr<const NodeProperties> props;
110   TF_RETURN_IF_ERROR(NodeProperties::CreateFromNodeDef(
111       ndef, flr_->GetFunctionLibraryDefinition(), &props));
112   TF_RETURN_IF_ERROR(flr_->CreateKernel(props, &k));
113   kernel_.reset(k);
114 
115   input_alloc_attrs_.resize(kernel_->num_inputs());
116   input_devices_.resize(kernel_->num_inputs(), device_);
117   for (size_t i = 0; i < input_alloc_attrs_.size(); ++i) {
118     bool host = kernel_->input_memory_types()[i] == tensorflow::HOST_MEMORY;
119     input_alloc_attrs_[i].set_on_host(host);
120     if (host) {
121       input_devices_[i] = host_cpu_device_;
122     }
123   }
124   output_alloc_attrs_.resize(kernel_->num_outputs());
125   for (size_t i = 0; i < output_alloc_attrs_.size(); ++i) {
126     output_alloc_attrs_[i].set_on_host(kernel_->output_memory_types()[i] ==
127                                        tensorflow::HOST_MEMORY);
128   }
129 
130   return Status::OK();
131 }
132 
InstantiateFunc(const bool log_device_placement,const NodeDef & ndef,GraphCollector * graph_collector)133 Status KernelAndDeviceFunc::InstantiateFunc(const bool log_device_placement,
134                                             const NodeDef& ndef,
135                                             GraphCollector* graph_collector) {
136   const OpDef* op_def = nullptr;
137   const FunctionDef* function_def;
138   if (flr_ == nullptr) {
139     // If function is being executed without an explicit device request,
140     // lookup the FunctionDef in the CPU's FLR. All FLRs share the same
141     // library.
142     function_def = pflr_->GetFLR(host_cpu_device_->name())
143                        ->GetFunctionLibraryDefinition()
144                        ->Find(ndef.op());
145   } else {
146     function_def = flr_->GetFunctionLibraryDefinition()->Find(ndef.op());
147   }
148 
149   if (function_def != nullptr) {
150     op_def = &(function_def->signature());
151   } else {
152     TF_RETURN_IF_ERROR(OpDefForOp(ndef.op(), &op_def));
153   }
154   TF_RETURN_IF_ERROR(
155       InOutTypesForNode(ndef, *op_def, &input_dtypes_, &output_dtypes_));
156 
157   FunctionLibraryRuntime::InstantiateOptions options;
158   options.target = device_ == nullptr ? "" : device_->name();
159   options.is_multi_device_function = true;
160   for (const Device* device : input_devices_) {
161     options.input_devices.push_back(device->name());
162   }
163   options.composite_devices = composite_devices_;
164   options.input_resource_dtypes_and_shapes = input_resource_dtypes_and_shapes_;
165   if (outputs_on_op_device_) {
166     const FunctionLibraryDefinition* lib_def =
167         pflr_->GetFunctionLibraryDefinition();
168     const FunctionDef* fdef = lib_def->Find(ndef.op());
169     if (fdef == nullptr) {
170       return errors::InvalidArgument("Failed to find function ", ndef.op());
171     }
172     for (int i = 0; i < fdef->signature().output_arg_size(); ++i) {
173       options.output_devices.push_back(options.target);
174     }
175   }
176 
177   const auto& it = ndef.attr().find("executor_type");
178   if (it != ndef.attr().end()) {
179     options.executor_type = it->second.s();
180   }
181   const auto& is_component_fn_it = ndef.attr().find("is_component_function");
182   if (is_component_fn_it != ndef.attr().end()) {
183     options.is_component_function = is_component_fn_it->second.b();
184   }
185 #if !defined(IS_MOBILE_PLATFORM)
186   // Android tf library does not include grappler.
187   const auto& config_it = ndef.attr().find("config_proto");
188   if (config_it != ndef.attr().end()) {
189     if (!options.config_proto.ParseFromString(config_it->second.s())) {
190       return errors::InvalidArgument(
191           "Failed to parse config_proto attribute as tensorflow::ConfigProto "
192           "proto.");
193     }
194     grappler::GrapplerItem::OptimizationOptions optimization_options =
195         grappler::CreateOptOptionsForEager();
196 
197     options.optimize_graph_fn = std::bind(
198         grappler::OptimizeGraph, std::placeholders::_1, std::placeholders::_2,
199         std::placeholders::_3, std::placeholders::_4, std::placeholders::_5,
200         options.config_proto, function_def->signature().name(),
201         optimization_options, std::placeholders::_6);
202   }
203 #endif  // !IS_MOBILE_PLATFORM
204   options.graph_collector = graph_collector;
205 
206   // In Eager mode we always inline all functions into the top-level
207   // function body graph, to get a single executable graph, that could be
208   // optimized across function boundaries (e.g. prune unused inputs and
209   // outputs in a function call chain). This is required to mimic graph mode
210   // execution, with aggressive pruning of nodes not in the transitive fanin
211   // of fetches.
212   options.config_proto.mutable_graph_options()
213       ->mutable_optimizer_options()
214       ->set_do_function_inlining(true);
215 
216   options.config_proto.set_log_device_placement(log_device_placement);
217 
218   TF_RETURN_IF_ERROR(
219       pflr_->Instantiate(ndef.op(), AttrSlice(ndef), options, &handle_));
220   return pflr_->IsCrossProcess(handle_, &is_cross_process_);
221 }
222 
Init(const bool log_device_placement,const NodeDef & ndef,GraphCollector * graph_collector)223 Status KernelAndDeviceFunc::Init(const bool log_device_placement,
224                                  const NodeDef& ndef,
225                                  GraphCollector* graph_collector) {
226   TF_RETURN_IF_ERROR(
227       InstantiateFunc(log_device_placement, ndef, graph_collector));
228   return pflr_->GetOutputDevices(handle_, &output_devices_);
229 }
230 
231 namespace {
232 // In certain contexts (e.g. TPU async executions), the CancellationManager is
233 // used to shut down the device in error scenarios (as opposed to using the
234 // AsyncCompute's DoneCallback). This is handled through the
235 // {inc,dec}_num_deferred_ops_function.
236 struct OpExecutionState : public core::RefCounted {
237   // TODO(nareshmodi): consider refcounting the cancellation_manager.
238   CancellationManager cancellation_manager;
239 };
240 }  // anonymous namespace
241 
Run(ScopedStepContainer * step_container,const EagerKernelArgs & inputs,std::vector<EagerKernelRet> * outputs,CancellationManager * cancellation_manager,const absl::optional<EagerRemoteFunctionParams> & remote_func_params,const absl::optional<ManagedStackTrace> & stack_trace)242 Status KernelAndDeviceOp::Run(
243     ScopedStepContainer* step_container, const EagerKernelArgs& inputs,
244     std::vector<EagerKernelRet>* outputs,
245     CancellationManager* cancellation_manager,
246     const absl::optional<EagerRemoteFunctionParams>& remote_func_params,
247     const absl::optional<ManagedStackTrace>& stack_trace) {
248   OpKernelContext::Params params;
249   params.device = device_;
250   params.frame_iter = FrameAndIter(0, 0);
251   params.inputs = inputs.GetTensorValues();
252   params.op_kernel = kernel_.get();
253   params.resource_manager = device_->resource_manager();
254   params.input_alloc_attrs = &input_alloc_attrs_;
255   params.output_attr_array = output_alloc_attrs_.data();
256   params.function_library = flr_;
257   params.slice_reader_cache = &slice_reader_cache_;
258   params.rendezvous = rendezvous_;
259   params.stack_trace = stack_trace;
260   OpExecutionState* op_execution_state = nullptr;
261 
262   CancellationManager default_cancellation_manager;
263   if (cancellation_manager) {
264     params.cancellation_manager = cancellation_manager;
265   } else if (kernel_->is_deferred()) {
266     op_execution_state = new OpExecutionState;
267     params.cancellation_manager = &op_execution_state->cancellation_manager;
268     params.inc_num_deferred_ops_function = [op_execution_state]() {
269       op_execution_state->Ref();
270     };
271     params.dec_num_deferred_ops_function = [op_execution_state]() {
272       op_execution_state->Unref();
273     };
274   } else {
275     params.cancellation_manager = &default_cancellation_manager;
276   }
277 
278   params.log_memory = log_memory_;
279 
280   params.runner = get_runner();
281 
282   params.step_container = step_container;
283 
284   params.collective_executor =
285       collective_executor_ ? collective_executor_->get() : nullptr;
286 
287   OpKernelContext context(&params);
288 
289   {
290     port::ScopedFlushDenormal flush;
291     port::ScopedSetRound round(FE_TONEAREST);
292     // 'AnnotatedTraceMe' will trace both scheduling time on host and execution
293     // time on device of the OpKernel.
294     profiler::AnnotatedTraceMe activity(
295         [&] { return kernel_->TraceString(context, /*verbose=*/false); },
296         profiler::TraceMeLevel::kInfo);
297     device_->Compute(kernel_.get(), &context);
298   }
299 
300   // Clean up execution op_execution_state if deferred ops aren't running.
301   if (op_execution_state != nullptr) {
302     op_execution_state->Unref();
303   }
304 
305   if (!context.status().ok()) return context.status();
306 
307   if (outputs != nullptr) {
308     outputs->clear();
309     for (int i = 0; i < context.num_outputs(); ++i) {
310       const auto* output_tensor = context.mutable_output(i);
311       if (output_tensor != nullptr) {
312         outputs->push_back(Tensor(*output_tensor));
313       } else {
314         outputs->push_back(Tensor());
315       }
316     }
317   }
318   return Status::OK();
319 }
320 
Run(ScopedStepContainer * step_container,const EagerKernelArgs & inputs,std::vector<EagerKernelRet> * outputs,CancellationManager * cancellation_manager,const absl::optional<EagerRemoteFunctionParams> & remote_func_params,const absl::optional<ManagedStackTrace> & stack_trace)321 Status KernelAndDeviceFunc::Run(
322     ScopedStepContainer* step_container, const EagerKernelArgs& inputs,
323     std::vector<EagerKernelRet>* outputs,
324     CancellationManager* cancellation_manager,
325     const absl::optional<EagerRemoteFunctionParams>& remote_func_params,
326     const absl::optional<ManagedStackTrace>& stack_trace) {
327   Notification n;
328   Status status;
329   RunAsync(step_container, inputs, outputs, cancellation_manager,
330            remote_func_params, [&status, &n](const Status& s) {
331              status = s;
332              n.Notify();
333            });
334   n.WaitForNotification();
335   return status;
336 }
337 
RunAsync(ScopedStepContainer * step_container,const EagerKernelArgs & inputs,std::vector<EagerKernelRet> * outputs,CancellationManager * cancellation_manager,const absl::optional<EagerRemoteFunctionParams> & remote_func_params,std::function<void (const Status &)> done)338 void KernelAndDeviceFunc::RunAsync(
339     ScopedStepContainer* step_container, const EagerKernelArgs& inputs,
340     std::vector<EagerKernelRet>* outputs,
341     CancellationManager* cancellation_manager,
342     const absl::optional<EagerRemoteFunctionParams>& remote_func_params,
343     std::function<void(const Status&)> done) {
344   std::shared_ptr<FunctionLibraryRuntime::Options> opts = nullptr;
345   if (remote_func_params.has_value()) {
346     const EagerRemoteFunctionParams& params = remote_func_params.value();
347     if (params.step_id.has_value()) {
348       // If the function is a remote component of a cross-process function,
349       // re-use the step id as its parent function's.
350       opts = std::make_shared<FunctionLibraryRuntime::Options>(
351           params.step_id.value());
352     } else {
353       opts = std::make_shared<FunctionLibraryRuntime::Options>();
354     }
355     // Reuse the op id if it exists.
356     opts->op_id = params.op_id;
357   } else {
358     opts = std::make_shared<FunctionLibraryRuntime::Options>();
359     if (get_op_id_ && is_cross_process_) {
360       // If the function is a cross-process function and the remote execution
361       // goes through eager service, create an eager op id for the function.
362       opts->op_id = get_op_id_();
363     }
364   }
365 
366   // We don't pass rendezvous from eager context because we can get tensor
367   // name collisions in send/recv ops when running multiple instances
368   // of the same multi-device function concurrently.
369   Rendezvous* rendezvous = rendezvous_creator_(opts->step_id);
370   opts->rendezvous = rendezvous;
371   opts->create_rendezvous = false;
372 
373   // Create a cancellation manager to be used by FLR options if caller does not
374   // pass in one. If the caller does provide one, pass it to process FLR and the
375   // locally created one will be unused.
376   std::shared_ptr<CancellationManager> local_cm;
377   if (cancellation_manager) {
378     opts->cancellation_manager = cancellation_manager;
379   } else {
380     local_cm = std::make_shared<CancellationManager>();
381     opts->cancellation_manager = local_cm.get();
382   }
383   opts->allow_dead_tensors = true;
384   opts->step_container = step_container;
385   opts->collective_executor =
386       collective_executor_ ? collective_executor_->get() : nullptr;
387 
388   opts->stats_collector = nullptr;
389   opts->runner = get_runner();
390 
391   outputs->clear();
392 
393   pflr_->Run(*opts, handle_, inputs, outputs,
394              [opts, rendezvous, local_cm, step_container, this,
395               done = std::move(done)](const Status& s) {
396                rendezvous->Unref();
397                done(s);
398              });
399 }
400 
OutputDevice(int idx) const401 tensorflow::Device* KernelAndDeviceOp::OutputDevice(int idx) const {
402   if (kernel_->output_memory_types()[idx] == HOST_MEMORY) {
403     return nullptr;
404   }
405   return device_;
406 }
407 
OutputDevice(int idx) const408 tensorflow::Device* KernelAndDeviceFunc::OutputDevice(int idx) const {
409   if (output_dtypes_[idx] == DT_RESOURCE) {
410     return nullptr;
411   }
412   return output_devices_[idx];
413 }
414 
OutputResourceDevice(int idx) const415 tensorflow::Device* KernelAndDeviceOp::OutputResourceDevice(int idx) const {
416   if (kernel_->output_type(idx) == DT_RESOURCE) {
417     return device_;
418   }
419   return nullptr;
420 }
421 
OutputResourceDevice(int idx) const422 tensorflow::Device* KernelAndDeviceFunc::OutputResourceDevice(int idx) const {
423   if (output_dtypes_[idx] == DT_RESOURCE) {
424     return output_devices_[idx];
425   }
426   return nullptr;
427 }
428 
InputDevice(int i) const429 Device* KernelAndDeviceOp::InputDevice(int i) const {
430   return input_devices_[i];
431 }
432 
InputDevice(int i) const433 Device* KernelAndDeviceFunc::InputDevice(int i) const {
434   if ((input_dtypes_[i] == DT_RESOURCE) &&
435       (composite_devices_.find(input_devices_[i]->name()) ==
436        composite_devices_.end())) {
437     return host_cpu_device_;
438   } else {
439     return input_devices_[i];
440   }
441 }
442 
443 }  // namespace tensorflow
444