1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/common_runtime/executor.h"
17 
18 #include <atomic>
19 #include <memory>
20 #include <vector>
21 
22 #include "absl/memory/memory.h"
23 #include "tensorflow/core/common_runtime/costmodel_manager.h"
24 #include "tensorflow/core/common_runtime/entry.h"
25 #include "tensorflow/core/common_runtime/executor_factory.h"
26 #include "tensorflow/core/common_runtime/graph_view.h"
27 #include "tensorflow/core/common_runtime/immutable_executor_state.h"
28 #include "tensorflow/core/common_runtime/pending_counts.h"
29 #include "tensorflow/core/common_runtime/propagator_state.h"
30 #include "tensorflow/core/common_runtime/renamed_device.h"
31 #include "tensorflow/core/common_runtime/simple_propagator_state.h"
32 #include "tensorflow/core/common_runtime/step_stats_collector.h"
33 #include "tensorflow/core/framework/allocator.h"
34 #include "tensorflow/core/framework/cancellation.h"
35 #include "tensorflow/core/framework/collective.h"
36 #include "tensorflow/core/framework/control_flow.h"
37 #include "tensorflow/core/framework/device_attributes.pb.h"
38 #include "tensorflow/core/framework/log_memory.h"
39 #include "tensorflow/core/framework/metrics.h"
40 #include "tensorflow/core/framework/node_def_util.h"
41 #include "tensorflow/core/framework/op_kernel.h"
42 #include "tensorflow/core/framework/op_segment.h"
43 #include "tensorflow/core/framework/tensor.h"
44 #include "tensorflow/core/framework/tensor_reference.h"
45 #include "tensorflow/core/framework/types.h"
46 #include "tensorflow/core/framework/types.pb.h"
47 #include "tensorflow/core/graph/edgeset.h"
48 #include "tensorflow/core/graph/graph.h"
49 #include "tensorflow/core/graph/graph_node_util.h"
50 #include "tensorflow/core/lib/core/errors.h"
51 #include "tensorflow/core/lib/core/notification.h"
52 #include "tensorflow/core/lib/core/status.h"
53 #include "tensorflow/core/lib/core/threadpool.h"
54 #include "tensorflow/core/lib/gtl/flatmap.h"
55 #include "tensorflow/core/lib/gtl/inlined_vector.h"
56 #include "tensorflow/core/lib/gtl/manual_constructor.h"
57 #include "tensorflow/core/lib/hash/hash.h"
58 #include "tensorflow/core/platform/context.h"
59 #include "tensorflow/core/platform/env.h"
60 #include "tensorflow/core/platform/errors.h"
61 #include "tensorflow/core/platform/logging.h"
62 #include "tensorflow/core/platform/macros.h"
63 #include "tensorflow/core/platform/mutex.h"
64 #include "tensorflow/core/platform/profile_utils/cpu_utils.h"
65 #include "tensorflow/core/platform/thread_annotations.h"
66 #include "tensorflow/core/platform/tracing.h"
67 #include "tensorflow/core/platform/types.h"
68 #include "tensorflow/core/profiler/lib/annotated_traceme.h"
69 #include "tensorflow/core/profiler/lib/connected_traceme.h"
70 #include "tensorflow/core/profiler/lib/scoped_annotation.h"
71 #include "tensorflow/core/profiler/lib/traceme_encode.h"
72 #include "tensorflow/core/protobuf/error_codes.pb.h"
73 #include "tensorflow/core/util/tensor_slice_reader_cache.h"
74 
75 namespace tensorflow {
76 
77 namespace {
78 
79 // 1-D, 0 element tensor.
80 static const Tensor* const kEmptyTensor = new Tensor;
81 
82 // Helper routines for collecting step stats.
83 namespace nodestats {
NowInNsec()84 inline int64 NowInNsec() { return EnvTime::NowNanos(); }
85 
SetScheduled(NodeExecStatsInterface * stats,int64 micros)86 void SetScheduled(NodeExecStatsInterface* stats, int64 micros) {
87   if (!stats) return;
88   stats->SetScheduled(micros * EnvTime::kMicrosToNanos);
89 }
90 
SetAllStart(NodeExecStatsInterface * stats)91 void SetAllStart(NodeExecStatsInterface* stats) {
92   if (!stats) return;
93   stats->RecordExecutorStarted();
94 }
95 
SetOpStart(NodeExecStatsInterface * stats)96 void SetOpStart(NodeExecStatsInterface* stats) {
97   if (!stats) return;
98   stats->RecordComputeStarted();
99 }
100 
SetOpEnd(NodeExecStatsInterface * stats)101 void SetOpEnd(NodeExecStatsInterface* stats) {
102   if (!stats) return;
103   stats->RecordComputeEnded();
104 }
105 
SetAllEnd(NodeExecStatsInterface * stats)106 void SetAllEnd(NodeExecStatsInterface* stats) {
107   if (!stats) return;
108   stats->RecordExecutorEnded();
109 }
110 
SetOutput(NodeExecStatsInterface * stats,int slot,const Tensor * v)111 void SetOutput(NodeExecStatsInterface* stats, int slot, const Tensor* v) {
112   if (!stats) return;
113   stats->SetOutput(slot, v);
114 }
115 
SetMemory(NodeExecStatsInterface * stats,OpKernelContext * ctx)116 void SetMemory(NodeExecStatsInterface* stats, OpKernelContext* ctx) {
117   if (!stats) return;
118   stats->SetMemory(ctx);
119 }
120 
121 }  // namespace nodestats
122 
123 // Time the execution of kernels (in CPU cycles).  Used to dynamically identify
124 // inexpensive kernels which can be dispatched inline.
125 struct KernelTimer {
126   uint64 start_cycles = profile_utils::CpuUtils::GetCurrentClockCycle();
127 
ElapsedCyclestensorflow::__anon6f8fc96b0111::KernelTimer128   uint64 ElapsedCycles() {
129     return profile_utils::CpuUtils::GetCurrentClockCycle() - start_cycles;
130   }
131 };
132 
133 // TODO(b/152925936): Re-evaluate these constants with current usage patterns.
134 typedef gtl::InlinedVector<TensorValue, 4> TensorValueVec;
135 typedef gtl::InlinedVector<AllocatorAttributes, 4> AllocatorAttributeVec;
136 
137 class ExecutorImpl : public Executor {
138  public:
ExecutorImpl(const LocalExecutorParams & p)139   explicit ExecutorImpl(const LocalExecutorParams& p) : immutable_state_(p) {}
140 
Initialize(const Graph & graph)141   Status Initialize(const Graph& graph) {
142     TF_RETURN_IF_ERROR(immutable_state_.Initialize(graph));
143     kernel_stats_.Initialize(immutable_state_.graph_view());
144     return Status::OK();
145   }
146 
147   void RunAsync(const Args& args, DoneCallback done) override;
148 
149  private:
150   template <class PropagatorStateType>
151   friend class ExecutorState;
152 
153   // Stores execution time information about the kernels in an executor's graph.
154   class KernelStats {
155    public:
156     KernelStats() = default;
157 
Initialize(const GraphView & gview)158     void Initialize(const GraphView& gview) {
159       is_expensive_.resize(gview.num_nodes());
160       cost_estimates_ =
161           absl::make_unique<std::atomic_uint_fast64_t[]>(gview.num_nodes());
162       for (int32 i = 0; i < gview.num_nodes(); ++i) {
163         if (gview.node(i)) {
164           is_expensive_[i] =
165               gview.node(i)->kernel && gview.node(i)->kernel->IsExpensive();
166           cost_estimates_[i] = kInitialCostEstimateCycles;
167         }
168       }
169     }
170 
171     // Returns true iff the given node is considered "expensive". The
172     // executor uses this flag to optimize graph execution, for example
173     // by "inlining" inexpensive kernels.
IsExpensive(const NodeItem & node) const174     bool IsExpensive(const NodeItem& node) const {
175       return is_expensive_[node.node_id] &&
176              (cost_estimates_[node.node_id].load(std::memory_order_relaxed) >
177               kOpIsExpensiveThresholdCycles);
178     }
179 
180     // Returns the value of kernel->IsExpensive().
HasExpensiveMarker(const NodeItem & node) const181     bool HasExpensiveMarker(const NodeItem& node) const {
182       return is_expensive_[node.node_id];
183     }
184 
185     // Updates the dynamic cost estimate, which is used to determine whether the
186     // given node is expensive. The new cost estimate is a weighted average of
187     // the old cost estimate and the latest cost. We only update cost estimates
188     // for kernels for which IsExpensive() return true.
UpdateCostEstimate(const NodeItem & node,uint64 elapsed_cycles)189     void UpdateCostEstimate(const NodeItem& node, uint64 elapsed_cycles) {
190       // N.B. Updates to `cost_estimate` are atomic but unlocked.  Simultaneous
191       // updates may result in one or more updates being ignored.  This does not
192       // affect correctness but may slow down the update frequency.
193       std::atomic_uint_fast64_t& cost_estimate = cost_estimates_[node.node_id];
194       auto prev_estimate = cost_estimate.load(std::memory_order_relaxed);
195 
196       uint64 new_estimate =
197           ((kCostDecay - 1) * prev_estimate + elapsed_cycles) / kCostDecay;
198 
199       cost_estimate.store(new_estimate, std::memory_order_relaxed);
200     }
201 
202    private:
203     // Initial time (in CPU cycles) we expect an operation to take.  Used to
204     // determine whether an operation should be place in a threadpool.
205     // Operations start out "expensive".
206     static constexpr uint64 kInitialCostEstimateCycles = 100 * 1000 * 1000;
207     static constexpr uint64 kOpIsExpensiveThresholdCycles = 8000;
208     static constexpr uint64 kCostDecay = 10;
209 
210     std::vector<bool> is_expensive_;
211     // std::unique_ptr<std::atomic<bool>[]> is_expensive_;
212     std::unique_ptr<std::atomic_uint_fast64_t[]> cost_estimates_;
213   };
214 
215   ImmutableExecutorState immutable_state_;
216   KernelStats kernel_stats_;
217 
218   TF_DISALLOW_COPY_AND_ASSIGN(ExecutorImpl);
219 };
220 
221 // The state associated with one invocation of ExecutorImpl::Run.
222 //
223 // ExecutorState dispatches nodes when they become ready, and delegates to an
224 // instance of `PropagatorStateType` to keep track of how many predecessors of a
225 // are still pending.
226 //
227 // The template argument `class PropagatorStateType` must define the following
228 // public members:
229 // * A type `TaggedNode`, representing a node to be processed, with public
230 //   members:
231 //   * `const NodeItem& get_node_item() const`
232 //   * `bool get_is_dead() const`
233 // * A type `TaggedNodeReadyQueue`, representing a queue of nodes to be
234 //   processed, with public members (having the same meanings as in an
235 //   `std::vector<TaggedNode>`):
236 //   * `void push_back(const TaggedNode& node)`
237 //   * `TaggedNode front() const`
238 //   * `void pop_front()`
239 //   * `bool empty() const`
240 // * A type `TaggedNodeSeq`, representing a list of nodes to be schedules, with
241 //   public members (having the same meanings as in an
242 //   `std::vector<TaggedNode>`):
243 //   * `size_t size() const`
244 //   * `bool empty() const`
245 //   * `void clear()`
246 //   * `const_iterator begin() const`
247 //   * `const_iterator end() const`
248 // * A public constructor, `PropagatorStateType(const ImmutableExecutorState&
249 //   immutable_state, int64 step_id)`.
250 // * The following public methods:
251 //   * `void ActivateRoots(gtl::ArraySlice<const NodeItem*> roots,
252 //     TaggedNodeSeq* ready)`, which creates `TaggedNode` instances for the
253 //     nodes in `roots` and adds them to `*ready`
254 //   * `void PropagateOutputs(const TaggedNode& tagged_node, EntryVector*
255 //     outputs, TaggedNodeSeq* ready)`, which propagates `outputs` from the
256 //     given `tagged_node` to the destinations of its output edges, and adds
257 //     any newly runnable nodes to `*ready`
258 //   * `Entry* GetInputTensors(const TaggedNode& tagged_node) const`, which
259 //     returns a pointer to the input tensors for the given `tagged_node`
260 //   * `FrameAndIter GetFrameAndIter(const TaggedNode& tagged_node) const`,
261 //     which creates a `FrameAndIter` for the given `tagged_node`
262 //   * `void DumpState()`, which dumps the dynamic state of the executing graph
263 //   * `void MaybeMarkStarted(const TaggedNode& tagged_node)`, which records
264 //     that a node has started
265 //   * `void MaybeMarkCompleted(const TaggedNode& tagged_node)`, which records
266 //     that a node has completed
267 //
268 // See `PropagatorState` in "./propagator_state.h" for an example of a type that
269 // can be used to instantiate `PropagatorStateType`.
270 template <class PropagatorStateType>
271 class ExecutorState {
272  public:
273   ExecutorState(const Executor::Args& args,
274                 const ImmutableExecutorState& immutable_state_,
275                 ExecutorImpl::KernelStats* kernel_stats_);
276   ~ExecutorState();
277 
278   void RunAsync(Executor::DoneCallback done);
279 
280  private:
281   // Use `TaggedNode` types defined by `PropagatorStateType`.
282   typedef typename PropagatorStateType::TaggedNode TaggedNode;
283   typedef
284       typename PropagatorStateType::TaggedNodeReadyQueue TaggedNodeReadyQueue;
285   typedef typename PropagatorStateType::TaggedNodeSeq TaggedNodeSeq;
286 
287   struct AsyncState;
288 
289   // Process a ready node in current thread.
290   void Process(TaggedNode node, int64 scheduled_nsec);
291 
292   Status ProcessSync(const NodeItem& item, OpKernelContext::Params* params,
293                      EntryVector* outputs, NodeExecStatsInterface* stats);
294   void ProcessAsync(const NodeItem& item, const OpKernelContext::Params& params,
295                     const TaggedNode& tagged_node, Entry* first_input,
296                     NodeExecStatsInterface* stats);
297   void ProcessNoop(NodeExecStatsInterface* stats);
298   void ProcessConstTensor(const NodeItem& item, EntryVector* outputs,
299                           NodeExecStatsInterface* stats);
300 
301   // Before invoking item->kernel, fills in its "inputs".
302   Status PrepareInputs(const NodeItem& item, Entry* first_input,
303                        TensorValueVec* inputs,
304                        AllocatorAttributeVec* input_alloc_attrs,
305                        bool* is_input_dead);
306 
307   // After item->kernel computation is done, processes its outputs.
308   Status ProcessOutputs(const NodeItem& item, OpKernelContext* ctx,
309                         Entry* outputs, NodeExecStatsInterface* stats);
310 
311   // Called after each node finishes. Takes ownership of "stats". Returns true
312   // if execution has completed.
313   //
314   // This method will clear `*ready` before returning.
315   bool NodeDone(const Status& s, TaggedNodeSeq* ready,
316                 NodeExecStatsInterface* stats,
317                 TaggedNodeReadyQueue* inline_ready);
318 
319   // Schedule all the expensive nodes in '*ready', and put all the inexpensive
320   // nodes in 'ready' into 'inline_ready'.
321   //
322   // This method will clear `*ready` before returning.
323   //
324   // REQUIRES: `!ready->empty()`.
325   void ScheduleReady(TaggedNodeSeq* ready, TaggedNodeReadyQueue* inline_ready);
326 
327   // A wrapper for runner_ to keep track of the pending queue length. Op
328   // execution should dispatch work using this function instead of using runner_
329   // directly.
330   template <typename Closure>
331   void RunTask(Closure&& c);
332 
333   // Clean up when this executor is done.
334   void Finish();
335   void ScheduleFinish();
336 
337   // Contains the device context assigned by the device at the beginning of a
338   // step.
339   DeviceContext* device_context_ = nullptr;
340 
341   const bool vlog_;  // true if VLOG_IS_ON(1). Used to check vlog cheaply.
342 
343   // true if LogMemory::IsEnabled(). Used to check memory enabled cheaply.
344   const bool log_memory_;
345 
346   int64 step_id_;
347   // Not owned.
348   RendezvousInterface* rendezvous_;
349   CollectiveExecutor* collective_executor_ = nullptr;
350   SessionState* session_state_;
351   string session_handle_;
352   const SessionMetadata* session_metadata_ = nullptr;
353   TensorStore* tensor_store_;
354   // Step-local container.
355   ScopedStepContainer* step_container_;
356   StepStatsCollectorInterface* const stats_collector_;
357   const tracing::EventCollector* const event_collector_;
358   Context context_;
359 
360   // QUESTION: Make it a checkpoint::TensorSliceReaderCacheWrapper
361   // instead of a pointer?  (avoids having to delete).
362   checkpoint::TensorSliceReaderCacheWrapper* slice_reader_cache_;
363   CallFrameInterface* call_frame_;
364   const ImmutableExecutorState& immutable_state_;
365   ExecutorImpl::KernelStats* const kernel_stats_;
366   CancellationManager* cancellation_manager_;
367   // If not null, use this device to schedule intra-op operation
368   std::unique_ptr<DeviceBase> user_device_;
369   Executor::Args::Runner runner_;
370   bool sync_on_finish_;
371   const bool run_all_kernels_inline_;
372 
373   PropagatorStateType propagator_;
374 
375   // Invoked when the execution finishes.
376   Executor::DoneCallback done_cb_;
377 
378   std::atomic_int_fast32_t num_outstanding_ops_;
379 
380   // Available via OpKernelContext to every OpKernel invocation.
381   mutex num_deferred_ops_mu_;
382   int64 num_deferred_ops_ TF_GUARDED_BY(num_deferred_ops_mu_) = 0;
383   bool finish_when_deferred_ops_done_ TF_GUARDED_BY(num_deferred_ops_mu_) =
384       false;
385 
386   mutex mu_;
387   Status status_ TF_GUARDED_BY(mu_);
388 };
389 
390 template <class PropagatorStateType>
ExecutorState(const Executor::Args & args,const ImmutableExecutorState & immutable_state,ExecutorImpl::KernelStats * kernel_stats)391 ExecutorState<PropagatorStateType>::ExecutorState(
392     const Executor::Args& args, const ImmutableExecutorState& immutable_state,
393     ExecutorImpl::KernelStats* kernel_stats)
394     : vlog_(VLOG_IS_ON(1)),
395       log_memory_(LogMemory::IsEnabled()),
396       step_id_(args.step_id),
397       rendezvous_(args.rendezvous),
398       collective_executor_(args.collective_executor),
399       session_state_(args.session_state),
400       session_handle_(args.session_handle),
401       session_metadata_(immutable_state.params().session_metadata),
402       tensor_store_(args.tensor_store),
403       step_container_(args.step_container),
404       stats_collector_(args.stats_collector),
405       event_collector_(
406           tracing::GetEventCollector(tracing::EventCategory::kCompute)),
407       context_(ContextKind::kThread),
408       slice_reader_cache_(new checkpoint::TensorSliceReaderCacheWrapper),
409       call_frame_(args.call_frame),
410       immutable_state_(immutable_state),
411       kernel_stats_(kernel_stats),
412       cancellation_manager_(args.cancellation_manager),
413       runner_(args.runner),
414       sync_on_finish_(args.sync_on_finish),
415       run_all_kernels_inline_(args.run_all_kernels_inline),
416       propagator_(immutable_state, step_id_, vlog_),
417       num_outstanding_ops_(0) {
418   if (args.user_intra_op_threadpool != nullptr) {
419     Device* device = immutable_state_.params().device;
420     user_device_ = RenamedDevice::NewRenamedDevice(
421         device->name(), device, false, false, args.user_intra_op_threadpool);
422   }
423 }
424 
425 template <class PropagatorStateType>
~ExecutorState()426 ExecutorState<PropagatorStateType>::~ExecutorState() {
427   if (device_context_) {
428     device_context_->Unref();
429   }
430   delete slice_reader_cache_;
431 }
432 
433 template <class PropagatorStateType>
434 template <typename Closure>
RunTask(Closure && c)435 void ExecutorState<PropagatorStateType>::RunTask(Closure&& c) {
436   // Align the atomic variables at 64 bytes to avoid false-sharing, assuming the
437   // cacheline size is 64 bytes or smaller.
438   alignas(64) static std::atomic<int64_t> num_enqueue_ops{0};
439   alignas(64) static std::atomic<int64_t> num_dequeue_ops{0};
440 
441   auto n_enqueues = num_enqueue_ops.fetch_add(1, std::memory_order_relaxed);
442   // Sample the queue length on every 16 enqueue operations. This amortizes the
443   // cost of metric updates across 16 operations.
444   if (n_enqueues % 16 == 0) {
445     auto n_dequeues = num_dequeue_ops.load(std::memory_order_relaxed);
446     metrics::UpdateGraphPendingQueueLength(n_enqueues - n_dequeues);
447   }
448 
449   // mutable is needed because std::forward<Closure> in the lambda body may move
450   // the Closure `c`.
451   runner_([c = std::forward<Closure>(c)]() mutable {
452     num_dequeue_ops.fetch_add(1, std::memory_order_relaxed);
453     std::forward<Closure>(c)();
454   });
455 }
456 
457 template <class PropagatorStateType>
RunAsync(Executor::DoneCallback done)458 void ExecutorState<PropagatorStateType>::RunAsync(Executor::DoneCallback done) {
459   TaggedNodeSeq ready;
460 
461   // Ask the device to fill in the device context map.
462   Device* device = immutable_state_.params().device;
463   const Status get_context_status =
464       device->TryGetDeviceContext(&device_context_);
465   if (!get_context_status.ok()) {
466     delete this;
467     done(get_context_status);
468     return;
469   }
470 
471   // Initialize the ready queue.
472   ready.reserve(immutable_state_.root_nodes().size());
473   propagator_.ActivateRoots(immutable_state_.root_nodes(), &ready);
474   num_outstanding_ops_ = ready.size();
475   if (ready.empty()) {
476     delete this;
477     done(Status::OK());
478   } else {
479     done_cb_ = std::move(done);
480     // Schedule to run all the ready ops in thread pool.
481     ScheduleReady(&ready, nullptr);
482   }
483 }
484 
485 // State kept alive for executing an asynchronous node in another
486 // thread.  NOTE: We need to make a copy of p.input and p.input_alloc_attrs for
487 // asynchronous kernels because OpKernelContext methods like input_type(i) needs
488 // the param points to valid input type vector. It's not an issue for
489 // sync kernels because these vectors are kept on the stack.
490 template <class PropagatorStateType>
491 struct ExecutorState<PropagatorStateType>::AsyncState {
AsyncStatetensorflow::__anon6f8fc96b0111::ExecutorState::AsyncState492   AsyncState(const OpKernelContext::Params& p, const TaggedNode& _tagged_node,
493              const NodeItem* _item, Entry* _first_input,
494              NodeExecStatsInterface* _stats)
495       : saved_inputs(*p.inputs),
496         saved_input_alloc_attrs(*p.input_alloc_attrs),
497         params(p),
498         tagged_node(_tagged_node),
499         item(_item),
500         first_input(_first_input),
501         // ParamsButClearingEigenGPUDevice does equivalent of
502         //   params.eigen_gpu_device = nullptr;
503         ctx(ParamsButClearingEigenGPUDevice(&params), item->num_outputs),
504         stats(_stats) {
505     params.inputs = &saved_inputs;
506     params.input_alloc_attrs = &saved_input_alloc_attrs;
507   }
508 
509   TensorValueVec saved_inputs;
510   AllocatorAttributeVec saved_input_alloc_attrs;
511   OpKernelContext::Params params;
512   TaggedNode tagged_node;
513   const NodeItem* item;
514   Entry* first_input;
515   OpKernelContext ctx;
516   NodeExecStatsInterface* stats;
517 
518  private:
ParamsButClearingEigenGPUDevicetensorflow::__anon6f8fc96b0111::ExecutorState::AsyncState519   OpKernelContext::Params* ParamsButClearingEigenGPUDevice(
520       OpKernelContext::Params* p) {
521     // Ensure OpKernelContext constructor will make a new eigen GPU device if
522     // necessary.
523     p->eigen_gpu_device = nullptr;  // Force allocation
524     return p;
525   }
526 };
527 
528 // Returns true if `item` might be traced by the given trace and event
529 // collectors. Returns false only if `item` definitely will not be traced.
MightTrace(const tracing::EventCollector * event_collector,bool is_expensive)530 bool MightTrace(const tracing::EventCollector* event_collector,
531                 bool is_expensive) {
532   // Tracing will only be enabled if either `event_collector` is non null,
533   // or `trace_collector` is non-null and enabled for this particular kernel.
534   // Although `profiler::TraceMe`, `profiler::ScopedAnnotation`, and
535   // `tracing::ScopedRegion` check subsets of these properties internally in
536   // their constructors, the cost of passing the necessary arguments to them can
537   // be significant, so we avoid constructing them in the common case (when we
538   // know they will not be used).
539   if (event_collector != nullptr) {
540     return true;
541   }
542 
543   if (profiler::ScopedAnnotation::IsEnabled()) return true;
544 
545   return profiler::TraceMe::Active(profiler::GetTFTraceMeLevel(is_expensive));
546 }
547 
548 template <class PropagatorStateType>
ProcessSync(const NodeItem & item,OpKernelContext::Params * params,EntryVector * outputs,NodeExecStatsInterface * stats)549 Status ExecutorState<PropagatorStateType>::ProcessSync(
550     const NodeItem& item, OpKernelContext::Params* params, EntryVector* outputs,
551     NodeExecStatsInterface* stats) {
552   Status s;
553   OpKernelContext ctx(params, item.num_outputs);
554   nodestats::SetOpStart(stats);
555 
556   OpKernel* op_kernel = item.kernel;
557   Device* device = immutable_state_.params().device;
558   const bool is_expensive = kernel_stats_->IsExpensive(item);
559 
560   if (TF_PREDICT_FALSE(MightTrace(event_collector_, is_expensive))) {
561     tracing::ScopedRegion region(tracing::EventCategory::kCompute,
562                                  op_kernel->name_view());
563     profiler::AnnotatedTraceMe activity(
564         [op_kernel, &ctx] {
565           return op_kernel->TraceString(
566               ctx, /*verbose=*/profiler::TfOpDetailsEnabled());
567         },
568         profiler::GetTFTraceMeLevel(is_expensive));
569     device->Compute(op_kernel, &ctx);
570   } else if (kernel_stats_->HasExpensiveMarker(item)) {
571     KernelTimer timer;
572     device->Compute(op_kernel, &ctx);
573     // For expensive kernels, always update the cost estimate. For inexpensive
574     // kernels, update the cost estimate with ~1/16 probability. This assumes
575     // that the last 4 bits of the CPU cycle count is uniformly distributed.
576     constexpr int kKernelExecutionTrackingInvocationSkipCount = 16;
577     if (is_expensive ||
578         timer.start_cycles % kKernelExecutionTrackingInvocationSkipCount == 0) {
579       kernel_stats_->UpdateCostEstimate(item, timer.ElapsedCycles());
580     }
581   } else {
582     device->Compute(op_kernel, &ctx);
583   }
584   nodestats::SetOpEnd(stats);
585   if (outputs->size() < item.num_outputs) outputs->resize(item.num_outputs);
586   s = ProcessOutputs(item, &ctx, outputs->data(), stats);
587   nodestats::SetMemory(stats, &ctx);
588   return s;
589 }
590 
591 template <class PropagatorStateType>
ProcessAsync(const NodeItem & item,const OpKernelContext::Params & params,const TaggedNode & tagged_node,Entry * first_input,NodeExecStatsInterface * stats)592 void ExecutorState<PropagatorStateType>::ProcessAsync(
593     const NodeItem& item, const OpKernelContext::Params& params,
594     const TaggedNode& tagged_node, Entry* first_input,
595     NodeExecStatsInterface* stats) {
596   AsyncOpKernel* async_kernel = item.kernel->AsAsync();
597   DCHECK(async_kernel != nullptr);
598   AsyncState* state =
599       new AsyncState(params, tagged_node, &item, first_input, stats);
600 
601   auto done = [this, state]() {
602     Device* device = immutable_state_.params().device;
603     NodeExecStatsInterface* stats = state->stats;  // Shorthand
604     Entry* first_input = state->first_input;       // Shorthand
605 
606     nodestats::SetOpEnd(stats);
607     EntryVector outputs(state->item->num_outputs);
608     Status s = ProcessOutputs(*state->item, &state->ctx, outputs.data(), stats);
609     nodestats::SetMemory(stats, &state->ctx);
610     if (vlog_) {
611       VLOG(2) << "Async kernel done: " << state->item->node_id << " step "
612               << step_id_ << " " << SummarizeNodeDef(state->item->kernel->def())
613               << (state->tagged_node.get_is_dead() ? " is dead" : "")
614               << " device: " << device->name();
615     }
616 
617     // Clears inputs.
618     const int num_inputs = state->item->num_inputs;
619     for (int i = 0; i < num_inputs; ++i) {
620       (first_input + i)->ClearVal();
621     }
622     propagator_.MaybeMarkCompleted(state->tagged_node);
623     TaggedNodeSeq ready;
624     if (s.ok()) {
625       propagator_.PropagateOutputs(state->tagged_node, &outputs, &ready);
626     }
627     outputs.clear();
628     const bool completed = NodeDone(s, &ready, stats, nullptr);
629     delete state;
630     if (completed) ScheduleFinish();
631   };
632   nodestats::SetOpStart(stats);
633   {
634     profiler::AnnotatedTraceMe activity(
635         [async_kernel, state] {
636           return async_kernel->TraceString(
637               state->ctx, /*verbose=*/profiler::TfOpDetailsEnabled());
638         },
639         profiler::GetTFTraceMeLevel(kernel_stats_->IsExpensive(item)));
640     immutable_state_.params().device->ComputeAsync(async_kernel, &state->ctx,
641                                                    std::move(done));
642   }
643 }
644 
645 template <class PropagatorStateType>
ProcessNoop(NodeExecStatsInterface * stats)646 void ExecutorState<PropagatorStateType>::ProcessNoop(
647     NodeExecStatsInterface* stats) {
648   nodestats::SetOpStart(stats);
649   nodestats::SetOpEnd(stats);
650 }
651 
652 template <class PropagatorStateType>
ProcessConstTensor(const NodeItem & item,EntryVector * outputs,NodeExecStatsInterface * stats)653 void ExecutorState<PropagatorStateType>::ProcessConstTensor(
654     const NodeItem& item, EntryVector* outputs, NodeExecStatsInterface* stats) {
655   nodestats::SetOpStart(stats);
656   nodestats::SetOpEnd(stats);
657   Entry& output = (*outputs)[0];
658   output.state = Entry::State::HAS_CONST_TENSOR;
659   output.const_tensor = item.const_tensor;
660   output.alloc_attr = item.output_attrs()[0];
661 }
662 
663 template <class PropagatorStateType>
Process(TaggedNode tagged_node,int64 scheduled_nsec)664 void ExecutorState<PropagatorStateType>::Process(TaggedNode tagged_node,
665                                                  int64 scheduled_nsec) {
666   profiler::TraceMeConsumer activity(
667       // From TraceMeProducer in DirectSession::RunInternal,
668       // GraphMgr::ExecuteAsync, or FunctionLibraryRuntime::Run.
669       [&] {
670         // NOTE: This tracing uses the iteration number from the first tagged
671         // node that executes during this call to `Process()`. In principle,
672         // subsequent nodes could have different values of `iter_num` that
673         // will not be traced.
674         return profiler::TraceMeEncode(
675             "ExecutorState::Process",
676             {{"id", step_id_}, {"iter_num", tagged_node.get_iter_num()}});
677       },
678       profiler::ContextType::kTfExecutor, step_id_,
679       profiler::TraceMeLevel::kInfo);
680   WithContext wc(context_);
681   TaggedNodeSeq ready;
682   TaggedNodeReadyQueue inline_ready;
683 
684   // Parameters passed to OpKernel::Compute.
685   TensorValueVec inputs;
686   AllocatorAttributeVec input_alloc_attrs;
687 
688   OpKernelContext::Params params;
689   params.step_id = step_id_;
690   // Override device's threadpool if user provides an intra_op_threadpool
691   Device* device = immutable_state_.params().device;
692   if (user_device_) {
693     params.device = user_device_.get();
694   } else {
695     params.device = device;
696   }
697   params.log_memory = log_memory_;
698   params.rendezvous = rendezvous_;
699   params.collective_executor = collective_executor_;
700   params.session_state = session_state_;
701   params.session_handle = session_handle_;
702   params.session_metadata = session_metadata_;
703   params.tensor_store = tensor_store_;
704   params.cancellation_manager = cancellation_manager_;
705   params.call_frame = call_frame_;
706   params.function_library = immutable_state_.params().function_library;
707   params.resource_manager = device->resource_manager();
708   params.step_container = step_container_;
709   params.slice_reader_cache = slice_reader_cache_;
710   params.inputs = &inputs;
711   params.input_alloc_attrs = &input_alloc_attrs;
712   params.runner = &runner_;
713   params.run_all_kernels_inline = run_all_kernels_inline_;
714   params.stats_collector = stats_collector_;
715   params.inc_num_deferred_ops_function = [this]() {
716     mutex_lock lock(num_deferred_ops_mu_);
717     num_deferred_ops_++;
718   };
719   params.dec_num_deferred_ops_function = [this]() {
720     bool finish_when_deferred_ops_done = false;
721     {
722       mutex_lock lock(num_deferred_ops_mu_);
723       num_deferred_ops_--;
724       if (num_deferred_ops_ == 0) {
725         finish_when_deferred_ops_done = finish_when_deferred_ops_done_;
726       }
727     }
728     // Invoke Finish if the graph processing has completed. Finish is always
729     // called exactly once per ExecutorState, either here if there are any
730     // deferred ops, or in ScheduleFinish if there aren't any deferred ops.
731     if (finish_when_deferred_ops_done) Finish();
732   };
733 
734   // Set the device_context for this device, if it exists.
735   params.op_device_context = device_context_;
736 
737   Status s;
738   NodeExecStatsInterface* stats = nullptr;
739 
740   EntryVector outputs(1);
741 
742   bool completed = false;
743   inline_ready.push_back(tagged_node);
744   while (!inline_ready.empty()) {
745     tagged_node = inline_ready.front();
746     inline_ready.pop_front();
747     const NodeItem& item = tagged_node.get_node_item();
748     const int id = item.node_id;
749 
750     propagator_.MaybeMarkStarted(tagged_node);
751 
752     params.track_allocations = false;
753     stats = nullptr;
754     if (stats_collector_ && !tagged_node.get_is_dead()) {
755       stats = stats_collector_->CreateNodeExecStats(&item.kernel->def());
756       // Track allocations if and only if we are collecting statistics, and
757       // `stats` object is expecting allocations to be tracked.
758       params.track_allocations = stats ? stats->TrackAllocations() : false;
759       nodestats::SetScheduled(stats, scheduled_nsec);
760       nodestats::SetAllStart(stats);
761     }
762 
763     if (vlog_) {
764       VLOG(1) << "Process node: " << id << " step " << params.step_id << " "
765               << SummarizeNodeDef(item.kernel->def())
766               << (tagged_node.get_is_dead() ? " is dead" : "")
767               << " device: " << device->name();
768     }
769 
770     Entry* first_input = propagator_.GetInputTensors(tagged_node);
771 
772     // Only execute this node if it is not dead or it is a send/recv
773     // transfer node. For transfer nodes, we need to propagate the "dead"
774     // bit even when the node is dead.
775     bool launched_asynchronously = false;
776     if (tagged_node.get_is_dead() && !item.is_transfer_node) {
777       if (outputs.size() < item.num_outputs) outputs.resize(item.num_outputs);
778     } else if (TF_PREDICT_FALSE(item.is_noop)) {
779       ProcessNoop(stats);
780     } else if (item.const_tensor != nullptr && !params.track_allocations) {
781       ProcessConstTensor(item, &outputs, stats);
782     } else {
783       // Prepares inputs.
784       bool is_input_dead = false;
785       s = PrepareInputs(item, first_input, &inputs, &input_alloc_attrs,
786                         &is_input_dead);
787       if (!s.ok()) {
788         // Clear inputs.
789         const int num_inputs = item.num_inputs;
790         for (int i = 0; i < num_inputs; ++i) {
791           (first_input + i)->ClearVal();
792         }
793         propagator_.MaybeMarkCompleted(tagged_node);
794         // Continue to process the nodes in 'inline_ready'.
795         completed = NodeDone(s, &ready, stats, &inline_ready);
796         continue;
797       }
798 
799       // Set up compute params.
800       params.op_kernel = item.kernel;
801       params.frame_iter = propagator_.GetFrameAndIter(tagged_node);
802       params.is_input_dead = is_input_dead;
803       params.output_attr_array = item.output_attrs();
804       params.forward_from_array = item.forward_from();
805       params.outputs_required_array = item.outputs_required.get();
806 
807       if (item.kernel_is_async) {
808         ProcessAsync(item, params, tagged_node, first_input, stats);
809         launched_asynchronously = true;
810       } else {
811         s = ProcessSync(item, &params, &outputs, stats);
812       }
813     }
814 
815     if (!launched_asynchronously) {
816       if (vlog_) {
817         VLOG(2) << "Synchronous kernel done: " << id << " step "
818                 << params.step_id << " " << SummarizeNodeDef(item.kernel->def())
819                 << (tagged_node.get_is_dead() ? " is dead: " : "")
820                 << " device: " << device->name();
821       }
822 
823       // Clears inputs.
824       const int num_inputs = item.num_inputs;
825       for (int i = 0; i < num_inputs; ++i) {
826         (first_input + i)->ClearVal();
827       }
828       propagator_.MaybeMarkCompleted(tagged_node);
829       // Propagates outputs.
830       if (s.ok()) {
831         propagator_.PropagateOutputs(tagged_node, &outputs, &ready);
832       }
833 
834       // Clear outputs without deallocating the `outputs` vector.
835       const int num_outputs = item.num_outputs;
836       for (int i = 0; i < num_outputs; ++i) {
837         outputs[i].ClearVal();
838       }
839 
840       if (stats) {
841         scheduled_nsec = nodestats::NowInNsec();
842       }
843       // Postprocess.
844       completed = NodeDone(s, &ready, stats, &inline_ready);
845     }
846   }  // while !inline_ready.empty()
847 
848   // This thread of computation is done if completed = true.
849   if (completed) ScheduleFinish();
850 }
851 
852 template <class PropagatorStateType>
PrepareInputs(const NodeItem & item,Entry * first_input,TensorValueVec * inputs,AllocatorAttributeVec * input_alloc_attrs,bool * is_input_dead)853 Status ExecutorState<PropagatorStateType>::PrepareInputs(
854     const NodeItem& item, Entry* first_input, TensorValueVec* inputs,
855     AllocatorAttributeVec* input_alloc_attrs, bool* is_input_dead) {
856   inputs->resize(item.num_inputs);
857   input_alloc_attrs->resize(item.num_inputs);
858 
859   *is_input_dead = false;
860 
861   for (int i = 0; i < item.num_inputs; ++i) {
862     const bool expect_ref = TF_PREDICT_FALSE(item.is_any_input_ref_typed) &&
863                             IsRefType(item.input_type(i));
864     Entry* entry = first_input + i;
865     (*input_alloc_attrs)[i] = entry->alloc_attr;
866 
867     // i-th input.
868     TensorValue* inp = &(*inputs)[i];
869 
870     switch (entry->state) {
871       case Entry::State::NO_VALUE: {
872         // Only merge and transfer nodes can have no-value inputs.
873         inp->mutex_if_ref = nullptr;
874         if (item.is_merge) {
875           inp->tensor = nullptr;
876         } else {
877           DCHECK(item.is_transfer_node)
878               << item.kernel->name() << " - input " << i;
879           entry->state = Entry::State::HAS_CONST_TENSOR;
880           entry->const_tensor = kEmptyTensor;
881           // NOTE(mrry): This `const_cast` is necessary because `TensorValue`
882           // stores a non-const `Tensor*`, and relies on the `OpKernelContext`
883           // accessors making dynamic checks that prevent using an immutable
884           // tensor as a mutable tensor.
885           inp->tensor = const_cast<Tensor*>(kEmptyTensor);
886           *is_input_dead = true;
887         }
888         break;
889       }
890 
891       case Entry::State::HAS_VALUE: {
892         if (TF_PREDICT_FALSE(expect_ref)) {
893           return AttachDef(
894               errors::InvalidArgument(i, "-th input expects a ref type"),
895               item.kernel->def());
896         }
897         inp->mutex_if_ref = nullptr;
898         inp->tensor = entry->val.get();
899         break;
900       }
901 
902       case Entry::State::HAS_CONST_TENSOR: {
903         if (TF_PREDICT_FALSE(expect_ref)) {
904           return AttachDef(
905               errors::InvalidArgument(i, "-th input expects a ref type"),
906               item.kernel->def());
907         }
908         // NOTE(mrry): This `const_cast` is necessary because `TensorValue`
909         // stores a non-const `Tensor*`, and relies on the `OpKernelContext`
910         // accessors making dynamic checks that prevent using an immutable
911         // tensor as a mutable tensor.
912         inp->mutex_if_ref = nullptr;
913         inp->tensor = const_cast<Tensor*>(entry->const_tensor);
914         break;
915       }
916 
917       case Entry::State::HAS_REF_TENSOR: {
918         {
919           tf_shared_lock ml(*entry->ref_tensor.mu);
920           if (TF_PREDICT_FALSE(!entry->ref_tensor.tensor->IsInitialized() &&
921                                !item.is_initialization_op)) {
922             return AttachDef(errors::FailedPrecondition(
923                                  "Attempting to use uninitialized value ",
924                                  item.kernel->requested_input(i)),
925                              item.kernel->def());
926           }
927         }
928 
929         if (expect_ref) {
930           inp->mutex_if_ref = entry->ref_tensor.mu;
931           inp->tensor = entry->ref_tensor.tensor;
932         } else {
933           // Automatically deref the tensor ref when the op expects a
934           // tensor but is given a ref to a tensor.  Need to deref it
935           // under the mutex.
936           {
937             mutex* ref_mu = entry->ref_tensor.mu;
938             Tensor* ref_tensor = entry->ref_tensor.tensor;
939             tf_shared_lock l(*ref_mu);
940             entry->val.Init(*ref_tensor);
941           }
942           entry->state = Entry::State::HAS_VALUE;
943 
944           inp->mutex_if_ref = nullptr;
945           inp->tensor = entry->val.get();
946           // The dtype of entry->ref_tensor.tensor could have been changed by
947           // another operation that ran after the operation that "produced" it
948           // executed, so re-validate that the type of the dereferenced tensor
949           // matches the expected input type.
950           if (TF_PREDICT_FALSE(item.input_type(i) != inp->tensor->dtype())) {
951             return AttachDef(
952                 errors::InvalidArgument(
953                     i, "-th input expects type ",
954                     DataTypeString(item.input_type(i)),
955                     " but automatically dereferenced input tensor has type ",
956                     DataTypeString(inp->tensor->dtype())),
957                 item.kernel->def());
958           }
959         }
960         break;
961       }
962     }
963   }
964   return Status::OK();
965 }
966 
967 template <class PropagatorStateType>
ProcessOutputs(const NodeItem & item,OpKernelContext * ctx,Entry * outputs,NodeExecStatsInterface * stats)968 Status ExecutorState<PropagatorStateType>::ProcessOutputs(
969     const NodeItem& item, OpKernelContext* ctx, Entry* outputs,
970     NodeExecStatsInterface* stats) {
971   Status s = ctx->status();
972   if (!s.ok()) {
973     s = AttachDef(s, item.kernel->def());
974     // TODO(misard) Replace with a finer-grain enabling flag once we
975     // add better optional debugging support.
976     if (vlog_ && VLOG_IS_ON(1)) {
977       LOG(WARNING) << this << " Compute status: " << s;
978     }
979     if (s.code() == error::RESOURCE_EXHAUSTED) {
980       if (stats_collector_) {
981         string err = stats_collector_->ReportAllocsOnResourceExhausted(
982             s.error_message());
983         s = Status(s.code(), strings::StrCat(s.error_message(), err));
984       } else {
985         s = Status(
986             s.code(),
987             strings::StrCat(
988                 s.error_message(),
989                 "\nHint: If you want to see a list of allocated tensors when "
990                 "OOM happens, add report_tensor_allocations_upon_oom "
991                 "to RunOptions for current allocation info.\n"));
992       }
993     }
994     return s;
995   }
996 
997   for (int i = 0; i < item.num_outputs; ++i) {
998     const TensorValue val = ctx->release_output(i);
999     Entry* out = &outputs[i];
1000     DCHECK(out->state == Entry::State::NO_VALUE);
1001 
1002     if (val.tensor == nullptr) {
1003       // Unless it's a Switch or a Recv, or the executor has marked the output
1004       // as not required, the node must produce a tensor value at i-th output.
1005       if (!(item.is_recv_or_switch ||
1006             (item.outputs_required && !item.outputs_required[i]))) {
1007         s.Update(errors::Internal("Missing ", i, "-th output from ",
1008                                   FormatNodeDefForError(item.kernel->def())));
1009       }
1010     } else {
1011       // Set the allocator attributes of the output entry.
1012       out->alloc_attr = ctx->output_alloc_attr(i);
1013 
1014       // Sanity check of output tensor types. We need to inspect this safely as
1015       // we are in the tensor buffer.
1016       DataType dtype = val.dtype_safe();
1017       if (dtype == item.output_type(i)) {
1018         if (stats && val.tensor->IsInitialized()) {
1019           nodestats::SetOutput(stats, i, val.tensor);
1020         }
1021         if (val.is_ref()) {
1022           out->state = Entry::State::HAS_REF_TENSOR;
1023           out->ref_tensor.tensor = val.tensor;
1024           out->ref_tensor.mu = val.mutex_if_ref;
1025           if (log_memory_) {
1026             Tensor to_log;
1027             {
1028               // Dereference the tensor under the lock.
1029               tf_shared_lock l(*out->ref_tensor.mu);
1030               to_log = *out->ref_tensor.tensor;
1031             }
1032             LogMemory::RecordTensorOutput(ctx->op_kernel().name(),
1033                                           ctx->step_id(), i, to_log);
1034           }
1035         } else {
1036           // NOTE that std::move is used here, so val.tensor goes to
1037           // uninitialized state (val.tensor->IsInitialized return false).
1038           out->state = Entry::State::HAS_VALUE;
1039           out->val.Init(std::move(*val.tensor));
1040           if (log_memory_) {
1041             LogMemory::RecordTensorOutput(ctx->op_kernel().name(),
1042                                           ctx->step_id(), i, *out->val);
1043           }
1044         }
1045       } else {
1046         s.Update(
1047             errors::Internal("Output ", i, " of type ", DataTypeString(dtype),
1048                              " does not match declared output type ",
1049                              DataTypeString(item.output_type(i)), " for node ",
1050                              FormatNodeDefForError(item.kernel->def())));
1051       }
1052     }
1053     if (!val.is_ref()) {
1054       // If OpKernelContext returns outputs via pass-by-value, we
1055       // don't need this trouble.
1056       delete val.tensor;
1057     }
1058   }
1059   return s;
1060 }
1061 
1062 template <class PropagatorStateType>
NodeDone(const Status & s,TaggedNodeSeq * ready,NodeExecStatsInterface * stats,TaggedNodeReadyQueue * inline_ready)1063 bool ExecutorState<PropagatorStateType>::NodeDone(
1064     const Status& s, TaggedNodeSeq* ready, NodeExecStatsInterface* stats,
1065     TaggedNodeReadyQueue* inline_ready) {
1066   if (stats) {
1067     nodestats::SetAllEnd(stats);
1068     DCHECK_NE(stats_collector_, nullptr);
1069     stats->Done(immutable_state_.params().device->name());
1070   }
1071 
1072   if (TF_PREDICT_TRUE(s.ok())) {
1073     const size_t ready_size = ready->size();
1074     if (ready_size == 0) {
1075       return num_outstanding_ops_.fetch_sub(1) == 1;
1076     } else {
1077       // NOTE: Avoid touching the atomic counter if only one node becomes ready.
1078       if (ready_size > 1) {
1079         num_outstanding_ops_.fetch_add(ready_size - 1,
1080                                        std::memory_order_relaxed);
1081       }
1082 
1083       // Schedule the ready nodes in 'ready'.
1084       ScheduleReady(ready, inline_ready);
1085 
1086       return false;
1087     }
1088   } else {
1089     bool abort_run = false;
1090 
1091     // Some error happened. This thread of computation is done.
1092     {
1093       mutex_lock l(mu_);
1094       if (status_.ok()) {
1095         // If this is the first node to fail in this run, we are responsible for
1096         // aborting all other execution in the step.
1097         abort_run = true;
1098 
1099         // If execution has been cancelled, mark cancelled or aborted errors as
1100         // being derived. Note that the original node that fails might also
1101         // trigger cancellation, and here we make sure the original error is
1102         // exposed to users and not buried as a derived error.
1103         if (cancellation_manager_ && cancellation_manager_->IsCancelled() &&
1104             (errors::IsCancelled(s) || errors::IsAborted(s))) {
1105           status_ = StatusGroup::MakeDerived(s);
1106         } else {
1107           status_ = s;
1108         }
1109       }
1110     }
1111 
1112     if (abort_run) {
1113       TRACEPRINTF("StartAbort: %s", s.ToString().c_str());
1114       if (cancellation_manager_) {
1115         // Only log when the abort happens during the actual run time.
1116         // Use VLOG instead of LOG(warning) because error status is expected
1117         // when the executor is run under the grappler optimization phase or
1118         // when iterating through a tf.data input pipeline.
1119         VLOG(1) << "[" << immutable_state_.params().device->name()
1120                 << "] Executor start aborting: " << s;
1121       }
1122 
1123       if (rendezvous_) {
1124         rendezvous_->StartAbort(s);
1125       }
1126       if (cancellation_manager_) {
1127         cancellation_manager_->StartCancel();
1128       } else if (collective_executor_) {
1129         // If there's cancellation_manager_, collective ops aborts
1130         // collective_executor_ upon cancellation; otherwise we need to abort
1131         // here.
1132         collective_executor_->StartAbort(s);
1133       }
1134     }
1135 
1136     return num_outstanding_ops_.fetch_sub(1) == 1;
1137   }
1138 }
1139 
1140 template <class PropagatorStateType>
ScheduleReady(TaggedNodeSeq * ready,TaggedNodeReadyQueue * inline_ready)1141 void ExecutorState<PropagatorStateType>::ScheduleReady(
1142     TaggedNodeSeq* ready, TaggedNodeReadyQueue* inline_ready) {
1143   DCHECK(!ready->empty());
1144 
1145   int64 scheduled_nsec = 0;
1146   if (stats_collector_) {
1147     scheduled_nsec = nodestats::NowInNsec();
1148   }
1149 
1150   if (run_all_kernels_inline_) {
1151     if (inline_ready == nullptr) {
1152       // Schedule all ready kernels from a single closure. This ensure that,
1153       // regardless of the `runner_` implementation, all kernels will run
1154       // sequentially on the same thread, and thread wakeup overhead and
1155       // executor mutex contention will be minimized.
1156       RunTask([this, ready = std::move(*ready), scheduled_nsec]() {
1157         for (auto& tagged_node : ready) {
1158           Process(tagged_node, scheduled_nsec);
1159         }
1160       });
1161     } else {
1162       for (auto& tagged_node : *ready) {
1163         inline_ready->push_back(tagged_node);
1164       }
1165     }
1166   } else {
1167     const TaggedNode* curr_expensive_node = nullptr;
1168     if (inline_ready == nullptr) {
1169       // Schedule to run all the ready ops in thread pool.
1170       for (auto& tagged_node : *ready) {
1171         RunTask([=]() { Process(tagged_node, scheduled_nsec); });
1172       }
1173     } else {
1174       for (auto& tagged_node : *ready) {
1175         const NodeItem& item = *tagged_node.node_item;
1176         if (tagged_node.get_is_dead() || !kernel_stats_->IsExpensive(item)) {
1177           // Inline this inexpensive node.
1178           inline_ready->push_back(tagged_node);
1179         } else {
1180           if (curr_expensive_node) {
1181             // Dispatch to another thread since there is plenty of work to
1182             // do for this thread.
1183             RunTask(std::bind(&ExecutorState::Process, this,
1184                               *curr_expensive_node, scheduled_nsec));
1185           }
1186           curr_expensive_node = &tagged_node;
1187         }
1188       }
1189     }
1190     if (curr_expensive_node) {
1191       if (inline_ready->empty()) {
1192         inline_ready->push_back(*curr_expensive_node);
1193       } else {
1194         // There are inline nodes to run already. We dispatch this expensive
1195         // node to other thread.
1196         RunTask(std::bind(&ExecutorState::Process, this, *curr_expensive_node,
1197                           scheduled_nsec));
1198       }
1199     }
1200   }
1201   ready->clear();
1202 }
1203 
1204 template <class PropagatorStateType>
ScheduleFinish()1205 void ExecutorState<PropagatorStateType>::ScheduleFinish() {
1206   // Checks condition to decide if needs to invoke Finish(). If there are
1207   // in-flight deffered ops, wait for `num_deferred_ops_` reaches 0 to invoke
1208   // Finish(). Otherwise, invoke Finish() directly.
1209   // Note that it is critical that the ScheduleFinish / Finish codepath does not
1210   // block, otherwise we might deadlock.  See b/124523000 for details.
1211   {
1212     mutex_lock lock(num_deferred_ops_mu_);
1213     if (num_deferred_ops_ > 0) {
1214       finish_when_deferred_ops_done_ = true;
1215       return;
1216     }
1217   }
1218   // Finish is always called exactly once per ExecutorState, either here if
1219   // there aren't any deferred ops, or in the dec_num_deferred_ops_function if
1220   // there are deferred ops.
1221   Finish();
1222 }
1223 
1224 template <class PropagatorStateType>
Finish()1225 void ExecutorState<PropagatorStateType>::Finish() {
1226   mu_.lock();
1227   auto status = status_;
1228   auto done_cb = std::move(done_cb_);
1229   auto runner = std::move(runner_);
1230   mu_.unlock();
1231   int64 step_id = step_id_;
1232   CHECK(done_cb != nullptr);
1233   Device* device = immutable_state_.params().device;
1234 
1235   if (vlog_ && !status.ok() && VLOG_IS_ON(1)) {
1236     // Logs verbose information about the current state of active and pending
1237     // nodes in the propagator.
1238     propagator_.DumpState();
1239   }
1240 
1241   // There are several potential race conditions below. To name a few:
1242   // 1. Even if the device's status is OK at the precise moment when
1243   // num_deferred_ops_ reaches 0, it could go bad before device->RefreshStatus()
1244   // is called below, caused by work enqueued onto the same device by other
1245   // concurrent ExecutorState objects.
1246   // 2. Some implementations of Device::RefreshStatus, such as
1247   // XlaDevice::RefreshStatus, may be inherently racy because it releases the
1248   // device mutex after a stream pointer is acquired and before the stream is
1249   // queried for status.
1250   // 3. It's the same for some implementations of Device::Sync, such as
1251   // XlaDevice::Sync.
1252   //
1253   // However, these race conditions are acceptable because a stream (and
1254   // therefore an XlaDevice) can only go from OK to not-OK, never the opposite,
1255   // which means we will at worst report errors when there isn't any, never the
1256   // opposite.
1257 
1258   // An early exit for devices don't allow sync on completion. Ops that run on
1259   // these devices should have used num_deferred_ops correctly to ensure the
1260   // device has finished all relevant work at this point.
1261   if (!device->AllowsSyncOnCompletion()) {
1262     status.Update(device->RefreshStatus());
1263     if (!status.ok()) {
1264       // In device async execution mode, it's possible for device execution to
1265       // lag behind ExecutorState scheduling so much that this is the first
1266       // place a device execution error surfaces.
1267       // If so, all ExecutorState::NodeDone calls have already happened with OK
1268       // status. This is the last defense where StartCancel must be called to
1269       // abort all computation still running on any device.
1270       // TODO(b/124523000): Always call Finish in a separate thread, so even if
1271       // StartCancel blocks the current thread's execution, we won't encounter
1272       // deadlocks caused by inter-op thread exhaustion.
1273       if (rendezvous_) {
1274         rendezvous_->StartAbort(status);
1275       }
1276       if (cancellation_manager_) {
1277         cancellation_manager_->StartCancel();
1278       } else if (collective_executor_) {
1279         // If there's cancellation_manager_, collective ops aborts
1280         // collective_executor_ upon cancellation; otherwise we need to abort
1281         // here.
1282         collective_executor_->StartAbort(status);
1283       }
1284     }
1285     delete this;
1286     runner([step_id, status, done_cb = std::move(done_cb)]() {
1287       profiler::TraceMeConsumer activity(
1288           // From TraceMeProducer in KernelAndDeviceFunc::RunAsync,
1289           // DirectSession::RunInternal or GraphMgr::ExecuteAsync.
1290           [&] {
1291             return profiler::TraceMeEncode("ExecutorDoneCallback",
1292                                            {{"id", step_id}});
1293           },
1294           profiler::ContextType::kTfExecutor, step_id,
1295           profiler::TraceMeLevel::kInfo);
1296       done_cb(status);
1297     });
1298     return;
1299   }
1300 
1301   if (sync_on_finish_ && status.ok()) {
1302     // Block until the device has finished all queued operations. For
1303     // devices like GPUs that continue to execute Ops after their Compute
1304     // methods have completed, this ensures that control is not returned to
1305     // the user until the step (and its side-effects) has actually completed.
1306     device->Sync([this, step_id, runner = std::move(runner),
1307                   done_cb = std::move(done_cb)](const Status& status) mutable {
1308       delete this;
1309       runner([step_id, status, done_cb = std::move(done_cb)]() {
1310         profiler::TraceMeConsumer activity(
1311             // From TraceMeProducer in KernelAndDeviceFunc::RunAsync,
1312             // DirectSession::RunInternal or GraphMgr::ExecuteAsync.
1313             [&] {
1314               return profiler::TraceMeEncode("ExecutorDoneCallback",
1315                                              {{"id", step_id}});
1316             },
1317             profiler::ContextType::kTfExecutor, step_id,
1318             profiler::TraceMeLevel::kInfo);
1319         done_cb(status);
1320       });
1321     });
1322   } else {
1323     delete this;
1324     runner([step_id, status, done_cb = std::move(done_cb)]() {
1325       profiler::TraceMeConsumer activity(
1326           // From TraceMeProducer in KernelAndDeviceFunc::RunAsync,
1327           // DirectSession::RunInternal or GraphMgr::ExecuteAsync.
1328           [&] {
1329             return profiler::TraceMeEncode("ExecutorDoneCallback",
1330                                            {{"id", step_id}});
1331           },
1332           profiler::ContextType::kTfExecutor, step_id,
1333           profiler::TraceMeLevel::kInfo);
1334       done_cb(status);
1335     });
1336   }
1337 }
1338 
RunAsync(const Args & args,DoneCallback done)1339 void ExecutorImpl::RunAsync(const Args& args, DoneCallback done) {
1340   if (immutable_state_.requires_control_flow_support()) {
1341     (new ExecutorState<PropagatorState>(args, immutable_state_, &kernel_stats_))
1342         ->RunAsync(std::move(done));
1343   } else {
1344     (new ExecutorState<SimplePropagatorState>(args, immutable_state_,
1345                                               &kernel_stats_))
1346         ->RunAsync(std::move(done));
1347   }
1348 }
1349 
1350 }  // namespace
1351 
NewLocalExecutor(const LocalExecutorParams & params,const Graph & graph,Executor ** executor)1352 Status NewLocalExecutor(const LocalExecutorParams& params, const Graph& graph,
1353                         Executor** executor) {
1354   ExecutorImpl* impl = new ExecutorImpl(params);
1355   const Status s = impl->Initialize(graph);
1356   if (s.ok()) {
1357     *executor = impl;
1358   } else {
1359     delete impl;
1360   }
1361   return s;
1362 }
1363 
CreateNonCachedKernel(Device * device,FunctionLibraryRuntime * flib,const std::shared_ptr<const NodeProperties> & props,int graph_def_version,OpKernel ** kernel)1364 Status CreateNonCachedKernel(Device* device, FunctionLibraryRuntime* flib,
1365                              const std::shared_ptr<const NodeProperties>& props,
1366                              int graph_def_version, OpKernel** kernel) {
1367   const auto device_type = DeviceType(device->attributes().device_type());
1368   auto allocator = device->GetAllocator(AllocatorAttributes());
1369   return CreateOpKernel(device_type, device, allocator, flib,
1370                         device->resource_manager(), props, graph_def_version,
1371                         kernel);
1372 }
1373 
DeleteNonCachedKernel(OpKernel * kernel)1374 void DeleteNonCachedKernel(OpKernel* kernel) { delete kernel; }
1375 
1376 namespace {
1377 
1378 class DefaultExecutorRegistrar {
1379  public:
DefaultExecutorRegistrar()1380   DefaultExecutorRegistrar() {
1381     Factory* factory = new Factory;
1382     ExecutorFactory::Register("", factory);
1383     ExecutorFactory::Register("DEFAULT", factory);
1384   }
1385 
1386  private:
1387   class Factory : public ExecutorFactory {
NewExecutor(const LocalExecutorParams & params,const Graph & graph,std::unique_ptr<Executor> * out_executor)1388     Status NewExecutor(const LocalExecutorParams& params, const Graph& graph,
1389                        std::unique_ptr<Executor>* out_executor) override {
1390       Executor* ret = nullptr;
1391       TF_RETURN_IF_ERROR(NewLocalExecutor(params, std::move(graph), &ret));
1392       out_executor->reset(ret);
1393       return Status::OK();
1394     }
1395   };
1396 };
1397 static DefaultExecutorRegistrar registrar;
1398 
1399 }  // namespace
1400 
1401 }  // namespace tensorflow
1402