1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/core/common_runtime/executor.h"
17
18 #include <atomic>
19 #include <memory>
20 #include <vector>
21
22 #include "absl/memory/memory.h"
23 #include "tensorflow/core/common_runtime/costmodel_manager.h"
24 #include "tensorflow/core/common_runtime/entry.h"
25 #include "tensorflow/core/common_runtime/executor_factory.h"
26 #include "tensorflow/core/common_runtime/graph_view.h"
27 #include "tensorflow/core/common_runtime/immutable_executor_state.h"
28 #include "tensorflow/core/common_runtime/pending_counts.h"
29 #include "tensorflow/core/common_runtime/propagator_state.h"
30 #include "tensorflow/core/common_runtime/renamed_device.h"
31 #include "tensorflow/core/common_runtime/simple_propagator_state.h"
32 #include "tensorflow/core/common_runtime/step_stats_collector.h"
33 #include "tensorflow/core/framework/allocator.h"
34 #include "tensorflow/core/framework/cancellation.h"
35 #include "tensorflow/core/framework/collective.h"
36 #include "tensorflow/core/framework/control_flow.h"
37 #include "tensorflow/core/framework/device_attributes.pb.h"
38 #include "tensorflow/core/framework/log_memory.h"
39 #include "tensorflow/core/framework/metrics.h"
40 #include "tensorflow/core/framework/node_def_util.h"
41 #include "tensorflow/core/framework/op_kernel.h"
42 #include "tensorflow/core/framework/op_segment.h"
43 #include "tensorflow/core/framework/tensor.h"
44 #include "tensorflow/core/framework/tensor_reference.h"
45 #include "tensorflow/core/framework/types.h"
46 #include "tensorflow/core/framework/types.pb.h"
47 #include "tensorflow/core/graph/edgeset.h"
48 #include "tensorflow/core/graph/graph.h"
49 #include "tensorflow/core/graph/graph_node_util.h"
50 #include "tensorflow/core/lib/core/errors.h"
51 #include "tensorflow/core/lib/core/notification.h"
52 #include "tensorflow/core/lib/core/status.h"
53 #include "tensorflow/core/lib/core/threadpool.h"
54 #include "tensorflow/core/lib/gtl/flatmap.h"
55 #include "tensorflow/core/lib/gtl/inlined_vector.h"
56 #include "tensorflow/core/lib/gtl/manual_constructor.h"
57 #include "tensorflow/core/lib/hash/hash.h"
58 #include "tensorflow/core/platform/context.h"
59 #include "tensorflow/core/platform/env.h"
60 #include "tensorflow/core/platform/errors.h"
61 #include "tensorflow/core/platform/logging.h"
62 #include "tensorflow/core/platform/macros.h"
63 #include "tensorflow/core/platform/mutex.h"
64 #include "tensorflow/core/platform/profile_utils/cpu_utils.h"
65 #include "tensorflow/core/platform/thread_annotations.h"
66 #include "tensorflow/core/platform/tracing.h"
67 #include "tensorflow/core/platform/types.h"
68 #include "tensorflow/core/profiler/lib/annotated_traceme.h"
69 #include "tensorflow/core/profiler/lib/connected_traceme.h"
70 #include "tensorflow/core/profiler/lib/scoped_annotation.h"
71 #include "tensorflow/core/profiler/lib/traceme_encode.h"
72 #include "tensorflow/core/protobuf/error_codes.pb.h"
73 #include "tensorflow/core/util/tensor_slice_reader_cache.h"
74
75 namespace tensorflow {
76
77 namespace {
78
79 // 1-D, 0 element tensor.
80 static const Tensor* const kEmptyTensor = new Tensor;
81
82 // Helper routines for collecting step stats.
83 namespace nodestats {
NowInNsec()84 inline int64 NowInNsec() { return EnvTime::NowNanos(); }
85
SetScheduled(NodeExecStatsInterface * stats,int64 micros)86 void SetScheduled(NodeExecStatsInterface* stats, int64 micros) {
87 if (!stats) return;
88 stats->SetScheduled(micros * EnvTime::kMicrosToNanos);
89 }
90
SetAllStart(NodeExecStatsInterface * stats)91 void SetAllStart(NodeExecStatsInterface* stats) {
92 if (!stats) return;
93 stats->RecordExecutorStarted();
94 }
95
SetOpStart(NodeExecStatsInterface * stats)96 void SetOpStart(NodeExecStatsInterface* stats) {
97 if (!stats) return;
98 stats->RecordComputeStarted();
99 }
100
SetOpEnd(NodeExecStatsInterface * stats)101 void SetOpEnd(NodeExecStatsInterface* stats) {
102 if (!stats) return;
103 stats->RecordComputeEnded();
104 }
105
SetAllEnd(NodeExecStatsInterface * stats)106 void SetAllEnd(NodeExecStatsInterface* stats) {
107 if (!stats) return;
108 stats->RecordExecutorEnded();
109 }
110
SetOutput(NodeExecStatsInterface * stats,int slot,const Tensor * v)111 void SetOutput(NodeExecStatsInterface* stats, int slot, const Tensor* v) {
112 if (!stats) return;
113 stats->SetOutput(slot, v);
114 }
115
SetMemory(NodeExecStatsInterface * stats,OpKernelContext * ctx)116 void SetMemory(NodeExecStatsInterface* stats, OpKernelContext* ctx) {
117 if (!stats) return;
118 stats->SetMemory(ctx);
119 }
120
121 } // namespace nodestats
122
123 // Time the execution of kernels (in CPU cycles). Used to dynamically identify
124 // inexpensive kernels which can be dispatched inline.
125 struct KernelTimer {
126 uint64 start_cycles = profile_utils::CpuUtils::GetCurrentClockCycle();
127
ElapsedCyclestensorflow::__anon6f8fc96b0111::KernelTimer128 uint64 ElapsedCycles() {
129 return profile_utils::CpuUtils::GetCurrentClockCycle() - start_cycles;
130 }
131 };
132
133 // TODO(b/152925936): Re-evaluate these constants with current usage patterns.
134 typedef gtl::InlinedVector<TensorValue, 4> TensorValueVec;
135 typedef gtl::InlinedVector<AllocatorAttributes, 4> AllocatorAttributeVec;
136
137 class ExecutorImpl : public Executor {
138 public:
ExecutorImpl(const LocalExecutorParams & p)139 explicit ExecutorImpl(const LocalExecutorParams& p) : immutable_state_(p) {}
140
Initialize(const Graph & graph)141 Status Initialize(const Graph& graph) {
142 TF_RETURN_IF_ERROR(immutable_state_.Initialize(graph));
143 kernel_stats_.Initialize(immutable_state_.graph_view());
144 return Status::OK();
145 }
146
147 void RunAsync(const Args& args, DoneCallback done) override;
148
149 private:
150 template <class PropagatorStateType>
151 friend class ExecutorState;
152
153 // Stores execution time information about the kernels in an executor's graph.
154 class KernelStats {
155 public:
156 KernelStats() = default;
157
Initialize(const GraphView & gview)158 void Initialize(const GraphView& gview) {
159 is_expensive_.resize(gview.num_nodes());
160 cost_estimates_ =
161 absl::make_unique<std::atomic_uint_fast64_t[]>(gview.num_nodes());
162 for (int32 i = 0; i < gview.num_nodes(); ++i) {
163 if (gview.node(i)) {
164 is_expensive_[i] =
165 gview.node(i)->kernel && gview.node(i)->kernel->IsExpensive();
166 cost_estimates_[i] = kInitialCostEstimateCycles;
167 }
168 }
169 }
170
171 // Returns true iff the given node is considered "expensive". The
172 // executor uses this flag to optimize graph execution, for example
173 // by "inlining" inexpensive kernels.
IsExpensive(const NodeItem & node) const174 bool IsExpensive(const NodeItem& node) const {
175 return is_expensive_[node.node_id] &&
176 (cost_estimates_[node.node_id].load(std::memory_order_relaxed) >
177 kOpIsExpensiveThresholdCycles);
178 }
179
180 // Returns the value of kernel->IsExpensive().
HasExpensiveMarker(const NodeItem & node) const181 bool HasExpensiveMarker(const NodeItem& node) const {
182 return is_expensive_[node.node_id];
183 }
184
185 // Updates the dynamic cost estimate, which is used to determine whether the
186 // given node is expensive. The new cost estimate is a weighted average of
187 // the old cost estimate and the latest cost. We only update cost estimates
188 // for kernels for which IsExpensive() return true.
UpdateCostEstimate(const NodeItem & node,uint64 elapsed_cycles)189 void UpdateCostEstimate(const NodeItem& node, uint64 elapsed_cycles) {
190 // N.B. Updates to `cost_estimate` are atomic but unlocked. Simultaneous
191 // updates may result in one or more updates being ignored. This does not
192 // affect correctness but may slow down the update frequency.
193 std::atomic_uint_fast64_t& cost_estimate = cost_estimates_[node.node_id];
194 auto prev_estimate = cost_estimate.load(std::memory_order_relaxed);
195
196 uint64 new_estimate =
197 ((kCostDecay - 1) * prev_estimate + elapsed_cycles) / kCostDecay;
198
199 cost_estimate.store(new_estimate, std::memory_order_relaxed);
200 }
201
202 private:
203 // Initial time (in CPU cycles) we expect an operation to take. Used to
204 // determine whether an operation should be place in a threadpool.
205 // Operations start out "expensive".
206 static constexpr uint64 kInitialCostEstimateCycles = 100 * 1000 * 1000;
207 static constexpr uint64 kOpIsExpensiveThresholdCycles = 8000;
208 static constexpr uint64 kCostDecay = 10;
209
210 std::vector<bool> is_expensive_;
211 // std::unique_ptr<std::atomic<bool>[]> is_expensive_;
212 std::unique_ptr<std::atomic_uint_fast64_t[]> cost_estimates_;
213 };
214
215 ImmutableExecutorState immutable_state_;
216 KernelStats kernel_stats_;
217
218 TF_DISALLOW_COPY_AND_ASSIGN(ExecutorImpl);
219 };
220
221 // The state associated with one invocation of ExecutorImpl::Run.
222 //
223 // ExecutorState dispatches nodes when they become ready, and delegates to an
224 // instance of `PropagatorStateType` to keep track of how many predecessors of a
225 // are still pending.
226 //
227 // The template argument `class PropagatorStateType` must define the following
228 // public members:
229 // * A type `TaggedNode`, representing a node to be processed, with public
230 // members:
231 // * `const NodeItem& get_node_item() const`
232 // * `bool get_is_dead() const`
233 // * A type `TaggedNodeReadyQueue`, representing a queue of nodes to be
234 // processed, with public members (having the same meanings as in an
235 // `std::vector<TaggedNode>`):
236 // * `void push_back(const TaggedNode& node)`
237 // * `TaggedNode front() const`
238 // * `void pop_front()`
239 // * `bool empty() const`
240 // * A type `TaggedNodeSeq`, representing a list of nodes to be schedules, with
241 // public members (having the same meanings as in an
242 // `std::vector<TaggedNode>`):
243 // * `size_t size() const`
244 // * `bool empty() const`
245 // * `void clear()`
246 // * `const_iterator begin() const`
247 // * `const_iterator end() const`
248 // * A public constructor, `PropagatorStateType(const ImmutableExecutorState&
249 // immutable_state, int64 step_id)`.
250 // * The following public methods:
251 // * `void ActivateRoots(gtl::ArraySlice<const NodeItem*> roots,
252 // TaggedNodeSeq* ready)`, which creates `TaggedNode` instances for the
253 // nodes in `roots` and adds them to `*ready`
254 // * `void PropagateOutputs(const TaggedNode& tagged_node, EntryVector*
255 // outputs, TaggedNodeSeq* ready)`, which propagates `outputs` from the
256 // given `tagged_node` to the destinations of its output edges, and adds
257 // any newly runnable nodes to `*ready`
258 // * `Entry* GetInputTensors(const TaggedNode& tagged_node) const`, which
259 // returns a pointer to the input tensors for the given `tagged_node`
260 // * `FrameAndIter GetFrameAndIter(const TaggedNode& tagged_node) const`,
261 // which creates a `FrameAndIter` for the given `tagged_node`
262 // * `void DumpState()`, which dumps the dynamic state of the executing graph
263 // * `void MaybeMarkStarted(const TaggedNode& tagged_node)`, which records
264 // that a node has started
265 // * `void MaybeMarkCompleted(const TaggedNode& tagged_node)`, which records
266 // that a node has completed
267 //
268 // See `PropagatorState` in "./propagator_state.h" for an example of a type that
269 // can be used to instantiate `PropagatorStateType`.
270 template <class PropagatorStateType>
271 class ExecutorState {
272 public:
273 ExecutorState(const Executor::Args& args,
274 const ImmutableExecutorState& immutable_state_,
275 ExecutorImpl::KernelStats* kernel_stats_);
276 ~ExecutorState();
277
278 void RunAsync(Executor::DoneCallback done);
279
280 private:
281 // Use `TaggedNode` types defined by `PropagatorStateType`.
282 typedef typename PropagatorStateType::TaggedNode TaggedNode;
283 typedef
284 typename PropagatorStateType::TaggedNodeReadyQueue TaggedNodeReadyQueue;
285 typedef typename PropagatorStateType::TaggedNodeSeq TaggedNodeSeq;
286
287 struct AsyncState;
288
289 // Process a ready node in current thread.
290 void Process(TaggedNode node, int64 scheduled_nsec);
291
292 Status ProcessSync(const NodeItem& item, OpKernelContext::Params* params,
293 EntryVector* outputs, NodeExecStatsInterface* stats);
294 void ProcessAsync(const NodeItem& item, const OpKernelContext::Params& params,
295 const TaggedNode& tagged_node, Entry* first_input,
296 NodeExecStatsInterface* stats);
297 void ProcessNoop(NodeExecStatsInterface* stats);
298 void ProcessConstTensor(const NodeItem& item, EntryVector* outputs,
299 NodeExecStatsInterface* stats);
300
301 // Before invoking item->kernel, fills in its "inputs".
302 Status PrepareInputs(const NodeItem& item, Entry* first_input,
303 TensorValueVec* inputs,
304 AllocatorAttributeVec* input_alloc_attrs,
305 bool* is_input_dead);
306
307 // After item->kernel computation is done, processes its outputs.
308 Status ProcessOutputs(const NodeItem& item, OpKernelContext* ctx,
309 Entry* outputs, NodeExecStatsInterface* stats);
310
311 // Called after each node finishes. Takes ownership of "stats". Returns true
312 // if execution has completed.
313 //
314 // This method will clear `*ready` before returning.
315 bool NodeDone(const Status& s, TaggedNodeSeq* ready,
316 NodeExecStatsInterface* stats,
317 TaggedNodeReadyQueue* inline_ready);
318
319 // Schedule all the expensive nodes in '*ready', and put all the inexpensive
320 // nodes in 'ready' into 'inline_ready'.
321 //
322 // This method will clear `*ready` before returning.
323 //
324 // REQUIRES: `!ready->empty()`.
325 void ScheduleReady(TaggedNodeSeq* ready, TaggedNodeReadyQueue* inline_ready);
326
327 // A wrapper for runner_ to keep track of the pending queue length. Op
328 // execution should dispatch work using this function instead of using runner_
329 // directly.
330 template <typename Closure>
331 void RunTask(Closure&& c);
332
333 // Clean up when this executor is done.
334 void Finish();
335 void ScheduleFinish();
336
337 // Contains the device context assigned by the device at the beginning of a
338 // step.
339 DeviceContext* device_context_ = nullptr;
340
341 const bool vlog_; // true if VLOG_IS_ON(1). Used to check vlog cheaply.
342
343 // true if LogMemory::IsEnabled(). Used to check memory enabled cheaply.
344 const bool log_memory_;
345
346 int64 step_id_;
347 // Not owned.
348 RendezvousInterface* rendezvous_;
349 CollectiveExecutor* collective_executor_ = nullptr;
350 SessionState* session_state_;
351 string session_handle_;
352 const SessionMetadata* session_metadata_ = nullptr;
353 TensorStore* tensor_store_;
354 // Step-local container.
355 ScopedStepContainer* step_container_;
356 StepStatsCollectorInterface* const stats_collector_;
357 const tracing::EventCollector* const event_collector_;
358 Context context_;
359
360 // QUESTION: Make it a checkpoint::TensorSliceReaderCacheWrapper
361 // instead of a pointer? (avoids having to delete).
362 checkpoint::TensorSliceReaderCacheWrapper* slice_reader_cache_;
363 CallFrameInterface* call_frame_;
364 const ImmutableExecutorState& immutable_state_;
365 ExecutorImpl::KernelStats* const kernel_stats_;
366 CancellationManager* cancellation_manager_;
367 // If not null, use this device to schedule intra-op operation
368 std::unique_ptr<DeviceBase> user_device_;
369 Executor::Args::Runner runner_;
370 bool sync_on_finish_;
371 const bool run_all_kernels_inline_;
372
373 PropagatorStateType propagator_;
374
375 // Invoked when the execution finishes.
376 Executor::DoneCallback done_cb_;
377
378 std::atomic_int_fast32_t num_outstanding_ops_;
379
380 // Available via OpKernelContext to every OpKernel invocation.
381 mutex num_deferred_ops_mu_;
382 int64 num_deferred_ops_ TF_GUARDED_BY(num_deferred_ops_mu_) = 0;
383 bool finish_when_deferred_ops_done_ TF_GUARDED_BY(num_deferred_ops_mu_) =
384 false;
385
386 mutex mu_;
387 Status status_ TF_GUARDED_BY(mu_);
388 };
389
390 template <class PropagatorStateType>
ExecutorState(const Executor::Args & args,const ImmutableExecutorState & immutable_state,ExecutorImpl::KernelStats * kernel_stats)391 ExecutorState<PropagatorStateType>::ExecutorState(
392 const Executor::Args& args, const ImmutableExecutorState& immutable_state,
393 ExecutorImpl::KernelStats* kernel_stats)
394 : vlog_(VLOG_IS_ON(1)),
395 log_memory_(LogMemory::IsEnabled()),
396 step_id_(args.step_id),
397 rendezvous_(args.rendezvous),
398 collective_executor_(args.collective_executor),
399 session_state_(args.session_state),
400 session_handle_(args.session_handle),
401 session_metadata_(immutable_state.params().session_metadata),
402 tensor_store_(args.tensor_store),
403 step_container_(args.step_container),
404 stats_collector_(args.stats_collector),
405 event_collector_(
406 tracing::GetEventCollector(tracing::EventCategory::kCompute)),
407 context_(ContextKind::kThread),
408 slice_reader_cache_(new checkpoint::TensorSliceReaderCacheWrapper),
409 call_frame_(args.call_frame),
410 immutable_state_(immutable_state),
411 kernel_stats_(kernel_stats),
412 cancellation_manager_(args.cancellation_manager),
413 runner_(args.runner),
414 sync_on_finish_(args.sync_on_finish),
415 run_all_kernels_inline_(args.run_all_kernels_inline),
416 propagator_(immutable_state, step_id_, vlog_),
417 num_outstanding_ops_(0) {
418 if (args.user_intra_op_threadpool != nullptr) {
419 Device* device = immutable_state_.params().device;
420 user_device_ = RenamedDevice::NewRenamedDevice(
421 device->name(), device, false, false, args.user_intra_op_threadpool);
422 }
423 }
424
425 template <class PropagatorStateType>
~ExecutorState()426 ExecutorState<PropagatorStateType>::~ExecutorState() {
427 if (device_context_) {
428 device_context_->Unref();
429 }
430 delete slice_reader_cache_;
431 }
432
433 template <class PropagatorStateType>
434 template <typename Closure>
RunTask(Closure && c)435 void ExecutorState<PropagatorStateType>::RunTask(Closure&& c) {
436 // Align the atomic variables at 64 bytes to avoid false-sharing, assuming the
437 // cacheline size is 64 bytes or smaller.
438 alignas(64) static std::atomic<int64_t> num_enqueue_ops{0};
439 alignas(64) static std::atomic<int64_t> num_dequeue_ops{0};
440
441 auto n_enqueues = num_enqueue_ops.fetch_add(1, std::memory_order_relaxed);
442 // Sample the queue length on every 16 enqueue operations. This amortizes the
443 // cost of metric updates across 16 operations.
444 if (n_enqueues % 16 == 0) {
445 auto n_dequeues = num_dequeue_ops.load(std::memory_order_relaxed);
446 metrics::UpdateGraphPendingQueueLength(n_enqueues - n_dequeues);
447 }
448
449 // mutable is needed because std::forward<Closure> in the lambda body may move
450 // the Closure `c`.
451 runner_([c = std::forward<Closure>(c)]() mutable {
452 num_dequeue_ops.fetch_add(1, std::memory_order_relaxed);
453 std::forward<Closure>(c)();
454 });
455 }
456
457 template <class PropagatorStateType>
RunAsync(Executor::DoneCallback done)458 void ExecutorState<PropagatorStateType>::RunAsync(Executor::DoneCallback done) {
459 TaggedNodeSeq ready;
460
461 // Ask the device to fill in the device context map.
462 Device* device = immutable_state_.params().device;
463 const Status get_context_status =
464 device->TryGetDeviceContext(&device_context_);
465 if (!get_context_status.ok()) {
466 delete this;
467 done(get_context_status);
468 return;
469 }
470
471 // Initialize the ready queue.
472 ready.reserve(immutable_state_.root_nodes().size());
473 propagator_.ActivateRoots(immutable_state_.root_nodes(), &ready);
474 num_outstanding_ops_ = ready.size();
475 if (ready.empty()) {
476 delete this;
477 done(Status::OK());
478 } else {
479 done_cb_ = std::move(done);
480 // Schedule to run all the ready ops in thread pool.
481 ScheduleReady(&ready, nullptr);
482 }
483 }
484
485 // State kept alive for executing an asynchronous node in another
486 // thread. NOTE: We need to make a copy of p.input and p.input_alloc_attrs for
487 // asynchronous kernels because OpKernelContext methods like input_type(i) needs
488 // the param points to valid input type vector. It's not an issue for
489 // sync kernels because these vectors are kept on the stack.
490 template <class PropagatorStateType>
491 struct ExecutorState<PropagatorStateType>::AsyncState {
AsyncStatetensorflow::__anon6f8fc96b0111::ExecutorState::AsyncState492 AsyncState(const OpKernelContext::Params& p, const TaggedNode& _tagged_node,
493 const NodeItem* _item, Entry* _first_input,
494 NodeExecStatsInterface* _stats)
495 : saved_inputs(*p.inputs),
496 saved_input_alloc_attrs(*p.input_alloc_attrs),
497 params(p),
498 tagged_node(_tagged_node),
499 item(_item),
500 first_input(_first_input),
501 // ParamsButClearingEigenGPUDevice does equivalent of
502 // params.eigen_gpu_device = nullptr;
503 ctx(ParamsButClearingEigenGPUDevice(¶ms), item->num_outputs),
504 stats(_stats) {
505 params.inputs = &saved_inputs;
506 params.input_alloc_attrs = &saved_input_alloc_attrs;
507 }
508
509 TensorValueVec saved_inputs;
510 AllocatorAttributeVec saved_input_alloc_attrs;
511 OpKernelContext::Params params;
512 TaggedNode tagged_node;
513 const NodeItem* item;
514 Entry* first_input;
515 OpKernelContext ctx;
516 NodeExecStatsInterface* stats;
517
518 private:
ParamsButClearingEigenGPUDevicetensorflow::__anon6f8fc96b0111::ExecutorState::AsyncState519 OpKernelContext::Params* ParamsButClearingEigenGPUDevice(
520 OpKernelContext::Params* p) {
521 // Ensure OpKernelContext constructor will make a new eigen GPU device if
522 // necessary.
523 p->eigen_gpu_device = nullptr; // Force allocation
524 return p;
525 }
526 };
527
528 // Returns true if `item` might be traced by the given trace and event
529 // collectors. Returns false only if `item` definitely will not be traced.
MightTrace(const tracing::EventCollector * event_collector,bool is_expensive)530 bool MightTrace(const tracing::EventCollector* event_collector,
531 bool is_expensive) {
532 // Tracing will only be enabled if either `event_collector` is non null,
533 // or `trace_collector` is non-null and enabled for this particular kernel.
534 // Although `profiler::TraceMe`, `profiler::ScopedAnnotation`, and
535 // `tracing::ScopedRegion` check subsets of these properties internally in
536 // their constructors, the cost of passing the necessary arguments to them can
537 // be significant, so we avoid constructing them in the common case (when we
538 // know they will not be used).
539 if (event_collector != nullptr) {
540 return true;
541 }
542
543 if (profiler::ScopedAnnotation::IsEnabled()) return true;
544
545 return profiler::TraceMe::Active(profiler::GetTFTraceMeLevel(is_expensive));
546 }
547
548 template <class PropagatorStateType>
ProcessSync(const NodeItem & item,OpKernelContext::Params * params,EntryVector * outputs,NodeExecStatsInterface * stats)549 Status ExecutorState<PropagatorStateType>::ProcessSync(
550 const NodeItem& item, OpKernelContext::Params* params, EntryVector* outputs,
551 NodeExecStatsInterface* stats) {
552 Status s;
553 OpKernelContext ctx(params, item.num_outputs);
554 nodestats::SetOpStart(stats);
555
556 OpKernel* op_kernel = item.kernel;
557 Device* device = immutable_state_.params().device;
558 const bool is_expensive = kernel_stats_->IsExpensive(item);
559
560 if (TF_PREDICT_FALSE(MightTrace(event_collector_, is_expensive))) {
561 tracing::ScopedRegion region(tracing::EventCategory::kCompute,
562 op_kernel->name_view());
563 profiler::AnnotatedTraceMe activity(
564 [op_kernel, &ctx] {
565 return op_kernel->TraceString(
566 ctx, /*verbose=*/profiler::TfOpDetailsEnabled());
567 },
568 profiler::GetTFTraceMeLevel(is_expensive));
569 device->Compute(op_kernel, &ctx);
570 } else if (kernel_stats_->HasExpensiveMarker(item)) {
571 KernelTimer timer;
572 device->Compute(op_kernel, &ctx);
573 // For expensive kernels, always update the cost estimate. For inexpensive
574 // kernels, update the cost estimate with ~1/16 probability. This assumes
575 // that the last 4 bits of the CPU cycle count is uniformly distributed.
576 constexpr int kKernelExecutionTrackingInvocationSkipCount = 16;
577 if (is_expensive ||
578 timer.start_cycles % kKernelExecutionTrackingInvocationSkipCount == 0) {
579 kernel_stats_->UpdateCostEstimate(item, timer.ElapsedCycles());
580 }
581 } else {
582 device->Compute(op_kernel, &ctx);
583 }
584 nodestats::SetOpEnd(stats);
585 if (outputs->size() < item.num_outputs) outputs->resize(item.num_outputs);
586 s = ProcessOutputs(item, &ctx, outputs->data(), stats);
587 nodestats::SetMemory(stats, &ctx);
588 return s;
589 }
590
591 template <class PropagatorStateType>
ProcessAsync(const NodeItem & item,const OpKernelContext::Params & params,const TaggedNode & tagged_node,Entry * first_input,NodeExecStatsInterface * stats)592 void ExecutorState<PropagatorStateType>::ProcessAsync(
593 const NodeItem& item, const OpKernelContext::Params& params,
594 const TaggedNode& tagged_node, Entry* first_input,
595 NodeExecStatsInterface* stats) {
596 AsyncOpKernel* async_kernel = item.kernel->AsAsync();
597 DCHECK(async_kernel != nullptr);
598 AsyncState* state =
599 new AsyncState(params, tagged_node, &item, first_input, stats);
600
601 auto done = [this, state]() {
602 Device* device = immutable_state_.params().device;
603 NodeExecStatsInterface* stats = state->stats; // Shorthand
604 Entry* first_input = state->first_input; // Shorthand
605
606 nodestats::SetOpEnd(stats);
607 EntryVector outputs(state->item->num_outputs);
608 Status s = ProcessOutputs(*state->item, &state->ctx, outputs.data(), stats);
609 nodestats::SetMemory(stats, &state->ctx);
610 if (vlog_) {
611 VLOG(2) << "Async kernel done: " << state->item->node_id << " step "
612 << step_id_ << " " << SummarizeNodeDef(state->item->kernel->def())
613 << (state->tagged_node.get_is_dead() ? " is dead" : "")
614 << " device: " << device->name();
615 }
616
617 // Clears inputs.
618 const int num_inputs = state->item->num_inputs;
619 for (int i = 0; i < num_inputs; ++i) {
620 (first_input + i)->ClearVal();
621 }
622 propagator_.MaybeMarkCompleted(state->tagged_node);
623 TaggedNodeSeq ready;
624 if (s.ok()) {
625 propagator_.PropagateOutputs(state->tagged_node, &outputs, &ready);
626 }
627 outputs.clear();
628 const bool completed = NodeDone(s, &ready, stats, nullptr);
629 delete state;
630 if (completed) ScheduleFinish();
631 };
632 nodestats::SetOpStart(stats);
633 {
634 profiler::AnnotatedTraceMe activity(
635 [async_kernel, state] {
636 return async_kernel->TraceString(
637 state->ctx, /*verbose=*/profiler::TfOpDetailsEnabled());
638 },
639 profiler::GetTFTraceMeLevel(kernel_stats_->IsExpensive(item)));
640 immutable_state_.params().device->ComputeAsync(async_kernel, &state->ctx,
641 std::move(done));
642 }
643 }
644
645 template <class PropagatorStateType>
ProcessNoop(NodeExecStatsInterface * stats)646 void ExecutorState<PropagatorStateType>::ProcessNoop(
647 NodeExecStatsInterface* stats) {
648 nodestats::SetOpStart(stats);
649 nodestats::SetOpEnd(stats);
650 }
651
652 template <class PropagatorStateType>
ProcessConstTensor(const NodeItem & item,EntryVector * outputs,NodeExecStatsInterface * stats)653 void ExecutorState<PropagatorStateType>::ProcessConstTensor(
654 const NodeItem& item, EntryVector* outputs, NodeExecStatsInterface* stats) {
655 nodestats::SetOpStart(stats);
656 nodestats::SetOpEnd(stats);
657 Entry& output = (*outputs)[0];
658 output.state = Entry::State::HAS_CONST_TENSOR;
659 output.const_tensor = item.const_tensor;
660 output.alloc_attr = item.output_attrs()[0];
661 }
662
663 template <class PropagatorStateType>
Process(TaggedNode tagged_node,int64 scheduled_nsec)664 void ExecutorState<PropagatorStateType>::Process(TaggedNode tagged_node,
665 int64 scheduled_nsec) {
666 profiler::TraceMeConsumer activity(
667 // From TraceMeProducer in DirectSession::RunInternal,
668 // GraphMgr::ExecuteAsync, or FunctionLibraryRuntime::Run.
669 [&] {
670 // NOTE: This tracing uses the iteration number from the first tagged
671 // node that executes during this call to `Process()`. In principle,
672 // subsequent nodes could have different values of `iter_num` that
673 // will not be traced.
674 return profiler::TraceMeEncode(
675 "ExecutorState::Process",
676 {{"id", step_id_}, {"iter_num", tagged_node.get_iter_num()}});
677 },
678 profiler::ContextType::kTfExecutor, step_id_,
679 profiler::TraceMeLevel::kInfo);
680 WithContext wc(context_);
681 TaggedNodeSeq ready;
682 TaggedNodeReadyQueue inline_ready;
683
684 // Parameters passed to OpKernel::Compute.
685 TensorValueVec inputs;
686 AllocatorAttributeVec input_alloc_attrs;
687
688 OpKernelContext::Params params;
689 params.step_id = step_id_;
690 // Override device's threadpool if user provides an intra_op_threadpool
691 Device* device = immutable_state_.params().device;
692 if (user_device_) {
693 params.device = user_device_.get();
694 } else {
695 params.device = device;
696 }
697 params.log_memory = log_memory_;
698 params.rendezvous = rendezvous_;
699 params.collective_executor = collective_executor_;
700 params.session_state = session_state_;
701 params.session_handle = session_handle_;
702 params.session_metadata = session_metadata_;
703 params.tensor_store = tensor_store_;
704 params.cancellation_manager = cancellation_manager_;
705 params.call_frame = call_frame_;
706 params.function_library = immutable_state_.params().function_library;
707 params.resource_manager = device->resource_manager();
708 params.step_container = step_container_;
709 params.slice_reader_cache = slice_reader_cache_;
710 params.inputs = &inputs;
711 params.input_alloc_attrs = &input_alloc_attrs;
712 params.runner = &runner_;
713 params.run_all_kernels_inline = run_all_kernels_inline_;
714 params.stats_collector = stats_collector_;
715 params.inc_num_deferred_ops_function = [this]() {
716 mutex_lock lock(num_deferred_ops_mu_);
717 num_deferred_ops_++;
718 };
719 params.dec_num_deferred_ops_function = [this]() {
720 bool finish_when_deferred_ops_done = false;
721 {
722 mutex_lock lock(num_deferred_ops_mu_);
723 num_deferred_ops_--;
724 if (num_deferred_ops_ == 0) {
725 finish_when_deferred_ops_done = finish_when_deferred_ops_done_;
726 }
727 }
728 // Invoke Finish if the graph processing has completed. Finish is always
729 // called exactly once per ExecutorState, either here if there are any
730 // deferred ops, or in ScheduleFinish if there aren't any deferred ops.
731 if (finish_when_deferred_ops_done) Finish();
732 };
733
734 // Set the device_context for this device, if it exists.
735 params.op_device_context = device_context_;
736
737 Status s;
738 NodeExecStatsInterface* stats = nullptr;
739
740 EntryVector outputs(1);
741
742 bool completed = false;
743 inline_ready.push_back(tagged_node);
744 while (!inline_ready.empty()) {
745 tagged_node = inline_ready.front();
746 inline_ready.pop_front();
747 const NodeItem& item = tagged_node.get_node_item();
748 const int id = item.node_id;
749
750 propagator_.MaybeMarkStarted(tagged_node);
751
752 params.track_allocations = false;
753 stats = nullptr;
754 if (stats_collector_ && !tagged_node.get_is_dead()) {
755 stats = stats_collector_->CreateNodeExecStats(&item.kernel->def());
756 // Track allocations if and only if we are collecting statistics, and
757 // `stats` object is expecting allocations to be tracked.
758 params.track_allocations = stats ? stats->TrackAllocations() : false;
759 nodestats::SetScheduled(stats, scheduled_nsec);
760 nodestats::SetAllStart(stats);
761 }
762
763 if (vlog_) {
764 VLOG(1) << "Process node: " << id << " step " << params.step_id << " "
765 << SummarizeNodeDef(item.kernel->def())
766 << (tagged_node.get_is_dead() ? " is dead" : "")
767 << " device: " << device->name();
768 }
769
770 Entry* first_input = propagator_.GetInputTensors(tagged_node);
771
772 // Only execute this node if it is not dead or it is a send/recv
773 // transfer node. For transfer nodes, we need to propagate the "dead"
774 // bit even when the node is dead.
775 bool launched_asynchronously = false;
776 if (tagged_node.get_is_dead() && !item.is_transfer_node) {
777 if (outputs.size() < item.num_outputs) outputs.resize(item.num_outputs);
778 } else if (TF_PREDICT_FALSE(item.is_noop)) {
779 ProcessNoop(stats);
780 } else if (item.const_tensor != nullptr && !params.track_allocations) {
781 ProcessConstTensor(item, &outputs, stats);
782 } else {
783 // Prepares inputs.
784 bool is_input_dead = false;
785 s = PrepareInputs(item, first_input, &inputs, &input_alloc_attrs,
786 &is_input_dead);
787 if (!s.ok()) {
788 // Clear inputs.
789 const int num_inputs = item.num_inputs;
790 for (int i = 0; i < num_inputs; ++i) {
791 (first_input + i)->ClearVal();
792 }
793 propagator_.MaybeMarkCompleted(tagged_node);
794 // Continue to process the nodes in 'inline_ready'.
795 completed = NodeDone(s, &ready, stats, &inline_ready);
796 continue;
797 }
798
799 // Set up compute params.
800 params.op_kernel = item.kernel;
801 params.frame_iter = propagator_.GetFrameAndIter(tagged_node);
802 params.is_input_dead = is_input_dead;
803 params.output_attr_array = item.output_attrs();
804 params.forward_from_array = item.forward_from();
805 params.outputs_required_array = item.outputs_required.get();
806
807 if (item.kernel_is_async) {
808 ProcessAsync(item, params, tagged_node, first_input, stats);
809 launched_asynchronously = true;
810 } else {
811 s = ProcessSync(item, ¶ms, &outputs, stats);
812 }
813 }
814
815 if (!launched_asynchronously) {
816 if (vlog_) {
817 VLOG(2) << "Synchronous kernel done: " << id << " step "
818 << params.step_id << " " << SummarizeNodeDef(item.kernel->def())
819 << (tagged_node.get_is_dead() ? " is dead: " : "")
820 << " device: " << device->name();
821 }
822
823 // Clears inputs.
824 const int num_inputs = item.num_inputs;
825 for (int i = 0; i < num_inputs; ++i) {
826 (first_input + i)->ClearVal();
827 }
828 propagator_.MaybeMarkCompleted(tagged_node);
829 // Propagates outputs.
830 if (s.ok()) {
831 propagator_.PropagateOutputs(tagged_node, &outputs, &ready);
832 }
833
834 // Clear outputs without deallocating the `outputs` vector.
835 const int num_outputs = item.num_outputs;
836 for (int i = 0; i < num_outputs; ++i) {
837 outputs[i].ClearVal();
838 }
839
840 if (stats) {
841 scheduled_nsec = nodestats::NowInNsec();
842 }
843 // Postprocess.
844 completed = NodeDone(s, &ready, stats, &inline_ready);
845 }
846 } // while !inline_ready.empty()
847
848 // This thread of computation is done if completed = true.
849 if (completed) ScheduleFinish();
850 }
851
852 template <class PropagatorStateType>
PrepareInputs(const NodeItem & item,Entry * first_input,TensorValueVec * inputs,AllocatorAttributeVec * input_alloc_attrs,bool * is_input_dead)853 Status ExecutorState<PropagatorStateType>::PrepareInputs(
854 const NodeItem& item, Entry* first_input, TensorValueVec* inputs,
855 AllocatorAttributeVec* input_alloc_attrs, bool* is_input_dead) {
856 inputs->resize(item.num_inputs);
857 input_alloc_attrs->resize(item.num_inputs);
858
859 *is_input_dead = false;
860
861 for (int i = 0; i < item.num_inputs; ++i) {
862 const bool expect_ref = TF_PREDICT_FALSE(item.is_any_input_ref_typed) &&
863 IsRefType(item.input_type(i));
864 Entry* entry = first_input + i;
865 (*input_alloc_attrs)[i] = entry->alloc_attr;
866
867 // i-th input.
868 TensorValue* inp = &(*inputs)[i];
869
870 switch (entry->state) {
871 case Entry::State::NO_VALUE: {
872 // Only merge and transfer nodes can have no-value inputs.
873 inp->mutex_if_ref = nullptr;
874 if (item.is_merge) {
875 inp->tensor = nullptr;
876 } else {
877 DCHECK(item.is_transfer_node)
878 << item.kernel->name() << " - input " << i;
879 entry->state = Entry::State::HAS_CONST_TENSOR;
880 entry->const_tensor = kEmptyTensor;
881 // NOTE(mrry): This `const_cast` is necessary because `TensorValue`
882 // stores a non-const `Tensor*`, and relies on the `OpKernelContext`
883 // accessors making dynamic checks that prevent using an immutable
884 // tensor as a mutable tensor.
885 inp->tensor = const_cast<Tensor*>(kEmptyTensor);
886 *is_input_dead = true;
887 }
888 break;
889 }
890
891 case Entry::State::HAS_VALUE: {
892 if (TF_PREDICT_FALSE(expect_ref)) {
893 return AttachDef(
894 errors::InvalidArgument(i, "-th input expects a ref type"),
895 item.kernel->def());
896 }
897 inp->mutex_if_ref = nullptr;
898 inp->tensor = entry->val.get();
899 break;
900 }
901
902 case Entry::State::HAS_CONST_TENSOR: {
903 if (TF_PREDICT_FALSE(expect_ref)) {
904 return AttachDef(
905 errors::InvalidArgument(i, "-th input expects a ref type"),
906 item.kernel->def());
907 }
908 // NOTE(mrry): This `const_cast` is necessary because `TensorValue`
909 // stores a non-const `Tensor*`, and relies on the `OpKernelContext`
910 // accessors making dynamic checks that prevent using an immutable
911 // tensor as a mutable tensor.
912 inp->mutex_if_ref = nullptr;
913 inp->tensor = const_cast<Tensor*>(entry->const_tensor);
914 break;
915 }
916
917 case Entry::State::HAS_REF_TENSOR: {
918 {
919 tf_shared_lock ml(*entry->ref_tensor.mu);
920 if (TF_PREDICT_FALSE(!entry->ref_tensor.tensor->IsInitialized() &&
921 !item.is_initialization_op)) {
922 return AttachDef(errors::FailedPrecondition(
923 "Attempting to use uninitialized value ",
924 item.kernel->requested_input(i)),
925 item.kernel->def());
926 }
927 }
928
929 if (expect_ref) {
930 inp->mutex_if_ref = entry->ref_tensor.mu;
931 inp->tensor = entry->ref_tensor.tensor;
932 } else {
933 // Automatically deref the tensor ref when the op expects a
934 // tensor but is given a ref to a tensor. Need to deref it
935 // under the mutex.
936 {
937 mutex* ref_mu = entry->ref_tensor.mu;
938 Tensor* ref_tensor = entry->ref_tensor.tensor;
939 tf_shared_lock l(*ref_mu);
940 entry->val.Init(*ref_tensor);
941 }
942 entry->state = Entry::State::HAS_VALUE;
943
944 inp->mutex_if_ref = nullptr;
945 inp->tensor = entry->val.get();
946 // The dtype of entry->ref_tensor.tensor could have been changed by
947 // another operation that ran after the operation that "produced" it
948 // executed, so re-validate that the type of the dereferenced tensor
949 // matches the expected input type.
950 if (TF_PREDICT_FALSE(item.input_type(i) != inp->tensor->dtype())) {
951 return AttachDef(
952 errors::InvalidArgument(
953 i, "-th input expects type ",
954 DataTypeString(item.input_type(i)),
955 " but automatically dereferenced input tensor has type ",
956 DataTypeString(inp->tensor->dtype())),
957 item.kernel->def());
958 }
959 }
960 break;
961 }
962 }
963 }
964 return Status::OK();
965 }
966
967 template <class PropagatorStateType>
ProcessOutputs(const NodeItem & item,OpKernelContext * ctx,Entry * outputs,NodeExecStatsInterface * stats)968 Status ExecutorState<PropagatorStateType>::ProcessOutputs(
969 const NodeItem& item, OpKernelContext* ctx, Entry* outputs,
970 NodeExecStatsInterface* stats) {
971 Status s = ctx->status();
972 if (!s.ok()) {
973 s = AttachDef(s, item.kernel->def());
974 // TODO(misard) Replace with a finer-grain enabling flag once we
975 // add better optional debugging support.
976 if (vlog_ && VLOG_IS_ON(1)) {
977 LOG(WARNING) << this << " Compute status: " << s;
978 }
979 if (s.code() == error::RESOURCE_EXHAUSTED) {
980 if (stats_collector_) {
981 string err = stats_collector_->ReportAllocsOnResourceExhausted(
982 s.error_message());
983 s = Status(s.code(), strings::StrCat(s.error_message(), err));
984 } else {
985 s = Status(
986 s.code(),
987 strings::StrCat(
988 s.error_message(),
989 "\nHint: If you want to see a list of allocated tensors when "
990 "OOM happens, add report_tensor_allocations_upon_oom "
991 "to RunOptions for current allocation info.\n"));
992 }
993 }
994 return s;
995 }
996
997 for (int i = 0; i < item.num_outputs; ++i) {
998 const TensorValue val = ctx->release_output(i);
999 Entry* out = &outputs[i];
1000 DCHECK(out->state == Entry::State::NO_VALUE);
1001
1002 if (val.tensor == nullptr) {
1003 // Unless it's a Switch or a Recv, or the executor has marked the output
1004 // as not required, the node must produce a tensor value at i-th output.
1005 if (!(item.is_recv_or_switch ||
1006 (item.outputs_required && !item.outputs_required[i]))) {
1007 s.Update(errors::Internal("Missing ", i, "-th output from ",
1008 FormatNodeDefForError(item.kernel->def())));
1009 }
1010 } else {
1011 // Set the allocator attributes of the output entry.
1012 out->alloc_attr = ctx->output_alloc_attr(i);
1013
1014 // Sanity check of output tensor types. We need to inspect this safely as
1015 // we are in the tensor buffer.
1016 DataType dtype = val.dtype_safe();
1017 if (dtype == item.output_type(i)) {
1018 if (stats && val.tensor->IsInitialized()) {
1019 nodestats::SetOutput(stats, i, val.tensor);
1020 }
1021 if (val.is_ref()) {
1022 out->state = Entry::State::HAS_REF_TENSOR;
1023 out->ref_tensor.tensor = val.tensor;
1024 out->ref_tensor.mu = val.mutex_if_ref;
1025 if (log_memory_) {
1026 Tensor to_log;
1027 {
1028 // Dereference the tensor under the lock.
1029 tf_shared_lock l(*out->ref_tensor.mu);
1030 to_log = *out->ref_tensor.tensor;
1031 }
1032 LogMemory::RecordTensorOutput(ctx->op_kernel().name(),
1033 ctx->step_id(), i, to_log);
1034 }
1035 } else {
1036 // NOTE that std::move is used here, so val.tensor goes to
1037 // uninitialized state (val.tensor->IsInitialized return false).
1038 out->state = Entry::State::HAS_VALUE;
1039 out->val.Init(std::move(*val.tensor));
1040 if (log_memory_) {
1041 LogMemory::RecordTensorOutput(ctx->op_kernel().name(),
1042 ctx->step_id(), i, *out->val);
1043 }
1044 }
1045 } else {
1046 s.Update(
1047 errors::Internal("Output ", i, " of type ", DataTypeString(dtype),
1048 " does not match declared output type ",
1049 DataTypeString(item.output_type(i)), " for node ",
1050 FormatNodeDefForError(item.kernel->def())));
1051 }
1052 }
1053 if (!val.is_ref()) {
1054 // If OpKernelContext returns outputs via pass-by-value, we
1055 // don't need this trouble.
1056 delete val.tensor;
1057 }
1058 }
1059 return s;
1060 }
1061
1062 template <class PropagatorStateType>
NodeDone(const Status & s,TaggedNodeSeq * ready,NodeExecStatsInterface * stats,TaggedNodeReadyQueue * inline_ready)1063 bool ExecutorState<PropagatorStateType>::NodeDone(
1064 const Status& s, TaggedNodeSeq* ready, NodeExecStatsInterface* stats,
1065 TaggedNodeReadyQueue* inline_ready) {
1066 if (stats) {
1067 nodestats::SetAllEnd(stats);
1068 DCHECK_NE(stats_collector_, nullptr);
1069 stats->Done(immutable_state_.params().device->name());
1070 }
1071
1072 if (TF_PREDICT_TRUE(s.ok())) {
1073 const size_t ready_size = ready->size();
1074 if (ready_size == 0) {
1075 return num_outstanding_ops_.fetch_sub(1) == 1;
1076 } else {
1077 // NOTE: Avoid touching the atomic counter if only one node becomes ready.
1078 if (ready_size > 1) {
1079 num_outstanding_ops_.fetch_add(ready_size - 1,
1080 std::memory_order_relaxed);
1081 }
1082
1083 // Schedule the ready nodes in 'ready'.
1084 ScheduleReady(ready, inline_ready);
1085
1086 return false;
1087 }
1088 } else {
1089 bool abort_run = false;
1090
1091 // Some error happened. This thread of computation is done.
1092 {
1093 mutex_lock l(mu_);
1094 if (status_.ok()) {
1095 // If this is the first node to fail in this run, we are responsible for
1096 // aborting all other execution in the step.
1097 abort_run = true;
1098
1099 // If execution has been cancelled, mark cancelled or aborted errors as
1100 // being derived. Note that the original node that fails might also
1101 // trigger cancellation, and here we make sure the original error is
1102 // exposed to users and not buried as a derived error.
1103 if (cancellation_manager_ && cancellation_manager_->IsCancelled() &&
1104 (errors::IsCancelled(s) || errors::IsAborted(s))) {
1105 status_ = StatusGroup::MakeDerived(s);
1106 } else {
1107 status_ = s;
1108 }
1109 }
1110 }
1111
1112 if (abort_run) {
1113 TRACEPRINTF("StartAbort: %s", s.ToString().c_str());
1114 if (cancellation_manager_) {
1115 // Only log when the abort happens during the actual run time.
1116 // Use VLOG instead of LOG(warning) because error status is expected
1117 // when the executor is run under the grappler optimization phase or
1118 // when iterating through a tf.data input pipeline.
1119 VLOG(1) << "[" << immutable_state_.params().device->name()
1120 << "] Executor start aborting: " << s;
1121 }
1122
1123 if (rendezvous_) {
1124 rendezvous_->StartAbort(s);
1125 }
1126 if (cancellation_manager_) {
1127 cancellation_manager_->StartCancel();
1128 } else if (collective_executor_) {
1129 // If there's cancellation_manager_, collective ops aborts
1130 // collective_executor_ upon cancellation; otherwise we need to abort
1131 // here.
1132 collective_executor_->StartAbort(s);
1133 }
1134 }
1135
1136 return num_outstanding_ops_.fetch_sub(1) == 1;
1137 }
1138 }
1139
1140 template <class PropagatorStateType>
ScheduleReady(TaggedNodeSeq * ready,TaggedNodeReadyQueue * inline_ready)1141 void ExecutorState<PropagatorStateType>::ScheduleReady(
1142 TaggedNodeSeq* ready, TaggedNodeReadyQueue* inline_ready) {
1143 DCHECK(!ready->empty());
1144
1145 int64 scheduled_nsec = 0;
1146 if (stats_collector_) {
1147 scheduled_nsec = nodestats::NowInNsec();
1148 }
1149
1150 if (run_all_kernels_inline_) {
1151 if (inline_ready == nullptr) {
1152 // Schedule all ready kernels from a single closure. This ensure that,
1153 // regardless of the `runner_` implementation, all kernels will run
1154 // sequentially on the same thread, and thread wakeup overhead and
1155 // executor mutex contention will be minimized.
1156 RunTask([this, ready = std::move(*ready), scheduled_nsec]() {
1157 for (auto& tagged_node : ready) {
1158 Process(tagged_node, scheduled_nsec);
1159 }
1160 });
1161 } else {
1162 for (auto& tagged_node : *ready) {
1163 inline_ready->push_back(tagged_node);
1164 }
1165 }
1166 } else {
1167 const TaggedNode* curr_expensive_node = nullptr;
1168 if (inline_ready == nullptr) {
1169 // Schedule to run all the ready ops in thread pool.
1170 for (auto& tagged_node : *ready) {
1171 RunTask([=]() { Process(tagged_node, scheduled_nsec); });
1172 }
1173 } else {
1174 for (auto& tagged_node : *ready) {
1175 const NodeItem& item = *tagged_node.node_item;
1176 if (tagged_node.get_is_dead() || !kernel_stats_->IsExpensive(item)) {
1177 // Inline this inexpensive node.
1178 inline_ready->push_back(tagged_node);
1179 } else {
1180 if (curr_expensive_node) {
1181 // Dispatch to another thread since there is plenty of work to
1182 // do for this thread.
1183 RunTask(std::bind(&ExecutorState::Process, this,
1184 *curr_expensive_node, scheduled_nsec));
1185 }
1186 curr_expensive_node = &tagged_node;
1187 }
1188 }
1189 }
1190 if (curr_expensive_node) {
1191 if (inline_ready->empty()) {
1192 inline_ready->push_back(*curr_expensive_node);
1193 } else {
1194 // There are inline nodes to run already. We dispatch this expensive
1195 // node to other thread.
1196 RunTask(std::bind(&ExecutorState::Process, this, *curr_expensive_node,
1197 scheduled_nsec));
1198 }
1199 }
1200 }
1201 ready->clear();
1202 }
1203
1204 template <class PropagatorStateType>
ScheduleFinish()1205 void ExecutorState<PropagatorStateType>::ScheduleFinish() {
1206 // Checks condition to decide if needs to invoke Finish(). If there are
1207 // in-flight deffered ops, wait for `num_deferred_ops_` reaches 0 to invoke
1208 // Finish(). Otherwise, invoke Finish() directly.
1209 // Note that it is critical that the ScheduleFinish / Finish codepath does not
1210 // block, otherwise we might deadlock. See b/124523000 for details.
1211 {
1212 mutex_lock lock(num_deferred_ops_mu_);
1213 if (num_deferred_ops_ > 0) {
1214 finish_when_deferred_ops_done_ = true;
1215 return;
1216 }
1217 }
1218 // Finish is always called exactly once per ExecutorState, either here if
1219 // there aren't any deferred ops, or in the dec_num_deferred_ops_function if
1220 // there are deferred ops.
1221 Finish();
1222 }
1223
1224 template <class PropagatorStateType>
Finish()1225 void ExecutorState<PropagatorStateType>::Finish() {
1226 mu_.lock();
1227 auto status = status_;
1228 auto done_cb = std::move(done_cb_);
1229 auto runner = std::move(runner_);
1230 mu_.unlock();
1231 int64 step_id = step_id_;
1232 CHECK(done_cb != nullptr);
1233 Device* device = immutable_state_.params().device;
1234
1235 if (vlog_ && !status.ok() && VLOG_IS_ON(1)) {
1236 // Logs verbose information about the current state of active and pending
1237 // nodes in the propagator.
1238 propagator_.DumpState();
1239 }
1240
1241 // There are several potential race conditions below. To name a few:
1242 // 1. Even if the device's status is OK at the precise moment when
1243 // num_deferred_ops_ reaches 0, it could go bad before device->RefreshStatus()
1244 // is called below, caused by work enqueued onto the same device by other
1245 // concurrent ExecutorState objects.
1246 // 2. Some implementations of Device::RefreshStatus, such as
1247 // XlaDevice::RefreshStatus, may be inherently racy because it releases the
1248 // device mutex after a stream pointer is acquired and before the stream is
1249 // queried for status.
1250 // 3. It's the same for some implementations of Device::Sync, such as
1251 // XlaDevice::Sync.
1252 //
1253 // However, these race conditions are acceptable because a stream (and
1254 // therefore an XlaDevice) can only go from OK to not-OK, never the opposite,
1255 // which means we will at worst report errors when there isn't any, never the
1256 // opposite.
1257
1258 // An early exit for devices don't allow sync on completion. Ops that run on
1259 // these devices should have used num_deferred_ops correctly to ensure the
1260 // device has finished all relevant work at this point.
1261 if (!device->AllowsSyncOnCompletion()) {
1262 status.Update(device->RefreshStatus());
1263 if (!status.ok()) {
1264 // In device async execution mode, it's possible for device execution to
1265 // lag behind ExecutorState scheduling so much that this is the first
1266 // place a device execution error surfaces.
1267 // If so, all ExecutorState::NodeDone calls have already happened with OK
1268 // status. This is the last defense where StartCancel must be called to
1269 // abort all computation still running on any device.
1270 // TODO(b/124523000): Always call Finish in a separate thread, so even if
1271 // StartCancel blocks the current thread's execution, we won't encounter
1272 // deadlocks caused by inter-op thread exhaustion.
1273 if (rendezvous_) {
1274 rendezvous_->StartAbort(status);
1275 }
1276 if (cancellation_manager_) {
1277 cancellation_manager_->StartCancel();
1278 } else if (collective_executor_) {
1279 // If there's cancellation_manager_, collective ops aborts
1280 // collective_executor_ upon cancellation; otherwise we need to abort
1281 // here.
1282 collective_executor_->StartAbort(status);
1283 }
1284 }
1285 delete this;
1286 runner([step_id, status, done_cb = std::move(done_cb)]() {
1287 profiler::TraceMeConsumer activity(
1288 // From TraceMeProducer in KernelAndDeviceFunc::RunAsync,
1289 // DirectSession::RunInternal or GraphMgr::ExecuteAsync.
1290 [&] {
1291 return profiler::TraceMeEncode("ExecutorDoneCallback",
1292 {{"id", step_id}});
1293 },
1294 profiler::ContextType::kTfExecutor, step_id,
1295 profiler::TraceMeLevel::kInfo);
1296 done_cb(status);
1297 });
1298 return;
1299 }
1300
1301 if (sync_on_finish_ && status.ok()) {
1302 // Block until the device has finished all queued operations. For
1303 // devices like GPUs that continue to execute Ops after their Compute
1304 // methods have completed, this ensures that control is not returned to
1305 // the user until the step (and its side-effects) has actually completed.
1306 device->Sync([this, step_id, runner = std::move(runner),
1307 done_cb = std::move(done_cb)](const Status& status) mutable {
1308 delete this;
1309 runner([step_id, status, done_cb = std::move(done_cb)]() {
1310 profiler::TraceMeConsumer activity(
1311 // From TraceMeProducer in KernelAndDeviceFunc::RunAsync,
1312 // DirectSession::RunInternal or GraphMgr::ExecuteAsync.
1313 [&] {
1314 return profiler::TraceMeEncode("ExecutorDoneCallback",
1315 {{"id", step_id}});
1316 },
1317 profiler::ContextType::kTfExecutor, step_id,
1318 profiler::TraceMeLevel::kInfo);
1319 done_cb(status);
1320 });
1321 });
1322 } else {
1323 delete this;
1324 runner([step_id, status, done_cb = std::move(done_cb)]() {
1325 profiler::TraceMeConsumer activity(
1326 // From TraceMeProducer in KernelAndDeviceFunc::RunAsync,
1327 // DirectSession::RunInternal or GraphMgr::ExecuteAsync.
1328 [&] {
1329 return profiler::TraceMeEncode("ExecutorDoneCallback",
1330 {{"id", step_id}});
1331 },
1332 profiler::ContextType::kTfExecutor, step_id,
1333 profiler::TraceMeLevel::kInfo);
1334 done_cb(status);
1335 });
1336 }
1337 }
1338
RunAsync(const Args & args,DoneCallback done)1339 void ExecutorImpl::RunAsync(const Args& args, DoneCallback done) {
1340 if (immutable_state_.requires_control_flow_support()) {
1341 (new ExecutorState<PropagatorState>(args, immutable_state_, &kernel_stats_))
1342 ->RunAsync(std::move(done));
1343 } else {
1344 (new ExecutorState<SimplePropagatorState>(args, immutable_state_,
1345 &kernel_stats_))
1346 ->RunAsync(std::move(done));
1347 }
1348 }
1349
1350 } // namespace
1351
NewLocalExecutor(const LocalExecutorParams & params,const Graph & graph,Executor ** executor)1352 Status NewLocalExecutor(const LocalExecutorParams& params, const Graph& graph,
1353 Executor** executor) {
1354 ExecutorImpl* impl = new ExecutorImpl(params);
1355 const Status s = impl->Initialize(graph);
1356 if (s.ok()) {
1357 *executor = impl;
1358 } else {
1359 delete impl;
1360 }
1361 return s;
1362 }
1363
CreateNonCachedKernel(Device * device,FunctionLibraryRuntime * flib,const std::shared_ptr<const NodeProperties> & props,int graph_def_version,OpKernel ** kernel)1364 Status CreateNonCachedKernel(Device* device, FunctionLibraryRuntime* flib,
1365 const std::shared_ptr<const NodeProperties>& props,
1366 int graph_def_version, OpKernel** kernel) {
1367 const auto device_type = DeviceType(device->attributes().device_type());
1368 auto allocator = device->GetAllocator(AllocatorAttributes());
1369 return CreateOpKernel(device_type, device, allocator, flib,
1370 device->resource_manager(), props, graph_def_version,
1371 kernel);
1372 }
1373
DeleteNonCachedKernel(OpKernel * kernel)1374 void DeleteNonCachedKernel(OpKernel* kernel) { delete kernel; }
1375
1376 namespace {
1377
1378 class DefaultExecutorRegistrar {
1379 public:
DefaultExecutorRegistrar()1380 DefaultExecutorRegistrar() {
1381 Factory* factory = new Factory;
1382 ExecutorFactory::Register("", factory);
1383 ExecutorFactory::Register("DEFAULT", factory);
1384 }
1385
1386 private:
1387 class Factory : public ExecutorFactory {
NewExecutor(const LocalExecutorParams & params,const Graph & graph,std::unique_ptr<Executor> * out_executor)1388 Status NewExecutor(const LocalExecutorParams& params, const Graph& graph,
1389 std::unique_ptr<Executor>* out_executor) override {
1390 Executor* ret = nullptr;
1391 TF_RETURN_IF_ERROR(NewLocalExecutor(params, std::move(graph), &ret));
1392 out_executor->reset(ret);
1393 return Status::OK();
1394 }
1395 };
1396 };
1397 static DefaultExecutorRegistrar registrar;
1398
1399 } // namespace
1400
1401 } // namespace tensorflow
1402