1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/core/common_runtime/eager/kernel_and_device.h"
17
18 #include <memory>
19
20 #include "absl/strings/match.h"
21 #include "tensorflow/core/common_runtime/device_factory.h"
22 #include "tensorflow/core/common_runtime/eager/attr_builder.h"
23 #include "tensorflow/core/common_runtime/process_function_library_runtime.h"
24 #include "tensorflow/core/common_runtime/rendezvous_mgr.h"
25 #include "tensorflow/core/framework/allocator.h"
26 #include "tensorflow/core/framework/cancellation.h"
27 #include "tensorflow/core/framework/function.h"
28 #include "tensorflow/core/framework/node_def.pb.h"
29 #include "tensorflow/core/framework/op_kernel.h"
30 #include "tensorflow/core/framework/resource_mgr.h"
31 #include "tensorflow/core/framework/types.h"
32 #include "tensorflow/core/framework/types.pb.h"
33 #include "tensorflow/core/lib/core/errors.h"
34 #include "tensorflow/core/lib/core/refcount.h"
35 #include "tensorflow/core/lib/gtl/cleanup.h"
36 #include "tensorflow/core/lib/gtl/map_util.h"
37 #include "tensorflow/core/lib/random/random.h"
38 #include "tensorflow/core/platform/denormal.h"
39 #include "tensorflow/core/platform/errors.h"
40 #include "tensorflow/core/platform/fingerprint.h"
41 #include "tensorflow/core/platform/setround.h"
42 #include "tensorflow/core/profiler/lib/annotated_traceme.h"
43 #include "tensorflow/core/profiler/lib/traceme.h"
44 #include "tensorflow/core/public/version.h"
45 #include "tensorflow/core/util/tensor_slice_reader_cache.h"
46 #if !defined(IS_MOBILE_PLATFORM)
47 #include "tensorflow/core/grappler/grappler_item.h"
48 #include "tensorflow/core/grappler/optimizers/meta_optimizer.h"
49 #endif // !IS_MOBILE_PLATFORM
50
51 namespace tensorflow {
52
GetLocalArg(const FunctionArgIndex & index,Tensor * val) const53 Status EagerKernelArgs::GetLocalArg(const FunctionArgIndex& index,
54 Tensor* val) const {
55 if (index.sub_index >= 0) {
56 return errors::InvalidArgument("Got unexpected sub_index ", index.sub_index,
57 " for argument ", index.index);
58 }
59 Tensor* arg = tensor_args_.at(index.index).tensor;
60 if (arg) {
61 *val = *arg;
62 return Status::OK();
63 } else {
64 return errors::NotFound("Argument ", index.index, " has no local tensor.");
65 }
66 }
67
GetLocalTensors() const68 std::vector<Tensor> EagerKernelArgs::GetLocalTensors() const {
69 std::vector<Tensor> lcoal_inputs;
70 lcoal_inputs.reserve(tensor_args_.size());
71 for (const TensorValue& tensor_value : tensor_args_) {
72 lcoal_inputs.push_back(*tensor_value.tensor);
73 }
74 return lcoal_inputs;
75 }
76
get_runner() const77 std::function<void(std::function<void()>)>* KernelAndDevice::get_runner()
78 const {
79 if (runner_) {
80 return runner_;
81 } else {
82 static auto* default_runner =
83 new std::function<void(std::function<void()>)>(
84 [](const std::function<void()>& f) { f(); });
85 return default_runner;
86 }
87 }
88
~KernelAndDeviceFunc()89 KernelAndDeviceFunc::~KernelAndDeviceFunc() {
90 if (handle_ != kInvalidHandle) {
91 Status status = pflr_->ReleaseHandle(handle_);
92 if (!status.ok()) {
93 LOG(INFO) << "Ignoring error status when releasing multi-device function "
94 "handle "
95 << status.ToString();
96 }
97 }
98 }
99
Init(const bool log_device_placement,const NodeDef & ndef,GraphCollector * graph_collector)100 Status KernelAndDeviceOp::Init(const bool log_device_placement,
101 const NodeDef& ndef,
102 GraphCollector* graph_collector) {
103 OpKernel* k = nullptr;
104 if (flr_ == nullptr) {
105 return errors::Internal(
106 "A valid FunctionLibraryRuntime must be provided when running ops "
107 "based on OpKernel.");
108 }
109 std::shared_ptr<const NodeProperties> props;
110 TF_RETURN_IF_ERROR(NodeProperties::CreateFromNodeDef(
111 ndef, flr_->GetFunctionLibraryDefinition(), &props));
112 TF_RETURN_IF_ERROR(flr_->CreateKernel(props, &k));
113 kernel_.reset(k);
114
115 input_alloc_attrs_.resize(kernel_->num_inputs());
116 input_devices_.resize(kernel_->num_inputs(), device_);
117 for (size_t i = 0; i < input_alloc_attrs_.size(); ++i) {
118 bool host = kernel_->input_memory_types()[i] == tensorflow::HOST_MEMORY;
119 input_alloc_attrs_[i].set_on_host(host);
120 if (host) {
121 input_devices_[i] = host_cpu_device_;
122 }
123 }
124 output_alloc_attrs_.resize(kernel_->num_outputs());
125 for (size_t i = 0; i < output_alloc_attrs_.size(); ++i) {
126 output_alloc_attrs_[i].set_on_host(kernel_->output_memory_types()[i] ==
127 tensorflow::HOST_MEMORY);
128 }
129
130 return Status::OK();
131 }
132
InstantiateFunc(const bool log_device_placement,const NodeDef & ndef,GraphCollector * graph_collector)133 Status KernelAndDeviceFunc::InstantiateFunc(const bool log_device_placement,
134 const NodeDef& ndef,
135 GraphCollector* graph_collector) {
136 const OpDef* op_def = nullptr;
137 const FunctionDef* function_def;
138 if (flr_ == nullptr) {
139 // If function is being executed without an explicit device request,
140 // lookup the FunctionDef in the CPU's FLR. All FLRs share the same
141 // library.
142 function_def = pflr_->GetFLR(host_cpu_device_->name())
143 ->GetFunctionLibraryDefinition()
144 ->Find(ndef.op());
145 } else {
146 function_def = flr_->GetFunctionLibraryDefinition()->Find(ndef.op());
147 }
148
149 if (function_def != nullptr) {
150 op_def = &(function_def->signature());
151 } else {
152 TF_RETURN_IF_ERROR(OpDefForOp(ndef.op(), &op_def));
153 }
154 TF_RETURN_IF_ERROR(
155 InOutTypesForNode(ndef, *op_def, &input_dtypes_, &output_dtypes_));
156
157 FunctionLibraryRuntime::InstantiateOptions options;
158 options.target = device_ == nullptr ? "" : device_->name();
159 options.is_multi_device_function = true;
160 for (const Device* device : input_devices_) {
161 options.input_devices.push_back(device->name());
162 }
163 options.composite_devices = composite_devices_;
164 options.input_resource_dtypes_and_shapes = input_resource_dtypes_and_shapes_;
165 if (outputs_on_op_device_) {
166 const FunctionLibraryDefinition* lib_def =
167 pflr_->GetFunctionLibraryDefinition();
168 const FunctionDef* fdef = lib_def->Find(ndef.op());
169 if (fdef == nullptr) {
170 return errors::InvalidArgument("Failed to find function ", ndef.op());
171 }
172 for (int i = 0; i < fdef->signature().output_arg_size(); ++i) {
173 options.output_devices.push_back(options.target);
174 }
175 }
176
177 const auto& it = ndef.attr().find("executor_type");
178 if (it != ndef.attr().end()) {
179 options.executor_type = it->second.s();
180 }
181 const auto& is_component_fn_it = ndef.attr().find("is_component_function");
182 if (is_component_fn_it != ndef.attr().end()) {
183 options.is_component_function = is_component_fn_it->second.b();
184 }
185 #if !defined(IS_MOBILE_PLATFORM)
186 // Android tf library does not include grappler.
187 const auto& config_it = ndef.attr().find("config_proto");
188 if (config_it != ndef.attr().end()) {
189 if (!options.config_proto.ParseFromString(config_it->second.s())) {
190 return errors::InvalidArgument(
191 "Failed to parse config_proto attribute as tensorflow::ConfigProto "
192 "proto.");
193 }
194 grappler::GrapplerItem::OptimizationOptions optimization_options =
195 grappler::CreateOptOptionsForEager();
196
197 options.optimize_graph_fn = std::bind(
198 grappler::OptimizeGraph, std::placeholders::_1, std::placeholders::_2,
199 std::placeholders::_3, std::placeholders::_4, std::placeholders::_5,
200 options.config_proto, function_def->signature().name(),
201 optimization_options, std::placeholders::_6);
202 }
203 #endif // !IS_MOBILE_PLATFORM
204 options.graph_collector = graph_collector;
205
206 // In Eager mode we always inline all functions into the top-level
207 // function body graph, to get a single executable graph, that could be
208 // optimized across function boundaries (e.g. prune unused inputs and
209 // outputs in a function call chain). This is required to mimic graph mode
210 // execution, with aggressive pruning of nodes not in the transitive fanin
211 // of fetches.
212 options.config_proto.mutable_graph_options()
213 ->mutable_optimizer_options()
214 ->set_do_function_inlining(true);
215
216 options.config_proto.set_log_device_placement(log_device_placement);
217
218 TF_RETURN_IF_ERROR(
219 pflr_->Instantiate(ndef.op(), AttrSlice(ndef), options, &handle_));
220 return pflr_->IsCrossProcess(handle_, &is_cross_process_);
221 }
222
Init(const bool log_device_placement,const NodeDef & ndef,GraphCollector * graph_collector)223 Status KernelAndDeviceFunc::Init(const bool log_device_placement,
224 const NodeDef& ndef,
225 GraphCollector* graph_collector) {
226 TF_RETURN_IF_ERROR(
227 InstantiateFunc(log_device_placement, ndef, graph_collector));
228 return pflr_->GetOutputDevices(handle_, &output_devices_);
229 }
230
231 namespace {
232 // In certain contexts (e.g. TPU async executions), the CancellationManager is
233 // used to shut down the device in error scenarios (as opposed to using the
234 // AsyncCompute's DoneCallback). This is handled through the
235 // {inc,dec}_num_deferred_ops_function.
236 struct OpExecutionState : public core::RefCounted {
237 // TODO(nareshmodi): consider refcounting the cancellation_manager.
238 CancellationManager cancellation_manager;
239 };
240 } // anonymous namespace
241
Run(ScopedStepContainer * step_container,const EagerKernelArgs & inputs,std::vector<EagerKernelRet> * outputs,CancellationManager * cancellation_manager,const absl::optional<EagerRemoteFunctionParams> & remote_func_params,const absl::optional<ManagedStackTrace> & stack_trace)242 Status KernelAndDeviceOp::Run(
243 ScopedStepContainer* step_container, const EagerKernelArgs& inputs,
244 std::vector<EagerKernelRet>* outputs,
245 CancellationManager* cancellation_manager,
246 const absl::optional<EagerRemoteFunctionParams>& remote_func_params,
247 const absl::optional<ManagedStackTrace>& stack_trace) {
248 OpKernelContext::Params params;
249 params.device = device_;
250 params.frame_iter = FrameAndIter(0, 0);
251 params.inputs = inputs.GetTensorValues();
252 params.op_kernel = kernel_.get();
253 params.resource_manager = device_->resource_manager();
254 params.input_alloc_attrs = &input_alloc_attrs_;
255 params.output_attr_array = output_alloc_attrs_.data();
256 params.function_library = flr_;
257 params.slice_reader_cache = &slice_reader_cache_;
258 params.rendezvous = rendezvous_;
259 params.stack_trace = stack_trace;
260 OpExecutionState* op_execution_state = nullptr;
261
262 CancellationManager default_cancellation_manager;
263 if (cancellation_manager) {
264 params.cancellation_manager = cancellation_manager;
265 } else if (kernel_->is_deferred()) {
266 op_execution_state = new OpExecutionState;
267 params.cancellation_manager = &op_execution_state->cancellation_manager;
268 params.inc_num_deferred_ops_function = [op_execution_state]() {
269 op_execution_state->Ref();
270 };
271 params.dec_num_deferred_ops_function = [op_execution_state]() {
272 op_execution_state->Unref();
273 };
274 } else {
275 params.cancellation_manager = &default_cancellation_manager;
276 }
277
278 params.log_memory = log_memory_;
279
280 params.runner = get_runner();
281
282 params.step_container = step_container;
283
284 params.collective_executor =
285 collective_executor_ ? collective_executor_->get() : nullptr;
286
287 OpKernelContext context(¶ms);
288
289 {
290 port::ScopedFlushDenormal flush;
291 port::ScopedSetRound round(FE_TONEAREST);
292 // 'AnnotatedTraceMe' will trace both scheduling time on host and execution
293 // time on device of the OpKernel.
294 profiler::AnnotatedTraceMe activity(
295 [&] { return kernel_->TraceString(context, /*verbose=*/false); },
296 profiler::TraceMeLevel::kInfo);
297 device_->Compute(kernel_.get(), &context);
298 }
299
300 // Clean up execution op_execution_state if deferred ops aren't running.
301 if (op_execution_state != nullptr) {
302 op_execution_state->Unref();
303 }
304
305 if (!context.status().ok()) return context.status();
306
307 if (outputs != nullptr) {
308 outputs->clear();
309 for (int i = 0; i < context.num_outputs(); ++i) {
310 const auto* output_tensor = context.mutable_output(i);
311 if (output_tensor != nullptr) {
312 outputs->push_back(Tensor(*output_tensor));
313 } else {
314 outputs->push_back(Tensor());
315 }
316 }
317 }
318 return Status::OK();
319 }
320
Run(ScopedStepContainer * step_container,const EagerKernelArgs & inputs,std::vector<EagerKernelRet> * outputs,CancellationManager * cancellation_manager,const absl::optional<EagerRemoteFunctionParams> & remote_func_params,const absl::optional<ManagedStackTrace> & stack_trace)321 Status KernelAndDeviceFunc::Run(
322 ScopedStepContainer* step_container, const EagerKernelArgs& inputs,
323 std::vector<EagerKernelRet>* outputs,
324 CancellationManager* cancellation_manager,
325 const absl::optional<EagerRemoteFunctionParams>& remote_func_params,
326 const absl::optional<ManagedStackTrace>& stack_trace) {
327 Notification n;
328 Status status;
329 RunAsync(step_container, inputs, outputs, cancellation_manager,
330 remote_func_params, [&status, &n](const Status& s) {
331 status = s;
332 n.Notify();
333 });
334 n.WaitForNotification();
335 return status;
336 }
337
RunAsync(ScopedStepContainer * step_container,const EagerKernelArgs & inputs,std::vector<EagerKernelRet> * outputs,CancellationManager * cancellation_manager,const absl::optional<EagerRemoteFunctionParams> & remote_func_params,std::function<void (const Status &)> done)338 void KernelAndDeviceFunc::RunAsync(
339 ScopedStepContainer* step_container, const EagerKernelArgs& inputs,
340 std::vector<EagerKernelRet>* outputs,
341 CancellationManager* cancellation_manager,
342 const absl::optional<EagerRemoteFunctionParams>& remote_func_params,
343 std::function<void(const Status&)> done) {
344 std::shared_ptr<FunctionLibraryRuntime::Options> opts = nullptr;
345 if (remote_func_params.has_value()) {
346 const EagerRemoteFunctionParams& params = remote_func_params.value();
347 if (params.step_id.has_value()) {
348 // If the function is a remote component of a cross-process function,
349 // re-use the step id as its parent function's.
350 opts = std::make_shared<FunctionLibraryRuntime::Options>(
351 params.step_id.value());
352 } else {
353 opts = std::make_shared<FunctionLibraryRuntime::Options>();
354 }
355 // Reuse the op id if it exists.
356 opts->op_id = params.op_id;
357 } else {
358 opts = std::make_shared<FunctionLibraryRuntime::Options>();
359 if (get_op_id_ && is_cross_process_) {
360 // If the function is a cross-process function and the remote execution
361 // goes through eager service, create an eager op id for the function.
362 opts->op_id = get_op_id_();
363 }
364 }
365
366 // We don't pass rendezvous from eager context because we can get tensor
367 // name collisions in send/recv ops when running multiple instances
368 // of the same multi-device function concurrently.
369 Rendezvous* rendezvous = rendezvous_creator_(opts->step_id);
370 opts->rendezvous = rendezvous;
371 opts->create_rendezvous = false;
372
373 // Create a cancellation manager to be used by FLR options if caller does not
374 // pass in one. If the caller does provide one, pass it to process FLR and the
375 // locally created one will be unused.
376 std::shared_ptr<CancellationManager> local_cm;
377 if (cancellation_manager) {
378 opts->cancellation_manager = cancellation_manager;
379 } else {
380 local_cm = std::make_shared<CancellationManager>();
381 opts->cancellation_manager = local_cm.get();
382 }
383 opts->allow_dead_tensors = true;
384 opts->step_container = step_container;
385 opts->collective_executor =
386 collective_executor_ ? collective_executor_->get() : nullptr;
387
388 opts->stats_collector = nullptr;
389 opts->runner = get_runner();
390
391 outputs->clear();
392
393 pflr_->Run(*opts, handle_, inputs, outputs,
394 [opts, rendezvous, local_cm, step_container, this,
395 done = std::move(done)](const Status& s) {
396 rendezvous->Unref();
397 done(s);
398 });
399 }
400
OutputDevice(int idx) const401 tensorflow::Device* KernelAndDeviceOp::OutputDevice(int idx) const {
402 if (kernel_->output_memory_types()[idx] == HOST_MEMORY) {
403 return nullptr;
404 }
405 return device_;
406 }
407
OutputDevice(int idx) const408 tensorflow::Device* KernelAndDeviceFunc::OutputDevice(int idx) const {
409 if (output_dtypes_[idx] == DT_RESOURCE) {
410 return nullptr;
411 }
412 return output_devices_[idx];
413 }
414
OutputResourceDevice(int idx) const415 tensorflow::Device* KernelAndDeviceOp::OutputResourceDevice(int idx) const {
416 if (kernel_->output_type(idx) == DT_RESOURCE) {
417 return device_;
418 }
419 return nullptr;
420 }
421
OutputResourceDevice(int idx) const422 tensorflow::Device* KernelAndDeviceFunc::OutputResourceDevice(int idx) const {
423 if (output_dtypes_[idx] == DT_RESOURCE) {
424 return output_devices_[idx];
425 }
426 return nullptr;
427 }
428
InputDevice(int i) const429 Device* KernelAndDeviceOp::InputDevice(int i) const {
430 return input_devices_[i];
431 }
432
InputDevice(int i) const433 Device* KernelAndDeviceFunc::InputDevice(int i) const {
434 if ((input_dtypes_[i] == DT_RESOURCE) &&
435 (composite_devices_.find(input_devices_[i]->name()) ==
436 composite_devices_.end())) {
437 return host_cpu_device_;
438 } else {
439 return input_devices_[i];
440 }
441 }
442
443 } // namespace tensorflow
444