1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/common_runtime/executor.h"
17 
18 #include <atomic>
19 #include <deque>
20 #include <memory>
21 #include <string>
22 #include <unordered_map>
23 #include <vector>
24 
25 #include "tensorflow/core/common_runtime/costmodel_manager.h"
26 #include "tensorflow/core/common_runtime/executor_factory.h"
27 #include "tensorflow/core/common_runtime/pending_counts.h"
28 #include "tensorflow/core/common_runtime/step_stats_collector.h"
29 #include "tensorflow/core/framework/allocation_description.pb.h"
30 #include "tensorflow/core/framework/allocator.h"
31 #include "tensorflow/core/framework/cancellation.h"
32 #include "tensorflow/core/framework/collective.h"
33 #include "tensorflow/core/framework/control_flow.h"
34 #include "tensorflow/core/framework/device_attributes.pb.h"
35 #include "tensorflow/core/framework/graph.pb.h"
36 #include "tensorflow/core/framework/log_memory.h"
37 #include "tensorflow/core/framework/node_def_util.h"
38 #include "tensorflow/core/framework/op_kernel.h"
39 #include "tensorflow/core/framework/op_segment.h"
40 #include "tensorflow/core/framework/step_stats.pb.h"
41 #include "tensorflow/core/framework/tensor.h"
42 #include "tensorflow/core/framework/tensor_reference.h"
43 #include "tensorflow/core/framework/types.h"
44 #include "tensorflow/core/framework/types.pb.h"
45 #include "tensorflow/core/graph/edgeset.h"
46 #include "tensorflow/core/lib/core/errors.h"
47 #include "tensorflow/core/lib/core/notification.h"
48 #include "tensorflow/core/lib/core/stringpiece.h"
49 #include "tensorflow/core/lib/core/threadpool.h"
50 #include "tensorflow/core/lib/gtl/flatmap.h"
51 #include "tensorflow/core/lib/gtl/flatset.h"
52 #include "tensorflow/core/lib/gtl/inlined_vector.h"
53 #include "tensorflow/core/lib/gtl/manual_constructor.h"
54 #include "tensorflow/core/lib/gtl/stl_util.h"
55 #include "tensorflow/core/lib/hash/hash.h"
56 #include "tensorflow/core/lib/strings/str_util.h"
57 #include "tensorflow/core/lib/strings/stringprintf.h"
58 #include "tensorflow/core/platform/context.h"
59 #include "tensorflow/core/platform/env.h"
60 #include "tensorflow/core/platform/logging.h"
61 #include "tensorflow/core/platform/macros.h"
62 #include "tensorflow/core/platform/mutex.h"
63 #include "tensorflow/core/platform/profile_utils/cpu_utils.h"
64 #include "tensorflow/core/platform/thread_annotations.h"
65 #include "tensorflow/core/platform/tracing.h"
66 #include "tensorflow/core/platform/types.h"
67 #include "tensorflow/core/util/tensor_slice_reader_cache.h"
68 
69 namespace tensorflow {
70 namespace {
71 
72 // 1-D, 0 element tensor.
73 static const Tensor* const kEmptyTensor = new Tensor;
74 
IsInitializationOp(const Node * node)75 bool IsInitializationOp(const Node* node) {
76   return node->op_def().allows_uninitialized_input();
77 }
78 
79 // Helper routines for collecting step stats.
80 namespace nodestats {
NowInNsec()81 inline int64 NowInNsec() { return Env::Default()->NowNanos(); }
82 
SetScheduled(NodeExecStatsInterface * stats,int64 micros)83 void SetScheduled(NodeExecStatsInterface* stats, int64 micros) {
84   if (!stats) return;
85   stats->SetScheduled(micros * EnvTime::kMicrosToNanos);
86 }
87 
SetAllStart(NodeExecStatsInterface * stats)88 void SetAllStart(NodeExecStatsInterface* stats) {
89   if (!stats) return;
90   stats->RecordExecutorStarted();
91 }
92 
SetOpStart(NodeExecStatsInterface * stats)93 void SetOpStart(NodeExecStatsInterface* stats) {
94   if (!stats) return;
95   stats->RecordComputeStarted();
96 }
97 
SetOpEnd(NodeExecStatsInterface * stats)98 void SetOpEnd(NodeExecStatsInterface* stats) {
99   if (!stats) return;
100   stats->RecordComputeEnded();
101 }
102 
SetAllEnd(NodeExecStatsInterface * stats)103 void SetAllEnd(NodeExecStatsInterface* stats) {
104   if (!stats) return;
105   stats->RecordExecutorEnded();
106 }
107 
SetOutput(NodeExecStatsInterface * stats,int slot,const Tensor * v)108 void SetOutput(NodeExecStatsInterface* stats, int slot, const Tensor* v) {
109   if (!stats) return;
110   stats->SetOutput(slot, v);
111 }
112 
SetMemory(NodeExecStatsInterface * stats,OpKernelContext * ctx)113 void SetMemory(NodeExecStatsInterface* stats, OpKernelContext* ctx) {
114   if (!stats) return;
115   stats->SetMemory(ctx);
116 }
117 
SetReferencedTensors(NodeExecStatsInterface * stats,const TensorReferenceVector & tensors)118 void SetReferencedTensors(NodeExecStatsInterface* stats,
119                           const TensorReferenceVector& tensors) {
120   if (!stats) return;
121   stats->SetReferencedTensors(tensors);
122 }
123 
124 }  // namespace nodestats
125 
126 class ExecutorImpl;
127 class GraphView;
128 
129 struct EdgeInfo {
130   int dst_id;
131   int output_slot : 31;
132   // true if this is the last info for output_slot in the EdgeInfo list.
133   bool is_last : 1;
134   int input_slot;
135 };
136 
137 // Time the execution of kernels (in CPU cycles).  Used to dynamically identify
138 // inexpensive kernels which can be dispatched inline.
139 struct KernelTimer {
140   uint64 start_cycles = profile_utils::CpuUtils::GetCurrentClockCycle();
141 
ElapsedCyclestensorflow::__anon6f8fc96b0111::KernelTimer142   uint64 ElapsedCycles() {
143     return profile_utils::CpuUtils::GetCurrentClockCycle() - start_cycles;
144   }
145 };
146 
147 struct NodeItem {
NodeItemtensorflow::__anon6f8fc96b0111::NodeItem148   NodeItem() {}
149 
150   // A graph node.
151   const Node* node = nullptr;
152 
153   // The kernel for this node.
154   OpKernel* kernel = nullptr;
155 
156   bool kernel_is_async : 1;      // True iff kernel->AsAsync() != nullptr
157   bool is_merge : 1;             // True iff IsMerge(node)
158   bool is_enter : 1;             // True iff IsEnter(node)
159   bool is_constant_enter : 1;    // True iff IsEnter(node) and
160                                  // node->GetAttr("is_constant") == true.
161   bool is_exit : 1;              // True iff IsExit(node)
162   bool is_control_trigger : 1;   // True iff IsControlTrigger(node)
163   bool is_sink : 1;              // True iff IsSink(node)
164   // True iff IsEnter(node) || IsExit(node) || IsNextIteration(node)
165   bool is_enter_exit_or_next_iter : 1;
166 
167   // Cached values of node->num_inputs() and node->num_outputs(), to
168   // avoid levels of indirection.
169   int num_inputs;
170   int num_outputs;
171 
172   // ExecutorImpl::tensors_[input_start] is the 1st positional input
173   // for this node.
174   int input_start = 0;
175 
176   // Number of output edges.
177   size_t num_output_edges;
178 
179   PendingCounts::Handle pending_id;
180 
output_edge_listtensorflow::__anon6f8fc96b0111::NodeItem181   const EdgeInfo* output_edge_list() const { return output_edge_base(); }
182 
183   // ith output edge.
output_edgetensorflow::__anon6f8fc96b0111::NodeItem184   const EdgeInfo& output_edge(int i) const {
185     DCHECK_GE(i, 0);
186     DCHECK_LT(i, num_output_edges);
187     return output_edge_base()[i];
188   }
189 
input_typetensorflow::__anon6f8fc96b0111::NodeItem190   DataType input_type(int i) const {
191     DCHECK_LT(i, num_inputs);
192     return static_cast<DataType>(input_type_base()[i]);
193   }
output_typetensorflow::__anon6f8fc96b0111::NodeItem194   DataType output_type(int i) const {
195     DCHECK_LT(i, num_outputs);
196     return static_cast<DataType>(output_type_base()[i]);
197   }
198 
199   // Return array of per-output allocator attributes.
output_attrstensorflow::__anon6f8fc96b0111::NodeItem200   const AllocatorAttributes* output_attrs() const { return output_attr_base(); }
201 
202   // Return array of expected input index from which each output should
203   // be forwarded:
204   // kNeverForward (-2) for DO NOT FORWARD (must allocate).
205   // kNoReservation (-1) for no expected forwarding.
206   // 0... for forward from that input.
forward_fromtensorflow::__anon6f8fc96b0111::NodeItem207   const int* forward_from() const { return forward_from_base(); }
208 
209  private:
210   friend class GraphView;
211 
212   // Variable length section starts immediately after *this
213   // (uint8 is enough for DataType).
214   //   EdgeInfo            out_edges[num_out_edges];
215   //   AllocatorAttributes output_attr[num_outputs];
216   //   int                 forward_from[num_outputs];
217   //   uint8               input_type[num_inputs];
218   //   uint8               output_type[num_outputs];
219 
220   // Return pointer to variable length section.
vartensorflow::__anon6f8fc96b0111::NodeItem221   char* var() const {
222     return const_cast<char*>(reinterpret_cast<const char*>(this) +
223                              sizeof(NodeItem));
224   }
225 
output_edge_basetensorflow::__anon6f8fc96b0111::NodeItem226   EdgeInfo* output_edge_base() const {
227     return reinterpret_cast<EdgeInfo*>(var());
228   }
output_attr_basetensorflow::__anon6f8fc96b0111::NodeItem229   AllocatorAttributes* output_attr_base() const {
230     return reinterpret_cast<AllocatorAttributes*>(var() + sizeof(EdgeInfo) *
231                                                               num_output_edges);
232   }
forward_from_basetensorflow::__anon6f8fc96b0111::NodeItem233   int* forward_from_base() const {
234     return reinterpret_cast<int*>(var() + sizeof(EdgeInfo) * num_output_edges +
235                                   sizeof(AllocatorAttributes) * num_outputs);
236   }
input_type_basetensorflow::__anon6f8fc96b0111::NodeItem237   uint8* input_type_base() const {
238     return reinterpret_cast<uint8*>(
239         var() + sizeof(EdgeInfo) * num_output_edges +
240         sizeof(AllocatorAttributes) * num_outputs + sizeof(int) * num_outputs);
241   }
output_type_basetensorflow::__anon6f8fc96b0111::NodeItem242   uint8* output_type_base() const {
243     return reinterpret_cast<uint8*>(
244         var() + sizeof(EdgeInfo) * num_output_edges +
245         sizeof(AllocatorAttributes) * num_outputs + sizeof(int) * num_outputs +
246         sizeof(uint8) * num_inputs);
247   }
248 
249   TF_DISALLOW_COPY_AND_ASSIGN(NodeItem);
250 };
251 
252 typedef gtl::InlinedVector<TensorValue, 4> TensorValueVec;
253 typedef gtl::InlinedVector<DeviceContext*, 4> DeviceContextVec;
254 typedef gtl::InlinedVector<AllocatorAttributes, 4> AllocatorAttributeVec;
255 
256 // Immutable view of a Graph organized for efficient execution.
257 class GraphView {
258  public:
GraphView()259   GraphView() : space_(nullptr) {}
260   ~GraphView();
261 
262   void Initialize(const Graph* g);
263   Status SetAllocAttrs(const Graph* g, const Device* device);
264   void SetScopedAllocatorAttrs(const std::vector<const Node*>& sa_nodes);
265 
node(size_t id) const266   NodeItem* node(size_t id) const {
267     DCHECK_GE(id, 0);
268     DCHECK_LT(id, num_nodes_);
269     uint32 offset = node_offsets_[id];
270     return ((offset == kuint32max)
271                 ? nullptr
272                 : reinterpret_cast<NodeItem*>(space_ + node_offsets_[id]));
273   }
274 
275  private:
276   char* InitializeNode(char* ptr, const Node* n);
277   size_t NodeItemBytes(const Node* n);
278 
279   int32 num_nodes_ = 0;
280   uint32* node_offsets_ = nullptr;  // array of size "graph_.num_node_ids()"
281   // node_offsets_[id] holds the byte offset for node w/ "id" in space_
282 
283   char* space_;  // NodeItem objects are allocated here
284 
285   TF_DISALLOW_COPY_AND_ASSIGN(GraphView);
286 };
287 
288 class ExecutorImpl : public Executor {
289  public:
ExecutorImpl(const LocalExecutorParams & p,std::unique_ptr<const Graph> g)290   ExecutorImpl(const LocalExecutorParams& p, std::unique_ptr<const Graph> g)
291       : params_(p), graph_(std::move(g)), gview_() {
292     CHECK(p.create_kernel != nullptr);
293     CHECK(p.delete_kernel != nullptr);
294   }
295 
~ExecutorImpl()296   ~ExecutorImpl() override {
297     for (int i = 0; i < graph_->num_node_ids(); i++) {
298       NodeItem* item = gview_.node(i);
299       if (item != nullptr) {
300         params_.delete_kernel(item->kernel);
301       }
302     }
303     for (auto fiter : frame_info_) {
304       delete fiter.second;
305     }
306   }
307 
308   Status Initialize();
309 
310   // Process all Nodes in the current graph, attempting to infer the
311   // memory allocation attributes to be used wherever they may allocate
312   // a tensor buffer.
313   Status SetAllocAttrs();
314 
315   void RunAsync(const Args& args, DoneCallback done) override;
316 
317  private:
318   friend class ExecutorState;
319 
320   struct ControlFlowInfo {
321     gtl::FlatSet<string> unique_frame_names;
322     std::vector<string> frame_names;
323   };
324 
325   struct FrameInfo {
FrameInfotensorflow::__anon6f8fc96b0111::ExecutorImpl::FrameInfo326     FrameInfo()
327         : input_count(0),
328           total_inputs(0),
329           pending_counts(nullptr),
330           nodes(nullptr) {}
331 
332     // The total number of inputs to a frame.
333     int input_count;
334 
335     // The total number of input tensors of a frame.
336     // == sum(nodes[*].num_inputs()) where nodes are the nodes in the frame.
337     int total_inputs;
338 
339     // Used to determine the next place to allocate space in the
340     // pending_counts data structure we'll eventually construct
341     PendingCounts::Layout pending_counts_layout;
342 
343     // Each frame has its own PendingCounts only for the nodes in the frame.
344     PendingCounts* pending_counts;  // Owned
345 
346     // The nodes in a frame. Used only for debugging.
347     std::vector<const Node*>* nodes;  // Owned
348 
~FrameInfotensorflow::__anon6f8fc96b0111::ExecutorImpl::FrameInfo349     ~FrameInfo() {
350       delete pending_counts;
351       delete nodes;
352     }
353   };
354 
355   static Status BuildControlFlowInfo(const Graph* graph,
356                                      ControlFlowInfo* cf_info);
357   void InitializePending(const Graph* graph, const ControlFlowInfo& cf_info);
358 
EnsureFrameInfo(const string & fname)359   FrameInfo* EnsureFrameInfo(const string& fname) {
360     auto slot = &frame_info_[fname];
361     if (*slot == nullptr) {
362       *slot = new FrameInfo;
363     }
364     return *slot;
365   }
366 
367   // Owned.
368   LocalExecutorParams params_;
369   std::unique_ptr<const Graph> graph_;
370   GraphView gview_;
371 
372   // A cached value of params_
373   bool device_record_tensor_accesses_ = false;
374 
375   // Root nodes (with no in edges) that should form the initial ready queue
376   std::vector<const Node*> root_nodes_;
377 
378   // Mapping from frame name to static information about the frame.
379   // TODO(yuanbyu): We could cache it along with the graph so to avoid
380   // the overhead of constructing it for each executor instance.
381   gtl::FlatMap<string, FrameInfo*> frame_info_;
382 
383   TF_DISALLOW_COPY_AND_ASSIGN(ExecutorImpl);
384 };
385 
386 // Infer memory allocation attributes of a node n's output,
387 // based on its use node dst.  Note that dst might not be directly
388 // connected to n by a single edge, but might be a downstream
389 // consumer of n's output by reference.  *attr is updated with any
390 // necessary attributes.
391 Status InferAllocAttr(const Node* n, const Node* dst,
392                       const DeviceNameUtils::ParsedName& local_dev_name,
393                       AllocatorAttributes* attr);
394 
~GraphView()395 GraphView::~GraphView() {
396   static_assert(std::is_trivially_destructible<AllocatorAttributes>::value,
397                 "Update code if AllocatorAttributes gains a destructor");
398   static_assert(std::is_trivially_destructible<EdgeInfo>::value,
399                 "Update code if EdgeInfo gains a destructor");
400   for (int i = 0; i < num_nodes_; i++) {
401     NodeItem* n = node(i);
402     if (n != nullptr) {
403       n->NodeItem::~NodeItem();
404       // Memory for "n" itself is held in space_ & gets cleaned up below
405     }
406   }
407   delete[] node_offsets_;
408   delete[] space_;
409 }
410 
NodeItemBytes(const Node * n)411 size_t GraphView::NodeItemBytes(const Node* n) {
412   const size_t num_output_edges = n->out_edges().size();
413   const int num_inputs = n->num_inputs();
414   const int num_outputs = n->num_outputs();
415 
416   // Compute number of bytes needed for NodeItem and variable length data.
417   // We do not subtract sizeof(var) since num_inputs/num_outputs might
418   // both be zero.
419   const size_t raw_bytes =
420       sizeof(NodeItem)                             // Fixed
421       + num_output_edges * sizeof(EdgeInfo)        // output_edges[...]
422       + num_outputs * sizeof(AllocatorAttributes)  // output_attr[...]
423       + num_outputs * sizeof(int)                  // forward_from[num_outputs]
424       + num_inputs * sizeof(uint8)                 // input_type[num_inputs]
425       + num_outputs * sizeof(uint8);               // output_type[num_outputs]
426   static constexpr size_t kItemAlignment = sizeof(NodeItem*);
427   static_assert(kItemAlignment % alignof(NodeItem) == 0,
428                 "NodeItem must be aligned with kItemAlignment");
429   static_assert(kItemAlignment % alignof(EdgeInfo) == 0,
430                 "EdgeInfo must be aligned with kItemAlignment");
431   static_assert(kItemAlignment % alignof(AllocatorAttributes) == 0,
432                 "AllocatorAttributes must be aligned with kItemAlignment");
433   static_assert(sizeof(NodeItem) % alignof(EdgeInfo) == 0,
434                 "NodeItem must be aligned with EdgeInfo");
435   static_assert(sizeof(NodeItem) % alignof(AllocatorAttributes) == 0,
436                 "NodeItem must be aligned with AllocatorAttributes");
437   static_assert(sizeof(EdgeInfo) % alignof(AllocatorAttributes) == 0,
438                 "EdgeInfo must be aligned with AllocatorAttributes");
439   const size_t bytes =
440       ((raw_bytes + kItemAlignment - 1) / kItemAlignment) * kItemAlignment;
441   return bytes;
442 }
443 
InitializeNode(char * ptr,const Node * n)444 char* GraphView::InitializeNode(char* ptr, const Node* n) {
445   const int id = n->id();
446   CHECK(node_offsets_[id] == kuint32max);  // Initial value in constructor
447 
448   const size_t bytes = NodeItemBytes(n);
449   constexpr size_t kItemAlignment = sizeof(NodeItem*);
450   CHECK_EQ(reinterpret_cast<uintptr_t>(ptr) % kItemAlignment, 0);
451   NodeItem* item = reinterpret_cast<NodeItem*>(ptr);
452 
453   // We store a 32-bit offset relative to the beginning of space_, so that we
454   // only need an array of 32-bit values to map from node id to the NodeItem*,
455   // (versus 64 bits on most machines if we just stored an array of NodeItem*
456   // pointers). Casting to int64 is needed on 32bit CPU to avoid comparing
457   // values as "int" vs "size_t" in CHECK_LE.
458   CHECK_LE(static_cast<int64>(ptr - space_), kuint32max);
459   const uint32 offset = static_cast<uint32>(ptr - space_);
460   node_offsets_[id] = offset;
461   ptr += bytes;
462 
463   const size_t num_output_edges = n->out_edges().size();
464   const int num_inputs = n->num_inputs();
465   const int num_outputs = n->num_outputs();
466 
467   new (item) NodeItem();
468   item->num_inputs = num_inputs;
469   item->num_outputs = num_outputs;
470   item->num_output_edges = num_output_edges;
471 
472   // Fill output edges.
473   // Keep track of the last EdgeInfo in the EdgeInfo array that references
474   // a given output slot.  For all but the last, we need to do a copy of the
475   // Tensor when propagating results downstream in the graph, but for the
476   // last one, we can just do a move of the Tensor object to propagate it.
477   gtl::InlinedVector<EdgeInfo*, 4> last_indices(num_outputs, nullptr);
478   EdgeInfo* dst_edge = item->output_edge_base();
479   for (auto e : n->out_edges()) {
480     dst_edge->dst_id = e->dst()->id();
481     CHECK_LE(e->src_output(), 0x3FFFFFFF);  // Must fit in 31 bits
482     dst_edge->output_slot = e->src_output();
483     dst_edge->is_last = false;
484     const int output_slot = dst_edge->output_slot;
485     if (output_slot >= 0) {
486       last_indices[output_slot] = dst_edge;
487     }
488     dst_edge->input_slot = e->dst_input();
489     dst_edge++;
490   }
491   for (EdgeInfo* edge_info : last_indices) {
492     if (edge_info != nullptr) {
493       edge_info->is_last = true;
494     }
495   }
496 
497   AllocatorAttributes* output_attrs = item->output_attr_base();
498   for (int i = 0; i < num_outputs; i++) {
499     new (&output_attrs[i]) AllocatorAttributes();
500   }
501 
502   DCHECK_LT(DataType_MAX, 255);  // Must fit in uint8
503   uint8* input_types = item->input_type_base();
504   for (int i = 0; i < num_inputs; i++) {
505     input_types[i] = static_cast<uint8>(n->input_type(i));
506     DCHECK_EQ(item->input_type(i), n->input_type(i));
507   }
508 
509   // Check ScopedAllocatorAttrs and forward_from.  Also assign output_types.
510   {
511     std::vector<int> forward_input;
512     Status fwd_status =
513         GetNodeAttr(n->attrs(), "_forward_input", &forward_input);
514     std::vector<int> scoped_allocator_attrs;
515     Status sa_status =
516         GetNodeAttr(n->attrs(), "_scoped_allocator", &scoped_allocator_attrs);
517 
518     int* forward_from = item->forward_from_base();
519     uint8* output_types = item->output_type_base();
520     for (int i = 0; i < num_outputs; ++i) {
521       output_types[i] = static_cast<uint8>(n->output_type(i));
522       DCHECK_EQ(item->output_type(i), n->output_type(i));
523 
524       forward_from[i] = OpKernelContext::Params::kNoReservation;
525       if (sa_status.ok()) {
526         for (int j = 0; j < scoped_allocator_attrs.size(); j += 2) {
527           if (scoped_allocator_attrs[j] == i) {
528             // This output slot must be explicitly allocated from a
529             // ScopedAllocator.
530             forward_from[i] = OpKernelContext::Params::kNeverForward;
531             DCHECK_EQ(output_attrs[i].scope_id, 0);
532             output_attrs[i].scope_id = scoped_allocator_attrs[j + 1];
533           }
534         }
535       }
536       if (fwd_status.ok() &&
537           forward_from[i] == OpKernelContext::Params::kNoReservation) {
538         DCHECK_EQ(forward_input.size() % 2, 0);
539         for (int j = 0; j < forward_input.size(); j += 2) {
540           if (forward_input[j + 1] == i) {
541             DCHECK_EQ(forward_from[i], OpKernelContext::Params::kNoReservation);
542             forward_from[i] = forward_input[j];
543             break;
544           }
545         }
546       }
547     }
548   }
549 
550   return ptr;
551 }
552 
Initialize(const Graph * g)553 void GraphView::Initialize(const Graph* g) {
554   CHECK(node_offsets_ == nullptr);
555   const int num_nodes = g->num_node_ids();
556   num_nodes_ = num_nodes;
557   size_t total_bytes = 0;
558   for (const Node* n : g->nodes()) {
559     total_bytes += NodeItemBytes(n);
560   }
561 
562   node_offsets_ = new uint32[num_nodes];
563   for (int i = 0; i < num_nodes; i++) {
564     node_offsets_[i] = kuint32max;
565   }
566 
567   space_ = new char[total_bytes];  // NodeItem objects are allocated here
568   char* ptr = space_;
569   for (const Node* n : g->nodes()) {
570     ptr = InitializeNode(ptr, n);
571   }
572   CHECK_EQ(ptr, space_ + total_bytes);
573 }
574 
GetMaxPendingCounts(const Node * n,size_t * max_pending,size_t * max_dead_count)575 void GetMaxPendingCounts(const Node* n, size_t* max_pending,
576                          size_t* max_dead_count) {
577   const size_t num_in_edges = n->in_edges().size();
578   size_t initial_count;
579   if (IsMerge(n)) {
580     // merge waits all control inputs so we initialize the pending
581     // count to be the number of control edges.
582     int32 num_control_edges = 0;
583     for (const Edge* edge : n->in_edges()) {
584       if (edge->IsControlEdge()) {
585         num_control_edges++;
586       }
587     }
588     // Use bit 0 to indicate if we are waiting for a ready live data input.
589     initial_count = 1 + (num_control_edges << 1);
590   } else {
591     initial_count = num_in_edges;
592   }
593 
594   *max_pending = initial_count;
595   *max_dead_count = num_in_edges;
596 }
597 
Initialize()598 Status ExecutorImpl::Initialize() {
599   gview_.Initialize(graph_.get());
600 
601   // Build the information about frames in this subgraph.
602   ControlFlowInfo cf_info;
603   TF_RETURN_IF_ERROR(BuildControlFlowInfo(graph_.get(), &cf_info));
604 
605   // Cache this value so we make this virtual function call once, rather
606   // that O(# steps * # nodes per step) times.
607   device_record_tensor_accesses_ =
608       params_.device->RequiresRecordingAccessedTensors();
609 
610   for (auto& it : cf_info.unique_frame_names) {
611     EnsureFrameInfo(it)->nodes = new std::vector<const Node*>;
612   }
613 
614   // Preprocess every node in the graph to create an instance of op
615   // kernel for each node.
616   for (const Node* n : graph_->nodes()) {
617     const int id = n->id();
618     const string& frame_name = cf_info.frame_names[id];
619     FrameInfo* frame_info = EnsureFrameInfo(frame_name);
620 
621     // See if this node is a root node, and if so, add to root_nodes_.
622     if (n->in_edges().empty()) {
623       root_nodes_.push_back(n);
624     }
625 
626     NodeItem* item = gview_.node(id);
627     item->node = n;
628 
629     item->input_start = frame_info->total_inputs;
630     frame_info->total_inputs += n->num_inputs();
631 
632     Status s = params_.create_kernel(n->def(), &item->kernel);
633     if (!s.ok()) {
634       item->kernel = nullptr;
635       s = AttachDef(s, *n);
636       LOG(ERROR) << "Executor failed to create kernel. " << s;
637       return s;
638     }
639     CHECK(item->kernel);
640     item->kernel_is_async = (item->kernel->AsAsync() != nullptr);
641     item->is_merge = IsMerge(n);
642     item->is_enter = IsEnter(n);
643     if (item->is_enter) {
644       bool is_constant_enter;
645       TF_RETURN_IF_ERROR(
646           GetNodeAttr(n->attrs(), "is_constant", &is_constant_enter));
647       item->is_constant_enter = is_constant_enter;
648     } else {
649       item->is_constant_enter = false;
650     }
651     item->is_exit = IsExit(n);
652     item->is_control_trigger = IsControlTrigger(n);
653     item->is_sink = IsSink(n);
654     item->is_enter_exit_or_next_iter =
655         (IsEnter(n) || IsExit(n) || IsNextIteration(n));
656 
657     // Compute the maximum values we'll store for this node in the
658     // pending counts data structure, and allocate a handle in
659     // that frame's pending counts data structure that has enough
660     // space to store these maximal count values.
661     size_t max_pending, max_dead;
662     GetMaxPendingCounts(n, &max_pending, &max_dead);
663     item->pending_id =
664         frame_info->pending_counts_layout.CreateHandle(max_pending, max_dead);
665 
666     // Initialize static information about the frames in the graph.
667     frame_info->nodes->push_back(n);
668     if (IsEnter(n)) {
669       string enter_name;
670       TF_RETURN_IF_ERROR(GetNodeAttr(n->attrs(), "frame_name", &enter_name));
671       EnsureFrameInfo(enter_name)->input_count++;
672     }
673   }
674 
675   // Initialize PendingCounts only after item->pending_id is initialized for
676   // all nodes.
677   InitializePending(graph_.get(), cf_info);
678 
679   return gview_.SetAllocAttrs(graph_.get(), params_.device);
680 }
681 
682 // If a Node has been marked to use a ScopedAllocator x for output i, then
683 // sc_attr will contain the subsequence (i, x) at an even offset.  This function
684 // extracts and transfers that ScopedAllocator id to alloc_attr.  For now, we
685 // only allow one ScopedAllocator use per Node.
ExtractScopedAllocatorAttr(const std::vector<int> & sc_attr,int output_index,AllocatorAttributes * alloc_attr)686 bool ExtractScopedAllocatorAttr(const std::vector<int>& sc_attr,
687                                 int output_index,
688                                 AllocatorAttributes* alloc_attr) {
689   DCHECK_LE(2, sc_attr.size());
690   for (int i = 0; i < sc_attr.size(); i += 2) {
691     if (sc_attr[i] == output_index) {
692       CHECK_EQ(alloc_attr->scope_id, 0);
693       alloc_attr->scope_id = sc_attr[i + 1];
694       return true;
695     }
696   }
697   return false;
698 }
699 
SetScopedAllocatorAttrs(const std::vector<const Node * > & sa_nodes)700 void GraphView::SetScopedAllocatorAttrs(
701     const std::vector<const Node*>& sa_nodes) {
702   for (const Node* sa : sa_nodes) {
703     NodeItem* sa_item = node(sa->id());
704     AllocatorAttributes* sa_attrs = sa_item->output_attr_base();
705     // Control edges out of the ScopedAllocator should be use instances, but may
706     // include a few other nodes.
707     for (const auto& e : sa->out_edges()) {
708       if (!e->IsControlEdge()) {
709         continue;
710       }
711       Node* use_node = e->dst();
712       NodeItem* item = node(use_node->id());
713       AllocatorAttributes* use_attrs = item->output_attr_base();
714       std::vector<int> scoped_allocator_attrs;
715       Status s = GetNodeAttr(use_node->attrs(), "_scoped_allocator",
716                              &scoped_allocator_attrs);
717       if (!s.ok()) {
718         VLOG(2) << "Failed to find expected ScopedAllocator attr on "
719                 << use_node->name();
720         continue;
721       }
722       // There can be more than one output using ScopedAllocation, but this
723       // analysis assumes they use the same ScopedAllocator.
724       for (const auto& e : use_node->out_edges()) {
725         if (!e->IsControlEdge()) {
726           AllocatorAttributes attr;
727           if (ExtractScopedAllocatorAttr(scoped_allocator_attrs,
728                                          e->src_output(), &attr)) {
729             // Set the scope_id on this use instance node.
730             (use_attrs + e->src_output())->Merge(attr);
731             // Propagate the other attributes of this node back to the SA node.
732             attr = *(use_attrs + e->src_output());
733             attr.scope_id = 0;
734             sa_attrs->Merge(attr);
735           }
736         }
737       }
738     }
739   }
740 }
741 
SetAllocAttrs(const Graph * g,const Device * device)742 Status GraphView::SetAllocAttrs(const Graph* g, const Device* device) {
743   Status s;
744   DeviceNameUtils::ParsedName local_dev_name = device->parsed_name();
745 
746   std::vector<const Node*> scoped_allocator_instances;
747   for (const Node* n : g->nodes()) {
748     NodeItem* item = node(n->id());
749     AllocatorAttributes* attrs = item->output_attr_base();
750     if (IsScopedAllocator(n)) {
751       scoped_allocator_instances.push_back(n);
752     }
753 
754     // Examine the out edges of each node looking for special use
755     // cases that may affect memory allocation attributes.
756     for (const auto& e : n->out_edges()) {
757       if (!e->IsControlEdge()) {
758         AllocatorAttributes attr;
759         s = InferAllocAttr(n, e->dst(), local_dev_name, &attr);
760         if (!s.ok()) return s;
761         if (attr.value != 0 || attr.scope_id != 0) {
762           attrs[e->src_output()].Merge(attr);
763         }
764       }
765     }
766 
767     for (int out = 0; out < n->num_outputs(); out++) {
768       const OpKernel* op_kernel = item->kernel;
769       DCHECK_LT(out, op_kernel->output_memory_types().size());
770       bool on_host = op_kernel->output_memory_types()[out] == HOST_MEMORY;
771       if (on_host) {
772         AllocatorAttributes h;
773         h.set_on_host(on_host);
774         attrs[out].Merge(h);
775       }
776     }
777   }
778   SetScopedAllocatorAttrs(scoped_allocator_instances);
779   return s;
780 }
781 
InferAllocAttr(const Node * n,const Node * dst,const DeviceNameUtils::ParsedName & local_dev_name,AllocatorAttributes * attr)782 Status InferAllocAttr(const Node* n, const Node* dst,
783                       const DeviceNameUtils::ParsedName& local_dev_name,
784                       AllocatorAttributes* attr) {
785   Status s;
786   // Note that it's possible for *n to be a Recv and *dst to be a Send,
787   // so these two cases are not mutually exclusive.
788   if (IsRecv(n)) {
789     string src_name;
790     s = GetNodeAttr(n->attrs(), "send_device", &src_name);
791     if (!s.ok()) return s;
792     DeviceNameUtils::ParsedName parsed_src_name;
793     if (!DeviceNameUtils::ParseFullName(src_name, &parsed_src_name)) {
794       s = errors::Internal("Bad send_device attr '", src_name, "' in node ",
795                            n->name());
796       return s;
797     }
798     if (!DeviceNameUtils::IsSameAddressSpace(parsed_src_name, local_dev_name)) {
799       // Value is going to be the sink of an RPC.
800       attr->set_nic_compatible(true);
801       VLOG(2) << "node " << n->name() << " is the sink of an RPC in";
802     } else if ((local_dev_name.type == "CPU" || n->IsHostRecv()) &&
803                parsed_src_name.type != "CPU") {
804       // Value is going to be the sink of a local DMA from GPU to CPU (or
805       // other types of accelerators).
806       attr->set_gpu_compatible(true);
807       VLOG(2) << "node " << n->name() << " is the sink of a gpu->cpu copy";
808     } else {
809       VLOG(2) << "default alloc case local type " << local_dev_name.type
810               << " remote type " << parsed_src_name.type;
811     }
812   }
813   if (IsSend(dst)) {
814     string dst_name;
815     s = GetNodeAttr(dst->attrs(), "recv_device", &dst_name);
816     if (!s.ok()) return s;
817     DeviceNameUtils::ParsedName parsed_dst_name;
818     if (!DeviceNameUtils::ParseFullName(dst_name, &parsed_dst_name)) {
819       s = errors::Internal("Bad recv_device attr '", dst_name, "' in node ",
820                            n->name());
821       return s;
822     }
823     if (!DeviceNameUtils::IsSameAddressSpace(parsed_dst_name, local_dev_name)) {
824       // Value is going to be the source of an RPC.
825       attr->set_nic_compatible(true);
826       VLOG(2) << "node " << n->name() << " is the source of an RPC out";
827     } else if ((local_dev_name.type == "CPU" || dst->IsHostSend()) &&
828                parsed_dst_name.type != "CPU") {
829       // Value is going to be the source of a local DMA from CPU to GPU (or
830       // other types of accelerators).
831       // Note that this does not cover the case where the allocation of the
832       // output tensor is not generated by the src: n.
833       attr->set_gpu_compatible(true);
834       VLOG(2) << "node " << n->name() << " is the source of a cpu->gpu copy";
835     } else {
836       VLOG(2) << "default alloc case local type " << local_dev_name.type
837               << " remote type " << parsed_dst_name.type;
838     }
839   }
840   if (n->IsCollective()) {
841     // We'll make the sweeping assumption that any collective op is going
842     // to be involved in network i/o.
843     attr->set_nic_compatible(true);
844   }
845   return s;
846 }
847 
848 // The state associated with one invocation of ExecutorImpl::Run.
849 // ExecutorState dispatches nodes when they become ready and keeps
850 // track of how many predecessors of a node have not done (pending_).
851 class ExecutorState {
852  public:
853   ExecutorState(const Executor::Args& args, ExecutorImpl* impl);
854   ~ExecutorState();
855 
856   void RunAsync(Executor::DoneCallback done);
857 
858  private:
859   // Either a tensor pointer (pass-by-reference) or a tensor (pass-by-value).
860   // TODO(yuanbyu): A better way to do "has_value"?
861   struct Entry {
Entrytensorflow::__anon6f8fc96b0111::ExecutorState::Entry862     Entry() {}
Entrytensorflow::__anon6f8fc96b0111::ExecutorState::Entry863     Entry(const Entry& other)
864         : ref(other.ref),
865           ref_mu(other.ref_mu),
866           has_value(other.has_value),
867           val_field_is_set(other.val_field_is_set),
868           alloc_attr(other.alloc_attr),
869           device_context(other.device_context) {
870       if (val_field_is_set) {
871         val.Init(*other.val);
872       }
873     }
~Entrytensorflow::__anon6f8fc96b0111::ExecutorState::Entry874     ~Entry() {
875       if (val_field_is_set) val.Destroy();
876     }
877 
operator =tensorflow::__anon6f8fc96b0111::ExecutorState::Entry878     Entry& operator=(const Entry& other) {
879       if (val_field_is_set) {
880         val.Destroy();
881       }
882       ref = other.ref;
883       ref_mu = other.ref_mu;
884       has_value = other.has_value;
885       val_field_is_set = other.val_field_is_set;
886       alloc_attr = other.alloc_attr;
887       device_context = other.device_context;
888       if (val_field_is_set) {
889         val.Init(*other.val);
890       }
891       return *this;
892     }
893 
operator =tensorflow::__anon6f8fc96b0111::ExecutorState::Entry894     Entry& operator=(Entry&& other) {
895       if (val_field_is_set) {
896         val.Destroy();
897       }
898       ref = other.ref;
899       ref_mu = other.ref_mu;
900       has_value = other.has_value;
901       val_field_is_set = other.val_field_is_set;
902       alloc_attr = other.alloc_attr;
903       device_context = other.device_context;
904       if (val_field_is_set) {
905         val.Init(std::move(*other.val));
906       }
907       return *this;
908     }
909 
910     // Clears the <val> field.
ClearValtensorflow::__anon6f8fc96b0111::ExecutorState::Entry911     void ClearVal() {
912       if (val_field_is_set) {
913         val.Destroy();
914         val_field_is_set = false;
915         has_value = false;
916       }
917     }
918 
919     // A tensor value, if val_field_is_set.
920     ManualConstructor<Tensor> val;
921 
922     Tensor* ref = nullptr;    // A tensor reference.
923     mutex* ref_mu = nullptr;  // mutex for *ref if ref is not nullptr.
924 
925     // Whether the value exists, either in <val> or <ref>.
926     bool has_value = false;
927 
928     bool val_field_is_set = false;
929 
930     // The attributes of the allocator that creates the tensor.
931     AllocatorAttributes alloc_attr;
932 
933     // Every entry carries an optional DeviceContext containing
934     // Device-specific information about how the Tensor was produced.
935     DeviceContext* device_context = nullptr;
936   };
937 
938   // Contains a value for [node->id()] for the device context assigned by the
939   // device at the beginning of a step.
940   DeviceContextMap device_context_map_;
941 
942   struct TaggedNode;
943   typedef gtl::InlinedVector<TaggedNode, 8> TaggedNodeSeq;
944   typedef gtl::InlinedVector<Entry, 4> EntryVector;
945 
946   struct IterationState {
IterationStatetensorflow::__anon6f8fc96b0111::ExecutorState::IterationState947     explicit IterationState(const PendingCounts* pending_counts,
948                             int total_input_tensors)
949         : input_tensors(new Entry[total_input_tensors]),
950           outstanding_ops(0),
951           outstanding_frame_count(0),
952           counts_(*pending_counts) {  // Initialize with copy of *pending_counts
953     }
954 
955     // The state of an iteration.
956 
957     // One copy per iteration. For iteration k, i-th node's j-th input is in
958     // input_tensors[k][impl_->nodes[i].input_start + j]. An entry is either
959     // a tensor pointer (pass-by-reference) or a tensor (pass-by-value).
960     //
961     // NOTE: No need to protect input_tensors[i] by any locks because it
962     // is resized once. Each element of tensors_ is written once by the
963     // source node of an edge and is cleared by the destination of the same
964     // edge. The latter node is never run concurrently with the former node.
965     Entry* input_tensors;
966 
967     // The number of outstanding ops for each iteration.
968     size_t outstanding_ops;
969 
970     // The number of outstanding frames for each iteration.
971     int outstanding_frame_count;
pendingtensorflow::__anon6f8fc96b0111::ExecutorState::IterationState972     int pending(PendingCounts::Handle h) { return counts_.pending(h); }
decrement_pendingtensorflow::__anon6f8fc96b0111::ExecutorState::IterationState973     int decrement_pending(PendingCounts::Handle h, int v) {
974       return counts_.decrement_pending(h, v);
975     }
976     // Mark a merge node as live
977     // REQUIRES: Node corresponding to "h" is a merge node
mark_livetensorflow::__anon6f8fc96b0111::ExecutorState::IterationState978     void mark_live(PendingCounts::Handle h) { counts_.mark_live(h); }
979     // Mark a node to show that processing has started.
mark_startedtensorflow::__anon6f8fc96b0111::ExecutorState::IterationState980     void mark_started(PendingCounts::Handle h) { counts_.mark_started(h); }
981     // Mark a node to show that processing has completed.
mark_completedtensorflow::__anon6f8fc96b0111::ExecutorState::IterationState982     void mark_completed(PendingCounts::Handle h) { counts_.mark_completed(h); }
node_statetensorflow::__anon6f8fc96b0111::ExecutorState::IterationState983     PendingCounts::NodeState node_state(PendingCounts::Handle h) {
984       return counts_.node_state(h);
985     }
986 
dead_counttensorflow::__anon6f8fc96b0111::ExecutorState::IterationState987     int dead_count(PendingCounts::Handle h) { return counts_.dead_count(h); }
increment_dead_counttensorflow::__anon6f8fc96b0111::ExecutorState::IterationState988     void increment_dead_count(PendingCounts::Handle h) {
989       counts_.increment_dead_count(h);
990     }
adjust_for_activationtensorflow::__anon6f8fc96b0111::ExecutorState::IterationState991     void adjust_for_activation(PendingCounts::Handle h, bool increment_dead,
992                                int* pending_result, int* dead_result) {
993       counts_.adjust_for_activation(h, increment_dead, pending_result,
994                                     dead_result);
995     }
996 
~IterationStatetensorflow::__anon6f8fc96b0111::ExecutorState::IterationState997     ~IterationState() { delete[] input_tensors; }
998 
999    private:
1000     PendingCounts counts_;
1001   };
1002 
1003   struct FrameState {
FrameStatetensorflow::__anon6f8fc96b0111::ExecutorState::FrameState1004     explicit FrameState(const ExecutorImpl* impl, int parallel_iters)
1005         : executor(impl),
1006           max_parallel_iterations(parallel_iters),
1007           num_outstanding_iterations(1) {}
1008 
1009     // A new frame is created for each loop. Execution starts at iteration 0.
1010     // When a value at iteration 0 passes through a NextIteration node,
1011     // iteration 1 is created and starts running. Note that iteration 0 may
1012     // still be running so multiple iterations may run in parallel. The
1013     // frame maintains the state of iterations in several data structures
1014     // such as pending_count and input_tensors. When iteration 0 completes,
1015     // we garbage collect the state of iteration 0.
1016     //
1017     // A frame instance is considered "done" and can be garbage collected
1018     // if all its inputs have entered and all its iterations are "done".
1019     //
1020     // A frame manages the live iterations of an iterative computation.
1021     // Iteration i is considered "done" when there are no outstanding ops,
1022     // frames at iteration i are done, all recvs for this iteration are
1023     // completed, and iteration i-1 is done. For iteration 0, we instead
1024     // wait for there to be no more pending inputs of the frame.
1025     //
1026     // Frames and iterations are garbage collected once they are done.
1027     // The state we need to keep around is highly dependent on the
1028     // parallelism enabled by the scheduler. We may want to have the
1029     // scheduler dynamically control the outstanding number of live
1030     // parallel frames and iterations. To reduce the state space, the
1031     // scheduler might want to schedule ops in inner frames first and
1032     // lower iterations first.
1033     //
1034     // This frame state is mostly initialized lazily on demand so we
1035     // don't introduce unnecessary overhead.
1036 
1037     // The executor the frame is in.
1038     const ExecutorImpl* executor = nullptr;
1039 
1040     // The name of this frame, which is the concatenation of its parent
1041     // frame name, the iteration of the parent frame when this frame was
1042     // created, and the value of the attr 'frame_name'.
1043     string frame_name;
1044 
1045     // The unique id for this frame. Generated by fingerprinting
1046     // frame_name.
1047     uint64 frame_id;
1048 
1049     // The iteration id of its parent frame when this frame is created.
1050     // -1 if there is no parent frame. The frame_name/parent_iter pair
1051     // uniquely identifies this FrameState.
1052     int64 parent_iter = -1;
1053 
1054     // The FrameState of its parent frame.
1055     FrameState* parent_frame = nullptr;
1056 
1057     // The maximum allowed number of parallel iterations.
1058     const int max_parallel_iterations;
1059 
1060     // The number of inputs this frame is still waiting.
1061     int num_pending_inputs = 0;
1062 
1063     // The highest iteration number we have reached so far in this frame.
1064     int64 iteration_count GUARDED_BY(mu) = 0;
1065 
1066     // The number of outstanding iterations.
1067     int num_outstanding_iterations GUARDED_BY(mu) = 1;
1068 
1069     // The active iteration states of this frame.
1070     gtl::InlinedVector<IterationState*, 12> iterations;
1071 
1072     // The NextIteration nodes to enter a new iteration. If the number of
1073     // outstanding iterations reaches the limit, we will defer the start of
1074     // the next iteration until the number of outstanding iterations falls
1075     // below the limit.
1076     std::vector<std::pair<const Node*, Entry>> next_iter_roots GUARDED_BY(mu);
1077 
1078     // The values of the loop invariants for this loop. They are added into
1079     // this list as they "enter" the frame. When a loop invariant enters,
1080     // we make it available to all active iterations. When the frame starts
1081     // a new iteration, we make all the current loop invariants available
1082     // to the new iteration.
1083     std::vector<std::pair<const Node*, Entry>> inv_values GUARDED_BY(mu);
1084 
1085     // The list of dead exit nodes for the current highest iteration. We
1086     // will only "execute" the dead exits of the final iteration.
1087     std::vector<const Node*> dead_exits GUARDED_BY(mu);
1088 
1089     // Static information specific to this frame.
1090     PendingCounts* pending_counts = nullptr;
1091     int total_input_tensors = 0;
1092     std::vector<const Node*>* nodes = nullptr;
1093 
1094     // Lock ordering: ExecutorState.mu_ < mu;
1095     // during structured traversal: parent_frame->mu < mu.
1096     mutex mu;
1097 
InitializeFrameInfotensorflow::__anon6f8fc96b0111::ExecutorState::FrameState1098     void InitializeFrameInfo(const string& enter_name) {
1099       auto it_frame_info = executor->frame_info_.find(enter_name);
1100       DCHECK(it_frame_info != executor->frame_info_.end());
1101       ExecutorImpl::FrameInfo* finfo = it_frame_info->second;
1102       pending_counts = finfo->pending_counts;
1103       total_input_tensors = finfo->total_inputs;
1104       num_pending_inputs = finfo->input_count;
1105       nodes = finfo->nodes;
1106     }
1107 
GetIterationtensorflow::__anon6f8fc96b0111::ExecutorState::FrameState1108     inline IterationState* GetIteration(int64 iter)
1109         EXCLUSIVE_LOCKS_REQUIRED(mu) {
1110       size_t index = iter % iterations.size();
1111       return iterations[index];
1112     }
1113 
SetIterationtensorflow::__anon6f8fc96b0111::ExecutorState::FrameState1114     inline void SetIteration(int64 iter, IterationState* state)
1115         EXCLUSIVE_LOCKS_REQUIRED(mu) {
1116       size_t index = iter % iterations.size();
1117       DCHECK(state == nullptr || iterations[index] == nullptr);
1118       iterations[index] = state;
1119     }
1120 
1121     // Decrement the outstanding op count and clean up the iterations in the
1122     // frame. Return true iff the execution of the frame is done.
DecrementOutstandingOpstensorflow::__anon6f8fc96b0111::ExecutorState::FrameState1123     inline bool DecrementOutstandingOps(const GraphView* gview, int64 iter,
1124                                         TaggedNodeSeq* ready) {
1125       mutex_lock l(mu);
1126       return DecrementOutstandingOpsLocked(gview, iter, ready);
1127     }
1128 
1129     // Decrement the outstanding op count and clean up the iterations in the
1130     // frame. Return true iff the execution of the frame is done.
DecrementOutstandingOpsLockedtensorflow::__anon6f8fc96b0111::ExecutorState::FrameState1131     inline bool DecrementOutstandingOpsLocked(const GraphView* gview,
1132                                               int64 iter, TaggedNodeSeq* ready)
1133         EXCLUSIVE_LOCKS_REQUIRED(mu) {
1134       IterationState* istate = GetIteration(iter);
1135       istate->outstanding_ops--;
1136       if (istate->outstanding_ops != 0) {
1137         return false;
1138       } else {
1139         return CleanupIterations(gview, iter, ready);
1140       }
1141     }
1142 
1143     // Returns true if the computation in the frame is completed.
IsFrameDonetensorflow::__anon6f8fc96b0111::ExecutorState::FrameState1144     inline bool IsFrameDone() EXCLUSIVE_LOCKS_REQUIRED(mu) {
1145       return (num_pending_inputs == 0 && num_outstanding_iterations == 0);
1146     }
1147 
1148     // Returns true if the iteration of the frame is completed.
1149     bool IsIterationDone(int64 iter) EXCLUSIVE_LOCKS_REQUIRED(mu);
1150 
1151     // Increments the iteration id. If this is a new iteration, initialize it.
1152     void IncrementIteration(const GraphView* gview, TaggedNodeSeq* ready)
1153         EXCLUSIVE_LOCKS_REQUIRED(mu);
1154 
1155     // Activate all the deferred NextIteration nodes in a new iteration.
1156     void ActivateNexts(const GraphView* gview, int64 iter, TaggedNodeSeq* ready)
1157         EXCLUSIVE_LOCKS_REQUIRED(mu);
1158 
1159     // Activate all the current loop invariants in a new iteration.
1160     void ActivateLoopInvs(const GraphView* gview, int64 iter,
1161                           TaggedNodeSeq* ready) EXCLUSIVE_LOCKS_REQUIRED(mu);
1162 
1163     // Add a new loop invariant and make it available to all active
1164     // iterations.
1165     void AddLoopInv(const NodeItem* item, const Entry& value,
1166                     TaggedNodeSeq* ready) EXCLUSIVE_LOCKS_REQUIRED(mu);
1167 
1168     // Activate the successors of a node. Contents of *outputs are left in an
1169     // indeterminate state after returning from this method.
1170     void ActivateNodes(const NodeItem* item, const bool is_dead, int64 iter,
1171                        EntryVector* outputs, TaggedNodeSeq* ready)
1172         EXCLUSIVE_LOCKS_REQUIRED(mu);
1173 
1174     // Cleanup iterations of this frame starting from iteration iter.
1175     bool CleanupIterations(const GraphView* gview, int64 iter,
1176                            TaggedNodeSeq* ready) EXCLUSIVE_LOCKS_REQUIRED(mu);
1177 
~FrameStatetensorflow::__anon6f8fc96b0111::ExecutorState::FrameState1178     ~FrameState() {
1179       for (size_t i = 0; i < iterations.size(); ++i) {
1180         delete iterations[i];
1181         iterations[i] = nullptr;
1182       }
1183     }
1184   };
1185 
1186   // A tagged node: <frame*, iter, node*>.
1187   struct TaggedNode {
1188     const Node* node = nullptr;
1189     FrameState* input_frame = nullptr;
1190     int64 input_iter = -1;
1191     bool is_dead = false;
1192 
TaggedNodetensorflow::__anon6f8fc96b0111::ExecutorState::TaggedNode1193     TaggedNode(const Node* t_node, FrameState* in_frame, int64 in_iter,
1194                bool dead) {
1195       node = t_node;
1196       input_frame = in_frame;
1197       input_iter = in_iter;
1198       is_dead = dead;
1199     }
1200   };
1201 
1202   // A drop-in replacement for std::deque<TaggedNode>.  We typically don't
1203   // have that many nodes in the ready queue, so we just use a vector and
1204   // don't free up memory from the queue as we consume nodes.
1205   class TaggedNodeReadyQueue {
1206    public:
TaggedNodeReadyQueue()1207     TaggedNodeReadyQueue() : front_index_(0) {}
1208 
push_back(TaggedNode node)1209     void push_back(TaggedNode node) { ready_.push_back(node); }
front() const1210     TaggedNode front() const {
1211       DCHECK_LT(front_index_, ready_.size());
1212       return ready_[front_index_];
1213     }
pop_front()1214     void pop_front() {
1215       DCHECK_LT(front_index_, ready_.size());
1216       front_index_++;
1217       if ((front_index_ == ready_.size()) || (front_index_ > 16384)) {
1218         if (front_index_ == ready_.size()) {
1219           ready_.clear();
1220         } else {
1221           // Lots of unused entries at beginning of vector: move everything
1222           // down to start of vector.
1223           ready_.erase(ready_.begin(), ready_.begin() + front_index_);
1224         }
1225         front_index_ = 0;
1226       }
1227     }
empty() const1228     bool empty() const { return ready_.empty(); }
begin() const1229     const TaggedNode* begin() const { return ready_.begin() + front_index_; }
end() const1230     const TaggedNode* end() const { return ready_.end(); }
1231 
1232    private:
1233     gtl::InlinedVector<TaggedNode, 16> ready_;
1234     int front_index_;
1235   };
1236 
1237   struct AsyncState;
1238 
1239   const bool vlog_;  // true if VLOG_IS_ON(1). Used to check vlog cheaply.
1240 
1241   // true if LogMemory::IsEnabled(). Used to check memory enabled cheaply.
1242   const bool log_memory_;
1243 
1244   int64 step_id_;
1245   // Not owned.
1246   Rendezvous* rendezvous_;
1247   CollectiveExecutor* collective_executor_ = nullptr;
1248   SessionState* session_state_;
1249   string session_handle_;
1250   TensorStore* tensor_store_;
1251   // Step-local container.
1252   ScopedStepContainer* step_container_;
1253   StepStatsCollectorInterface* const stats_collector_;
1254   const tracing::EventCollector* const event_collector_;
1255   Context context_;
1256 
1257   // QUESTION: Make it a checkpoint::TensorSliceReaderCacheWrapper
1258   // instead of a pointer?  (avoids having to delete).
1259   checkpoint::TensorSliceReaderCacheWrapper* slice_reader_cache_;
1260   CallFrameInterface* call_frame_;
1261   const ExecutorImpl* impl_;
1262   CancellationManager* cancellation_manager_;
1263   Executor::Args::Runner runner_;
1264   bool sync_on_finish_;
1265   const bool trace_using_annotations_;
1266 
1267   // Owned.
1268 
1269   // A flag that is set on error after the frame state has been
1270   // dumped for diagnostic purposes.
1271   bool dumped_on_error_ = false;
1272 
1273   // The root frame in which the execution of this step is started.
1274   FrameState* root_frame_;
1275 
1276   // Invoked when the execution finishes.
1277   Executor::DoneCallback done_cb_;
1278 
1279   std::atomic_int_fast32_t num_outstanding_ops_;
1280 
1281   // Available via OpKernelContext to every OpKernel invocation.
1282   mutex num_deferred_ops_mu_;
1283   condition_variable no_deferred_ops_cv_;
1284   int64 num_deferred_ops_ GUARDED_BY(num_deferred_ops_mu_) = 0;
1285 
1286   mutex mu_;
1287   Status status_ GUARDED_BY(mu_);
1288 
1289   // Mapping from frame name to outstanding frames. A new frame is created
1290   // at some iteration of an active frame. So the unique key for the new
1291   // child frame is composed of the name of the parent frame, the iteration
1292   // number at which the parent frame is creating the new frame, and the
1293   // name of the new frame from nodedef.
1294   gtl::FlatMap<string, FrameState*> outstanding_frames_ GUARDED_BY(mu_);
1295 
1296   // The unique name of a frame.
MakeFrameName(FrameState * frame,int64 iter_id,const string & name)1297   inline string MakeFrameName(FrameState* frame, int64 iter_id,
1298                               const string& name) {
1299     return strings::StrCat(frame->frame_name, ";", iter_id, ";", name);
1300   }
1301 
1302   // Find an existing or create a new child frame in the frame 'frame' at
1303   // iteration 'iter'.
1304   void FindOrCreateChildFrame(FrameState* frame, int64 iter, const Node* node,
1305                               FrameState** child);
1306 
1307   // Delete a frame. Called when the frame is done.
1308   void DeleteFrame(FrameState* frame, TaggedNodeSeq* ready);
1309 
1310   // Cleanup frames and iterations starting from frame/iter. Called when
1311   // a child frame is done.
1312   void CleanupFramesIterations(FrameState* frame, int64 iter,
1313                                TaggedNodeSeq* ready);
1314 
1315   // Process a ready node in current thread.
1316   void Process(TaggedNode node, int64 scheduled_nsec);
1317 
1318   // Before invoking item->kernel, fills in its "inputs".
1319   Status PrepareInputs(const NodeItem& item, Entry* first_input,
1320                        TensorValueVec* inputs,
1321                        DeviceContextVec* input_device_contexts,
1322                        AllocatorAttributeVec* input_alloc_attrs,
1323                        bool* is_input_dead);
1324 
1325   // After item->kernel computation is done, processes its outputs.
1326   Status ProcessOutputs(const NodeItem& item, OpKernelContext* ctx,
1327                         EntryVector* outputs, NodeExecStatsInterface* stats);
1328 
1329   // After processing the outputs, propagates the outputs to their dsts.
1330   // Contents of *outputs are left in an indeterminate state after
1331   // returning from this method.
1332   void PropagateOutputs(const TaggedNode& tagged_node, const NodeItem* item,
1333                         EntryVector* outputs, TaggedNodeSeq* ready);
1334 
1335   // "node" just finishes. Takes ownership of "stats". Returns true if
1336   // execution has completed.
1337   bool NodeDone(const Status& s, const Node* node, const TaggedNodeSeq& ready,
1338                 NodeExecStatsInterface* stats,
1339                 TaggedNodeReadyQueue* inline_ready);
1340 
1341   // Schedule all the expensive nodes in 'ready', and put all the inexpensive
1342   // nodes in 'ready' into 'inline_ready'.
1343   void ScheduleReady(const TaggedNodeSeq& ready,
1344                      TaggedNodeReadyQueue* inline_ready);
1345 
1346   // For debugging/logging only.
1347   inline void MaybeMarkCompleted(FrameState* frame, int64 iter, int64 id);
1348 
1349   // Provide debugging output about an outstanding node in the executor.
1350   void DumpPendingNodeState(const int node_id, const Entry* input_vector,
1351                             bool show_nodes_with_no_ready_inputs);
1352   void DumpActiveNodeState(const int node_id, const Entry* input_vector);
1353 
1354   // Provide debugging output about an outstanding iteration in the executor.
1355   void DumpIterationState(const FrameState* frame, IterationState* iteration);
1356 
1357   // Provide debugging output of the state of the executor.
1358   void DumpState();
1359   const Tensor* GetTensorValueForDump(const Entry& input);
1360 
1361   // Clean up when this executor is done.
1362   void Finish();
1363   // Schedule Finish() on a separate thread if it needs to wait for deferred
1364   // async ops to complete; otherwise run it on the current thread.
1365   void ScheduleFinish();
1366 
1367   // A standalone routine for this expression so that we can express
1368   // that we don't want thread safety analysis on this reference (it's
1369   // safe to do without the lock because the iterations array never
1370   // resizes and this particular iteration's array element will not
1371   // be changed out from under us because the iteration is still alive).
GetInputTensors(FrameState * input_frame,int64 input_iter) const1372   Entry* GetInputTensors(FrameState* input_frame,
1373                          int64 input_iter) const NO_THREAD_SAFETY_ANALYSIS {
1374     return input_frame->GetIteration(input_iter)->input_tensors;
1375   }
1376 };
1377 
ExecutorState(const Executor::Args & args,ExecutorImpl * impl)1378 ExecutorState::ExecutorState(const Executor::Args& args, ExecutorImpl* impl)
1379     : vlog_(VLOG_IS_ON(1)),
1380       log_memory_(LogMemory::IsEnabled()),
1381       step_id_(args.step_id),
1382       rendezvous_(args.rendezvous),
1383       collective_executor_(args.collective_executor),
1384       session_state_(args.session_state),
1385       session_handle_(args.session_handle),
1386       tensor_store_(args.tensor_store),
1387       step_container_(args.step_container),
1388       stats_collector_(args.stats_collector),
1389       event_collector_(
1390           tracing::GetEventCollector(tracing::EventCategory::kCompute)),
1391       context_(ContextKind::kThread),
1392       slice_reader_cache_(new checkpoint::TensorSliceReaderCacheWrapper),
1393       call_frame_(args.call_frame),
1394       impl_(impl),
1395       cancellation_manager_(args.cancellation_manager),
1396       runner_(args.runner),
1397       sync_on_finish_(args.sync_on_finish),
1398       trace_using_annotations_(impl->params_.device->TraceUsingAnnotations()),
1399       num_outstanding_ops_(0) {
1400   // We start the entire execution in iteration 0 of the root frame
1401   // so let us create the root frame and the state for iteration 0.
1402   // We assume root_frame_->frame_name.empty().
1403   root_frame_ = new FrameState(impl_, 1);
1404   root_frame_->frame_id = 0;  // must be 0
1405   root_frame_->InitializeFrameInfo(root_frame_->frame_name);
1406 
1407   // Initialize iteration 0.
1408   root_frame_->iterations.resize(root_frame_->max_parallel_iterations);
1409   root_frame_->iterations[0] = new IterationState(
1410       root_frame_->pending_counts, root_frame_->total_input_tensors);
1411 
1412   outstanding_frames_.insert({root_frame_->frame_name, root_frame_});
1413 }
1414 
~ExecutorState()1415 ExecutorState::~ExecutorState() {
1416   for (auto name_frame : outstanding_frames_) {
1417     delete name_frame.second;
1418   }
1419   for (auto it : device_context_map_) {
1420     it->Unref();
1421   }
1422   delete slice_reader_cache_;
1423 }
1424 
BuildControlFlowInfo(const Graph * g,ControlFlowInfo * cf_info)1425 Status ExecutorImpl::BuildControlFlowInfo(const Graph* g,
1426                                           ControlFlowInfo* cf_info) {
1427   const int num_nodes = g->num_node_ids();
1428   cf_info->frame_names.resize(num_nodes);
1429   std::vector<Node*> parent_nodes;
1430   parent_nodes.resize(num_nodes);
1431   std::vector<bool> visited;
1432   visited.resize(num_nodes);
1433 
1434   string frame_name;
1435   std::deque<Node*> ready;
1436 
1437   // Initialize with the root nodes.
1438   for (Node* n : g->nodes()) {
1439     if (n->in_edges().empty()) {
1440       visited[n->id()] = true;
1441       cf_info->unique_frame_names.insert(frame_name);
1442       ready.push_back(n);
1443     }
1444   }
1445 
1446   while (!ready.empty()) {
1447     Node* curr_node = ready.front();
1448     int curr_id = curr_node->id();
1449     ready.pop_front();
1450 
1451     Node* parent = nullptr;
1452     if (IsEnter(curr_node)) {
1453       // Enter a child frame.
1454       TF_RETURN_IF_ERROR(
1455           GetNodeAttr(curr_node->attrs(), "frame_name", &frame_name));
1456       parent = curr_node;
1457     } else if (IsExit(curr_node)) {
1458       // Exit to the parent frame.
1459       parent = parent_nodes[curr_id];
1460       frame_name = cf_info->frame_names[parent->id()];
1461       parent = parent_nodes[parent->id()];
1462     } else {
1463       parent = parent_nodes[curr_id];
1464       frame_name = cf_info->frame_names[curr_id];
1465     }
1466 
1467     for (const Edge* out_edge : curr_node->out_edges()) {
1468       Node* out = out_edge->dst();
1469       const int out_id = out->id();
1470 
1471       // Add to ready queue if not visited.
1472       bool is_visited = visited[out_id];
1473       if (!is_visited) {
1474         ready.push_back(out);
1475         visited[out_id] = true;
1476 
1477         // Process the node 'out'.
1478         cf_info->frame_names[out_id] = frame_name;
1479         parent_nodes[out_id] = parent;
1480         cf_info->unique_frame_names.insert(frame_name);
1481       }
1482     }
1483   }
1484 
1485   return Status::OK();
1486 }
1487 
InitializePending(const Graph * graph,const ControlFlowInfo & cf_info)1488 void ExecutorImpl::InitializePending(const Graph* graph,
1489                                      const ControlFlowInfo& cf_info) {
1490   for (auto& it : cf_info.unique_frame_names) {
1491     FrameInfo* finfo = EnsureFrameInfo(it);
1492     PendingCounts* counts = new PendingCounts(finfo->pending_counts_layout);
1493     DCHECK_EQ(finfo->pending_counts, nullptr);
1494     finfo->pending_counts = counts;
1495   }
1496   for (const Node* n : graph->nodes()) {
1497     const int id = n->id();
1498     const string& name = cf_info.frame_names[id];
1499     size_t max_pending, max_dead;
1500     GetMaxPendingCounts(n, &max_pending, &max_dead);
1501     const NodeItem* item = gview_.node(id);
1502     PendingCounts* counts = EnsureFrameInfo(name)->pending_counts;
1503     counts->set_initial_count(item->pending_id, max_pending);
1504   }
1505 }
1506 
RunAsync(Executor::DoneCallback done)1507 void ExecutorState::RunAsync(Executor::DoneCallback done) {
1508   const Graph* graph = impl_->graph_.get();
1509   TaggedNodeSeq ready;
1510 
1511   // Ask the device to fill in the device context map.
1512   Device* device = impl_->params_.device;
1513   const Status fill_status =
1514       device->FillContextMap(graph, &device_context_map_);
1515   if (!fill_status.ok()) {
1516     delete this;
1517     done(fill_status);
1518     return;
1519   }
1520 
1521   // Initialize the ready queue.
1522   for (const Node* n : impl_->root_nodes_) {
1523     DCHECK_EQ(n->in_edges().size(), 0);
1524     ready.push_back(TaggedNode{n, root_frame_, 0, false});
1525   }
1526   if (ready.empty()) {
1527     delete this;
1528     done(Status::OK());
1529   } else {
1530     num_outstanding_ops_ = ready.size();
1531     root_frame_->iterations[0]->outstanding_ops = ready.size();
1532     done_cb_ = std::move(done);
1533     // Schedule to run all the ready ops in thread pool.
1534     ScheduleReady(ready, nullptr);
1535   }
1536 }
1537 
1538 // State kept alive for executing an asynchronous node in another
1539 // thread.  NOTE: We need to make a copy of p.input,
1540 // p.input_device_contexts, and p.input_alloc_attrs for asynchronous
1541 // kernels because OpKernelContext methods like input_type(i) needs
1542 // the param points to valid input type vector. It's not an issue for
1543 // sync kernels because these vectors are kept on the stack.
1544 struct ExecutorState::AsyncState {
AsyncStatetensorflow::__anon6f8fc96b0111::ExecutorState::AsyncState1545   AsyncState(const OpKernelContext::Params& p, const TaggedNode& _tagged_node,
1546              const NodeItem* _item, Entry* _first_input,
1547              NodeExecStatsInterface* _stats)
1548       : saved_inputs(*p.inputs),
1549         saved_input_device_contexts(*p.input_device_contexts),
1550         saved_input_alloc_attrs(*p.input_alloc_attrs),
1551         params(p),
1552         tagged_node(_tagged_node),
1553         item(_item),
1554         first_input(_first_input),
1555         // ParamsButClearingEigenGPUDevice does equivalent of
1556         //   params.eigen_gpu_device = nullptr;
1557         ctx(ParamsButClearingEigenGPUDevice(&params), item->num_outputs),
1558         stats(_stats) {
1559     params.inputs = &saved_inputs;
1560     params.input_device_contexts = &saved_input_device_contexts;
1561     params.input_alloc_attrs = &saved_input_alloc_attrs;
1562   }
1563 
1564   TensorValueVec saved_inputs;
1565   DeviceContextVec saved_input_device_contexts;
1566   AllocatorAttributeVec saved_input_alloc_attrs;
1567   OpKernelContext::Params params;
1568   TaggedNode tagged_node;
1569   const NodeItem* item;
1570   Entry* first_input;
1571   OpKernelContext ctx;
1572   NodeExecStatsInterface* stats;
1573 
1574  private:
ParamsButClearingEigenGPUDevicetensorflow::__anon6f8fc96b0111::ExecutorState::AsyncState1575   OpKernelContext::Params* ParamsButClearingEigenGPUDevice(
1576       OpKernelContext::Params* p) {
1577     // Ensure OpKernelContext constructor will make a new eigen GPU device if
1578     // necessary.
1579     p->eigen_gpu_device = nullptr;  // Force allocation
1580     return p;
1581   }
1582 };
1583 
1584 // Returns true if `item` might be traced by the given trace and event
1585 // collectors. Returns false only if `item` definitely will not be traced.
MightTrace(const NodeItem & item,const tracing::EventCollector * event_collector,bool using_annotations)1586 bool MightTrace(const NodeItem& item,
1587                 const tracing::EventCollector* event_collector,
1588                 bool using_annotations) {
1589   // Tracing will only be enabled if either `event_collector` is non null,
1590   // or `trace_collector` is non-null and enabled for this particular kernel.
1591   // Although `tracing::ScopedActivity`,
1592   // `tracing::ScopedAnnotation`, and `tracing::ScopedRegion` check subsets of
1593   // these properties internally in their constructors, the cost of passing the
1594   // necessary arguments to them can be significant, so we avoid constructing
1595   // them in the common case (when we know they will not be used).
1596   if (event_collector != nullptr) {
1597     return true;
1598   }
1599   auto* trace_collector = tracing::GetTraceCollector();
1600   if (trace_collector) {
1601     if (using_annotations) {
1602       return trace_collector->IsEnabledForAnnotations();
1603     } else {
1604       return trace_collector->IsEnabledForActivities(
1605           item.kernel->IsExpensive());
1606     }
1607   }
1608   return false;
1609 }
1610 
Process(TaggedNode tagged_node,int64 scheduled_nsec)1611 void ExecutorState::Process(TaggedNode tagged_node, int64 scheduled_nsec) {
1612   WithContext wc(context_);
1613   const GraphView& gview = impl_->gview_;
1614   TaggedNodeSeq ready;
1615   TaggedNodeReadyQueue inline_ready;
1616 
1617   // Parameters passed to OpKernel::Compute.
1618   TensorValueVec inputs;
1619   DeviceContextVec input_device_contexts;
1620   AllocatorAttributeVec input_alloc_attrs;
1621 
1622   OpKernelContext::Params params;
1623   params.step_id = step_id_;
1624   Device* device = impl_->params_.device;
1625   params.device = device;
1626   params.log_memory = log_memory_;
1627   params.record_tensor_accesses = impl_->device_record_tensor_accesses_;
1628   params.rendezvous = rendezvous_;
1629   params.collective_executor = collective_executor_;
1630   params.session_state = session_state_;
1631   params.session_handle = session_handle_;
1632   params.tensor_store = tensor_store_;
1633   params.cancellation_manager = cancellation_manager_;
1634   params.call_frame = call_frame_;
1635   params.function_library = impl_->params_.function_library;
1636   params.resource_manager = device->resource_manager();
1637   params.step_container = step_container_;
1638   params.slice_reader_cache = slice_reader_cache_;
1639   params.inputs = &inputs;
1640   params.input_device_contexts = &input_device_contexts;
1641   params.input_alloc_attrs = &input_alloc_attrs;
1642   params.runner = &runner_;
1643   params.stats_collector = stats_collector_;
1644   params.inc_num_deferred_ops_function = [this]() {
1645     mutex_lock lock(num_deferred_ops_mu_);
1646     num_deferred_ops_++;
1647   };
1648   params.dec_num_deferred_ops_function = [this]() {
1649     mutex_lock lock(num_deferred_ops_mu_);
1650     num_deferred_ops_--;
1651     if (num_deferred_ops_ == 0) {
1652       no_deferred_ops_cv_.notify_all();
1653     }
1654   };
1655 
1656   Status s;
1657   NodeExecStatsInterface* stats = nullptr;
1658 
1659   EntryVector outputs;
1660   bool completed = false;
1661   inline_ready.push_back(tagged_node);
1662   while (!inline_ready.empty()) {
1663     tagged_node = inline_ready.front();
1664     inline_ready.pop_front();
1665     const Node* node = tagged_node.node;
1666     FrameState* input_frame = tagged_node.input_frame;
1667     const int64 input_iter = tagged_node.input_iter;
1668     const int id = node->id();
1669     const NodeItem& item = *gview.node(id);
1670 
1671     // TODO(misard) Replace with a finer-grain enabling flag once we
1672     // add better optional debugging support.
1673     if (vlog_ && VLOG_IS_ON(1)) {
1674       mutex_lock l(input_frame->mu);
1675       input_frame->GetIteration(input_iter)->mark_started(item.pending_id);
1676     }
1677 
1678     // Set the device_context for this node id, if it exists.
1679     if (id < device_context_map_.size()) {
1680       params.op_device_context = device_context_map_[id];
1681     }
1682 
1683     params.track_allocations = false;
1684     stats = nullptr;
1685     if (stats_collector_ && !tagged_node.is_dead) {
1686       stats = stats_collector_->CreateNodeExecStats(node);
1687       // Track allocations if and only if we are collecting statistics, and
1688       // `stats` object is expecting allocations to be tracked.
1689       params.track_allocations = stats ? stats->TrackAllocations() : false;
1690       nodestats::SetScheduled(stats, scheduled_nsec);
1691       nodestats::SetAllStart(stats);
1692     }
1693 
1694     if (vlog_) {
1695       VLOG(1) << "Process node: " << id << " step " << params.step_id << " "
1696               << SummarizeNode(*node) << (tagged_node.is_dead ? " is dead" : "")
1697               << " device: " << device->name();
1698     }
1699 
1700     Entry* input_tensors = GetInputTensors(input_frame, input_iter);
1701     Entry* first_input = input_tensors + item.input_start;
1702     outputs.clear();
1703 
1704     TensorReferenceVector accessed_tensors;
1705     DeviceContext* device_context = nullptr;
1706     // Only execute this node if it is not dead or it is a send/recv
1707     // transfer node. For transfer nodes, we need to propagate the "dead"
1708     // bit even when the node is dead.
1709     bool launched_asynchronously = false;
1710     if (tagged_node.is_dead && !IsTransferNode(node)) {
1711       outputs.resize(item.num_outputs);
1712     } else {
1713       // Prepares inputs.
1714       bool is_input_dead = false;
1715       s = PrepareInputs(item, first_input, &inputs, &input_device_contexts,
1716                         &input_alloc_attrs, &is_input_dead);
1717       if (!s.ok()) {
1718         // Clear inputs.
1719         int num_inputs = item.num_inputs;
1720         for (int i = 0; i < num_inputs; ++i) {
1721           (first_input + i)->ClearVal();
1722         }
1723         MaybeMarkCompleted(input_frame, input_iter, id);
1724         // Continue to process the nodes in 'inline_ready'.
1725         completed = NodeDone(s, item.node, ready, stats, &inline_ready);
1726         continue;
1727       }
1728 
1729       // Set up compute params.
1730       OpKernel* op_kernel = item.kernel;
1731       params.op_kernel = op_kernel;
1732       params.frame_iter = FrameAndIter(input_frame->frame_id, input_iter);
1733       params.is_input_dead = is_input_dead;
1734       params.output_attr_array = item.output_attrs();
1735       params.forward_from_array = item.forward_from();
1736 
1737       if (item.kernel_is_async) {
1738         // Asynchronous computes.
1739         AsyncOpKernel* async = item.kernel->AsAsync();
1740         DCHECK(async != nullptr);
1741         launched_asynchronously = true;
1742         AsyncState* state =
1743             new AsyncState(params, tagged_node, &item, first_input, stats);
1744 
1745         auto done = [this, state]() {
1746           Device* device = impl_->params_.device;
1747           NodeExecStatsInterface* stats = state->stats;  // Shorthand
1748           Entry* first_input = state->first_input;       // Shorthand
1749 
1750           nodestats::SetOpEnd(stats);
1751           EntryVector outputs;
1752           Status s = ProcessOutputs(*state->item, &state->ctx, &outputs, stats);
1753           nodestats::SetMemory(stats, &state->ctx);
1754           if (vlog_) {
1755             VLOG(2) << "Async kernel done: " << state->item->node->id()
1756                     << " step " << step_id_ << " "
1757                     << SummarizeNode(*state->item->node)
1758                     << (state->tagged_node.is_dead ? " is dead" : "")
1759                     << " device: " << device->name();
1760           }
1761 
1762           // Clears inputs.
1763           const int num_inputs = state->item->num_inputs;
1764           for (int i = 0; i < num_inputs; ++i) {
1765             (first_input + i)->ClearVal();
1766           }
1767           FrameState* input_frame = state->tagged_node.input_frame;
1768           const int64 input_iter = state->tagged_node.input_iter;
1769           const int id = state->tagged_node.node->id();
1770           MaybeMarkCompleted(input_frame, input_iter, id);
1771           TaggedNodeSeq ready;
1772           if (s.ok()) {
1773             PropagateOutputs(state->tagged_node, state->item, &outputs, &ready);
1774           }
1775           outputs.clear();
1776           if (s.ok() && impl_->device_record_tensor_accesses_) {
1777             // Get the list of all tensors accessed during the execution
1778             TensorReferenceVector accessed;
1779             state->ctx.retrieve_accessed_tensors(&accessed);
1780             nodestats::SetReferencedTensors(stats, accessed);
1781             // callee takes ownership of the vector
1782             device->ConsumeListOfAccessedTensors(state->ctx.op_device_context(),
1783                                                  accessed);
1784           }
1785           const bool completed =
1786               NodeDone(s, state->item->node, ready, stats, nullptr);
1787           delete state;
1788           if (completed) ScheduleFinish();
1789         };
1790         nodestats::SetOpStart(stats);
1791         device->ComputeAsync(async, &state->ctx, done);
1792       } else {
1793         // Synchronous computes.
1794         OpKernelContext ctx(&params, item.num_outputs);
1795         nodestats::SetOpStart(stats);
1796 
1797         if (TF_PREDICT_FALSE(
1798                 MightTrace(item, event_collector_, trace_using_annotations_))) {
1799           const string& op_name = op_kernel->name();
1800           tracing::ScopedRegion region(tracing::EventCategory::kCompute,
1801                                        op_name);
1802           if (trace_using_annotations_) {
1803             // The OpKernel may create child activities (such as GPU kernel
1804             // launches), so use a `ScopedAnnotation` to relate these activities
1805             // in the trace.
1806             tracing::ScopedAnnotation activity(
1807                 op_name, strings::StrCat(op_kernel->type_string(),
1808                                          "#id=", step_id_, "#"));
1809             device->Compute(op_kernel, &ctx);
1810           } else {
1811             // Use the cheaper `ScopedActivity` to trace just the OpKernel
1812             // execution.
1813             tracing::ScopedActivity activity(
1814                 op_name,
1815                 strings::StrCat(op_kernel->type_string(), "#id=", step_id_,
1816                                 "#"),
1817                 item.kernel->IsExpensive());
1818             device->Compute(op_kernel, &ctx);
1819           }
1820         } else {
1821           // In the common case, avoid creating any tracing objects.
1822           if (op_kernel->IsExpensive()) {
1823             KernelTimer timer;
1824             device->Compute(op_kernel, &ctx);
1825             op_kernel->UpdateCostEstimate(timer.ElapsedCycles());
1826           } else {
1827             device->Compute(op_kernel, &ctx);
1828           }
1829         }
1830 
1831         nodestats::SetOpEnd(stats);
1832         s = ProcessOutputs(item, &ctx, &outputs, stats);
1833         if (s.ok() && impl_->device_record_tensor_accesses_) {
1834           // Get the list of all tensors accessed during the execution
1835           ctx.retrieve_accessed_tensors(&accessed_tensors);
1836           device_context = ctx.op_device_context();
1837         }
1838         nodestats::SetMemory(stats, &ctx);
1839       }
1840     }
1841 
1842     if (!launched_asynchronously) {
1843       if (vlog_) {
1844         VLOG(2) << "Synchronous kernel done: " << id << " step "
1845                 << params.step_id << " " << SummarizeNode(*node)
1846                 << (tagged_node.is_dead ? " is dead: " : "")
1847                 << " device: " << device->name();
1848       }
1849 
1850       // Clears inputs.
1851       const int num_inputs = item.num_inputs;
1852       for (int i = 0; i < num_inputs; ++i) {
1853         (first_input + i)->ClearVal();
1854       }
1855       MaybeMarkCompleted(input_frame, input_iter, id);
1856       // Propagates outputs.
1857       if (s.ok()) {
1858         PropagateOutputs(tagged_node, &item, &outputs, &ready);
1859       }
1860       outputs.clear();
1861       if (!accessed_tensors.empty()) {
1862         nodestats::SetReferencedTensors(stats, accessed_tensors);
1863         // device_context is set above in synchronous computes
1864         device->ConsumeListOfAccessedTensors(device_context, accessed_tensors);
1865       }
1866       if (stats) {
1867         scheduled_nsec = nodestats::NowInNsec();
1868       }
1869       // Postprocess.
1870       completed = NodeDone(s, item.node, ready, stats, &inline_ready);
1871     }
1872   }  // while !inline_ready.empty()
1873 
1874   // This thread of computation is done if completed = true.
1875   if (completed) ScheduleFinish();
1876 }
1877 
PrepareInputs(const NodeItem & item,Entry * first_input,TensorValueVec * inputs,DeviceContextVec * input_device_contexts,AllocatorAttributeVec * input_alloc_attrs,bool * is_input_dead)1878 Status ExecutorState::PrepareInputs(const NodeItem& item, Entry* first_input,
1879                                     TensorValueVec* inputs,
1880                                     DeviceContextVec* input_device_contexts,
1881                                     AllocatorAttributeVec* input_alloc_attrs,
1882                                     bool* is_input_dead) {
1883   const Node* node = item.node;
1884 
1885   inputs->clear();
1886   inputs->resize(item.num_inputs);
1887   input_device_contexts->clear();
1888   input_device_contexts->resize(item.num_inputs);
1889   input_alloc_attrs->clear();
1890   input_alloc_attrs->resize(item.num_inputs);
1891 
1892   *is_input_dead = false;
1893 
1894   bool is_merge = item.is_merge;
1895   for (int i = 0; i < item.num_inputs; ++i) {
1896     const bool expect_ref = IsRefType(item.input_type(i));
1897     Entry* entry = first_input + i;
1898     (*input_device_contexts)[i] = entry->device_context;
1899     (*input_alloc_attrs)[i] = entry->alloc_attr;
1900 
1901     // i-th input.
1902     TensorValue* inp = &(*inputs)[i];
1903 
1904     // Only merge and transfer nodes can have no-value inputs.
1905     if (!entry->has_value) {
1906       if (!is_merge) {
1907         DCHECK(IsTransferNode(node)) << node->name() << " - input " << i;
1908         DCHECK(!entry->val_field_is_set) << node->name() << " - input " << i;
1909         entry->has_value = true;
1910         entry->val_field_is_set = true;
1911         entry->val.Init(*kEmptyTensor);
1912         inp->tensor = entry->val.get();
1913         *is_input_dead = true;
1914       }
1915       continue;
1916     }
1917     if (entry->ref == nullptr) {
1918       if (expect_ref) {
1919         return AttachDef(
1920             errors::InvalidArgument(i, "-th input expects a ref type"),
1921             item.kernel->def());
1922       }
1923       inp->tensor = entry->val.get();
1924     } else {
1925       {
1926         tf_shared_lock ml(*entry->ref_mu);
1927         if (!entry->ref->IsInitialized() && !IsInitializationOp(item.node)) {
1928           return AttachDef(errors::FailedPrecondition(
1929                                "Attempting to use uninitialized value ",
1930                                item.kernel->requested_input(i)),
1931                            item.kernel->def());
1932         }
1933       }
1934       if (expect_ref) {
1935         inp->mutex_if_ref = entry->ref_mu;
1936         inp->tensor = entry->ref;
1937       } else {
1938         // Automatically deref the tensor ref when the op expects a
1939         // tensor but is given a ref to a tensor.  Need to deref it
1940         // under the mutex.
1941         {
1942           tf_shared_lock l(*(entry->ref_mu));
1943           DCHECK(!entry->val_field_is_set);
1944           entry->val.Init(*entry->ref);
1945           entry->val_field_is_set = true;
1946         }
1947         entry->ref = nullptr;
1948         entry->ref_mu = nullptr;
1949 
1950         inp->tensor = entry->val.get();
1951         // The dtype of entry->ref could have been changed by another operation
1952         // that ran after the operation that "produced" it executed, so
1953         // re-validate that the type of the dereferenced tensor matches the
1954         // expected input type.
1955         if (item.input_type(i) != inp->tensor->dtype()) {
1956           return AttachDef(
1957               errors::InvalidArgument(
1958                   i, "-th input expects type ",
1959                   DataTypeString(item.input_type(i)),
1960                   " but automatically dereferenced input tensor has type ",
1961                   DataTypeString(inp->tensor->dtype())),
1962               item.kernel->def());
1963         }
1964       }
1965     }
1966   }
1967   return Status::OK();
1968 }
1969 
ProcessOutputs(const NodeItem & item,OpKernelContext * ctx,EntryVector * outputs,NodeExecStatsInterface * stats)1970 Status ExecutorState::ProcessOutputs(const NodeItem& item, OpKernelContext* ctx,
1971                                      EntryVector* outputs,
1972                                      NodeExecStatsInterface* stats) {
1973   const Node* node = item.node;
1974   DCHECK_EQ(0, outputs->size());
1975   outputs->resize(item.num_outputs);
1976 
1977   Status s = ctx->status();
1978   if (!s.ok()) {
1979     s = AttachDef(s, item.kernel->def());
1980     // TODO(misard) Replace with a finer-grain enabling flag once we
1981     // add better optional debugging support.
1982     if (vlog_ && VLOG_IS_ON(1)) {
1983       LOG(WARNING) << this << " Compute status: " << s;
1984       DumpState();
1985     }
1986     if (s.code() == error::RESOURCE_EXHAUSTED) {
1987       if (stats_collector_) {
1988         string err = stats_collector_->ReportAllocsOnResourceExhausted(
1989             s.error_message());
1990         s = Status(s.code(), strings::StrCat(s.error_message(), err));
1991       } else {
1992         s = Status(
1993             s.code(),
1994             strings::StrCat(
1995                 s.error_message(),
1996                 "\nHint: If you want to see a list of allocated tensors when "
1997                 "OOM happens, add report_tensor_allocations_upon_oom "
1998                 "to RunOptions for current allocation info.\n"));
1999       }
2000     }
2001     return s;
2002   }
2003 
2004   // Get the device_context for this node id, if it exists.
2005   DeviceContext* device_context = nullptr;
2006   if (node->id() < device_context_map_.size()) {
2007     device_context = device_context_map_[node->id()];
2008   }
2009 
2010   for (int i = 0; i < item.num_outputs; ++i) {
2011     const TensorValue val = ctx->release_output(i);
2012     if (val.tensor == nullptr) {
2013       // Unless it's a Switch or a Recv, the node must produce a
2014       // tensor value at i-th output.
2015       if (!IsSwitch(node) && !IsRecv(node)) {
2016         s.Update(errors::Internal("Missing ", i, "-th output from ",
2017                                   FormatNodeForError(*node)));
2018       }
2019     } else {
2020       Entry* out = &((*outputs)[i]);
2021 
2022       // Set the device context of the output entry.
2023       out->device_context = device_context;
2024 
2025       // Set the allocator attributes of the output entry.
2026       out->alloc_attr = ctx->output_alloc_attr(i);
2027 
2028       // Sanity check of output tensor types.
2029       DataType dtype;
2030       if (val.is_ref()) {
2031         tf_shared_lock ml(*val.mutex_if_ref);
2032         dtype = MakeRefType(val->dtype());
2033       } else {
2034         dtype = val->dtype();
2035       }
2036       if (dtype == item.output_type(i)) {
2037         if (stats && val.tensor->IsInitialized()) {
2038           nodestats::SetOutput(stats, i, val.tensor);
2039         }
2040         if (val.is_ref()) {
2041           out->has_value = true;
2042           out->ref = val.tensor;
2043           out->ref_mu = val.mutex_if_ref;
2044           if (log_memory_) {
2045             Tensor to_log;
2046             {
2047               // Dereference the tensor under the lock.
2048               tf_shared_lock l(*out->ref_mu);
2049               to_log = *out->ref;
2050             }
2051             LogMemory::RecordTensorOutput(ctx->op_kernel().name(),
2052                                           ctx->step_id(), i, to_log);
2053           }
2054         } else {
2055           // NOTE that std::move is used here, so val.tensor goes to
2056           // uninitialized state (val.tensor->IsInitialized return false).
2057           DCHECK(!out->val_field_is_set);
2058           out->has_value = true;
2059           out->val_field_is_set = true;
2060           out->val.Init(std::move(*val.tensor));
2061           if (log_memory_) {
2062             LogMemory::RecordTensorOutput(ctx->op_kernel().name(),
2063                                           ctx->step_id(), i, *out->val);
2064           }
2065         }
2066       } else {
2067         s.Update(errors::Internal("Output ", i, " of type ",
2068                                   DataTypeString(dtype),
2069                                   " does not match declared output type ",
2070                                   DataTypeString(item.output_type(i)),
2071                                   " for node ", FormatNodeForError(*node)));
2072       }
2073     }
2074     if (!val.is_ref()) {
2075       // If OpKernelContext returns outputs via pass-by-value, we
2076       // don't need this trouble.
2077       delete val.tensor;
2078     }
2079   }
2080   return s;
2081 }
2082 
PropagateOutputs(const TaggedNode & tagged_node,const NodeItem * item,EntryVector * outputs,TaggedNodeSeq * ready)2083 void ExecutorState::PropagateOutputs(const TaggedNode& tagged_node,
2084                                      const NodeItem* item, EntryVector* outputs,
2085                                      TaggedNodeSeq* ready) {
2086   auto activity_handle =
2087       [&]() -> std::unique_ptr<tracing::TraceCollector::Handle> {
2088     auto* trace_collector = tracing::GetTraceCollector();
2089     if (TF_PREDICT_FALSE(trace_collector != nullptr &&
2090                          trace_collector->IsEnabledForActivities(
2091                              false /* is_expensive */))) {
2092       const string& op_name = item->kernel->name();
2093       // Intentionally using ExecutorPropagateOutputs as the first key so that
2094       // users are aware that it's not the op invocation.
2095       return trace_collector->CreateActivityHandle(
2096           "ExecutorPropagateOutputs",
2097           strings::StrCat(op_name, "#id=", step_id_, "#"),
2098           false /* is_expensive */);
2099     } else {
2100       return nullptr;
2101     }
2102   }();
2103 
2104   const Node* node = tagged_node.node;
2105   FrameState* input_frame = tagged_node.input_frame;
2106   const int64 input_iter = tagged_node.input_iter;
2107   const bool is_dead = tagged_node.is_dead;
2108 
2109   // Propagates outputs along out edges, and puts newly ready nodes
2110   // into the ready queue.
2111   ready->clear();
2112   bool is_frame_done = false;
2113   FrameState* output_frame = input_frame;
2114   int64 output_iter = input_iter;
2115 
2116   if (!item->is_enter_exit_or_next_iter) {
2117     // Fast path for nodes types that don't need special handling
2118     DCHECK_EQ(input_frame, output_frame);
2119     // Normal path for most nodes
2120     mutex_lock l(input_frame->mu);
2121     output_frame->ActivateNodes(item, is_dead, output_iter, outputs, ready);
2122     is_frame_done = input_frame->DecrementOutstandingOpsLocked(
2123         &impl_->gview_, input_iter, ready);
2124   } else if (item->is_enter) {
2125     FindOrCreateChildFrame(input_frame, input_iter, node, &output_frame);
2126     output_iter = 0;
2127     {
2128       const NodeItem* item = impl_->gview_.node(node->id());
2129       mutex_lock l(output_frame->mu);
2130       if (item->is_constant_enter) {
2131         // Propagate to all active iterations if this is a loop invariant.
2132         output_frame->AddLoopInv(item, (*outputs)[0], ready);
2133       } else {
2134         output_frame->ActivateNodes(item, is_dead, output_iter, outputs, ready);
2135       }
2136       output_frame->num_pending_inputs--;
2137     }
2138     is_frame_done =
2139         input_frame->DecrementOutstandingOps(&impl_->gview_, input_iter, ready);
2140   } else if (item->is_exit) {
2141     if (is_dead) {
2142       mutex_lock l(input_frame->mu);
2143       // Stop and remember this node if it is a dead exit.
2144       if (input_iter == input_frame->iteration_count) {
2145         input_frame->dead_exits.push_back(node);
2146       }
2147       is_frame_done = input_frame->DecrementOutstandingOpsLocked(
2148           &impl_->gview_, input_iter, ready);
2149     } else {
2150       output_frame = input_frame->parent_frame;
2151       output_iter = input_frame->parent_iter;
2152       {
2153         mutex_lock l(output_frame->mu);
2154         output_frame->ActivateNodes(item, is_dead, output_iter, outputs, ready);
2155       }
2156       is_frame_done = input_frame->DecrementOutstandingOps(&impl_->gview_,
2157                                                            input_iter, ready);
2158     }
2159   } else {
2160     DCHECK(IsNextIteration(node));
2161     mutex_lock l(input_frame->mu);
2162     if (is_dead) {
2163       // Stop the deadness propagation.
2164       output_frame = nullptr;
2165     } else {
2166       if (input_iter == input_frame->iteration_count &&
2167           input_frame->num_outstanding_iterations ==
2168               input_frame->max_parallel_iterations) {
2169         // Reached the maximum for parallel iterations.
2170         input_frame->next_iter_roots.push_back({node, (*outputs)[0]});
2171         output_frame = nullptr;
2172       } else {
2173         // If this is a new iteration, start it.
2174         if (input_iter == input_frame->iteration_count) {
2175           input_frame->IncrementIteration(&impl_->gview_, ready);
2176         }
2177         output_iter = input_iter + 1;
2178       }
2179     }
2180     if (output_frame != nullptr) {
2181       // This is the case when node is not Enter, Exit, or NextIteration.
2182       DCHECK(input_frame == output_frame);
2183       output_frame->ActivateNodes(item, is_dead, output_iter, outputs, ready);
2184     }
2185     is_frame_done = input_frame->DecrementOutstandingOpsLocked(
2186         &impl_->gview_, input_iter, ready);
2187   }
2188 
2189   // At this point, this node is completely done. We also know if the
2190   // completion of this node makes its frame completed.
2191   if (is_frame_done) {
2192     FrameState* parent_frame = input_frame->parent_frame;
2193     const int64 parent_iter = input_frame->parent_iter;
2194     DeleteFrame(input_frame, ready);
2195     if (parent_frame != nullptr) {
2196       // The completion of frame may cause completions in its parent frame.
2197       // So clean things up recursively.
2198       CleanupFramesIterations(parent_frame, parent_iter, ready);
2199     }
2200   }
2201 }
2202 
NodeDone(const Status & s,const Node * node,const TaggedNodeSeq & ready,NodeExecStatsInterface * stats,TaggedNodeReadyQueue * inline_ready)2203 bool ExecutorState::NodeDone(const Status& s, const Node* node,
2204                              const TaggedNodeSeq& ready,
2205                              NodeExecStatsInterface* stats,
2206                              TaggedNodeReadyQueue* inline_ready) {
2207   nodestats::SetAllEnd(stats);
2208   if (stats) {
2209     if (stats_collector_) {
2210       stats->Done(impl_->params_.device->name());
2211     } else {
2212       delete stats;
2213     }
2214   }
2215 
2216   bool abort_run = false;
2217   if (!s.ok()) {
2218     // Some error happened. This thread of computation is done.
2219     mutex_lock l(mu_);
2220     if (status_.ok()) {
2221       abort_run = true;
2222       status_ = s;
2223     }
2224   }
2225   if (abort_run) {
2226     TRACEPRINTF("StartAbort: %s", s.ToString().c_str());
2227     if (rendezvous_) {
2228       rendezvous_->StartAbort(s);
2229     }
2230     if (collective_executor_) {
2231       collective_executor_->StartAbort(s);
2232     }
2233     if (cancellation_manager_) {
2234       cancellation_manager_->StartCancel();
2235     }
2236   }
2237 
2238   bool completed = false;
2239   const size_t ready_size = ready.size();
2240   if (ready_size == 0 || !s.ok()) {
2241     completed = (num_outstanding_ops_.fetch_sub(1) == 1);
2242   } else if (ready_size > 1) {
2243     num_outstanding_ops_.fetch_add(ready_size - 1, std::memory_order_relaxed);
2244   }
2245 
2246   // Schedule the ready nodes in 'ready'.
2247   if (s.ok()) {
2248     ScheduleReady(ready, inline_ready);
2249   }
2250   return completed;
2251 }
2252 
ScheduleReady(const TaggedNodeSeq & ready,TaggedNodeReadyQueue * inline_ready)2253 void ExecutorState::ScheduleReady(const TaggedNodeSeq& ready,
2254                                   TaggedNodeReadyQueue* inline_ready) {
2255   if (ready.empty()) return;
2256 
2257   int64 scheduled_nsec = 0;
2258   if (stats_collector_) {
2259     scheduled_nsec = nodestats::NowInNsec();
2260   }
2261 
2262   if (inline_ready == nullptr) {
2263     // Schedule to run all the ready ops in thread pool.
2264     for (auto& tagged_node : ready) {
2265       runner_([=]() { Process(tagged_node, scheduled_nsec); });
2266     }
2267     return;
2268   }
2269 
2270   const GraphView& gview = impl_->gview_;
2271   const TaggedNode* curr_expensive_node = nullptr;
2272   for (auto& tagged_node : ready) {
2273     const NodeItem& item = *gview.node(tagged_node.node->id());
2274     if (tagged_node.is_dead || !item.kernel->IsExpensive()) {
2275       // Inline this inexpensive node.
2276       inline_ready->push_back(tagged_node);
2277     } else {
2278       if (curr_expensive_node) {
2279         // Dispatch to another thread since there is plenty of work to
2280         // do for this thread.
2281         runner_(std::bind(&ExecutorState::Process, this, *curr_expensive_node,
2282                           scheduled_nsec));
2283       }
2284       curr_expensive_node = &tagged_node;
2285     }
2286   }
2287   if (curr_expensive_node) {
2288     if (inline_ready->empty()) {
2289       // Tail recursion optimization
2290       inline_ready->push_back(*curr_expensive_node);
2291     } else {
2292       // There are inline nodes to run already. We dispatch this expensive
2293       // node to other thread.
2294       runner_(std::bind(&ExecutorState::Process, this, *curr_expensive_node,
2295                         scheduled_nsec));
2296     }
2297   }
2298 }
2299 
MaybeMarkCompleted(FrameState * frame,int64 iter,int64 node_id)2300 inline void ExecutorState::MaybeMarkCompleted(FrameState* frame, int64 iter,
2301                                               int64 node_id) {
2302   // TODO(misard) Replace with a finer-grain enabling flag once we
2303   // add better optional debugging support.
2304   if (vlog_ && VLOG_IS_ON(1)) {
2305     const NodeItem* item = impl_->gview_.node(node_id);
2306     mutex_lock l(frame->mu);
2307     frame->GetIteration(iter)->mark_completed(item->pending_id);
2308   }
2309 }
2310 
GetTensorValueForDump(const Entry & input)2311 const Tensor* ExecutorState::GetTensorValueForDump(const Entry& input) {
2312   if (!input.has_value) {
2313     return kEmptyTensor;
2314   } else if (input.ref == nullptr) {
2315     return input.val.get();
2316   } else {
2317     return input.ref;
2318   }
2319 }
2320 
DumpPendingNodeState(const int node_id,const Entry * input_vector,const bool show_nodes_with_no_ready_inputs)2321 void ExecutorState::DumpPendingNodeState(
2322     const int node_id, const Entry* input_vector,
2323     const bool show_nodes_with_no_ready_inputs) {
2324   const NodeItem& node_item = *impl_->gview_.node(node_id);
2325   const Node& node = *node_item.node;
2326   const int input_base = node_item.input_start;
2327   if (!show_nodes_with_no_ready_inputs) {
2328     bool has_ready_input = false;
2329     for (int i = 0; i < node.num_inputs(); ++i) {
2330       const Entry& input = input_vector[input_base + i];
2331       const Tensor* tensor = GetTensorValueForDump(input);
2332       if (tensor->IsInitialized()) {
2333         has_ready_input = true;
2334         break;
2335       }
2336     }
2337     if (!has_ready_input) {
2338       return;
2339     }
2340   }
2341   LOG(WARNING) << "    Pending Node: " << node.DebugString();
2342   for (int i = 0; i < node.num_inputs(); ++i) {
2343     const Entry& input = input_vector[input_base + i];
2344     const Tensor* tensor = GetTensorValueForDump(input);
2345     if (tensor->IsInitialized()) {
2346       LOG(WARNING) << "      Input " << i << ": "
2347                    << strings::StrCat(
2348                           "Tensor<type: ", DataTypeString(tensor->dtype()),
2349                           " shape: ", tensor->shape().DebugString(), ">");
2350     } else {
2351       LOG(WARNING) << "      Input " << i << ": not present";
2352     }
2353   }
2354 }
2355 
DumpActiveNodeState(const int node_id,const Entry * input_vector)2356 void ExecutorState::DumpActiveNodeState(const int node_id,
2357                                         const Entry* input_vector) {
2358   const NodeItem& node_item = *impl_->gview_.node(node_id);
2359   const Node& node = *node_item.node;
2360   LOG(WARNING) << "    Active Node: " << node.DebugString();
2361   const int input_base = node_item.input_start;
2362   for (int i = 0; i < node.num_inputs(); ++i) {
2363     const Entry& input = input_vector[input_base + i];
2364     const Tensor* tensor = GetTensorValueForDump(input);
2365     if (tensor->IsInitialized()) {
2366       LOG(WARNING) << "      Input " << i << ": "
2367                    << strings::StrCat(
2368                           "Tensor<type: ", DataTypeString(tensor->dtype()),
2369                           " shape: ", tensor->shape().DebugString(), ">");
2370     } else {
2371       LOG(WARNING) << "      Input " << i << ": not present";
2372     }
2373   }
2374 }
2375 
DumpIterationState(const FrameState * frame,IterationState * iteration)2376 void ExecutorState::DumpIterationState(const FrameState* frame,
2377                                        IterationState* iteration) {
2378   const std::vector<const Node*>* nodes = frame->nodes;
2379   // Dump any waiting nodes that are holding on to tensors.
2380   for (const Node* node : *nodes) {
2381     const int node_id = node->id();
2382     PendingCounts::Handle pending_id = impl_->gview_.node(node_id)->pending_id;
2383     if (iteration->node_state(pending_id) == PendingCounts::PENDING_NOTREADY ||
2384         iteration->node_state(pending_id) == PendingCounts::PENDING_READY) {
2385       DumpPendingNodeState(node_id, iteration->input_tensors, false);
2386     }
2387   }
2388   // Then the active nodes.
2389   for (const Node* node : *nodes) {
2390     const int node_id = node->id();
2391     PendingCounts::Handle pending_id = impl_->gview_.node(node_id)->pending_id;
2392     if (iteration->node_state(pending_id) == PendingCounts::STARTED) {
2393       DumpActiveNodeState(node_id, iteration->input_tensors);
2394     }
2395   }
2396   // Show all input tensors in use.
2397   const int total_input_tensors = frame->total_input_tensors;
2398   size_t total_bytes = 0;
2399   for (int i = 0; i < total_input_tensors; ++i) {
2400     const Entry& input = iteration->input_tensors[i];
2401     const Tensor* tensor = GetTensorValueForDump(input);
2402     if (tensor->IsInitialized()) {
2403       LOG(WARNING) << "    Input " << i << ": "
2404                    << strings::StrCat(
2405                           "Tensor<type: ", DataTypeString(tensor->dtype()),
2406                           " shape: ", tensor->shape().DebugString(),
2407                           ", bytes: ", tensor->TotalBytes(), ">");
2408       total_bytes += tensor->TotalBytes();
2409     }
2410   }
2411   LOG(WARNING) << "    Total bytes " << total_bytes;
2412 }
2413 
DumpState()2414 void ExecutorState::DumpState() {
2415   mutex_lock l(mu_);
2416   if (!dumped_on_error_) {
2417     LOG(WARNING) << "Dumping state";
2418     for (auto& frame : outstanding_frames_) {
2419       LOG(WARNING) << frame.first;
2420       FrameState* frame_state = frame.second;
2421       mutex_lock frame_lock(frame_state->mu);
2422       for (IterationState* iteration : frame_state->iterations) {
2423         LOG(WARNING) << "  Iteration:";
2424         DumpIterationState(frame_state, iteration);
2425       }
2426     }
2427     dumped_on_error_ = true;
2428   }
2429 }
2430 
ScheduleFinish()2431 void ExecutorState::ScheduleFinish() {
2432   int num_deferred_ops;
2433   {
2434     mutex_lock lock(num_deferred_ops_mu_);
2435     num_deferred_ops = num_deferred_ops_;
2436   }
2437   if (num_deferred_ops > 0) {
2438     // Finish() may be blocked waiting for deferred async ops to complete. The
2439     // execution of deferred async ops may be waiting for non-enqueued ops of
2440     // other executors to complete. So running Finish() on the current thread
2441     // (inter-op threadpool thread) may lead to a deadlock due to threadpool
2442     // exhaustion. Instead, we run it on a separate thread to unblock the
2443     // threadpool thread.
2444     Env::Default()->SchedClosure([this]() { Finish(); });
2445   } else {
2446     Finish();
2447   }
2448 }
2449 
Finish()2450 void ExecutorState::Finish() {
2451   mu_.lock();
2452   auto status = status_;
2453   auto done_cb = std::move(done_cb_);
2454   auto runner = std::move(runner_);
2455   mu_.unlock();
2456   CHECK(done_cb != nullptr);
2457   Device* device = impl_->params_.device;
2458 
2459   // There are several potential race conditions below. To name a few:
2460   // 1. Even if the device's status is OK at the precise moment when
2461   // num_deferred_ops_ reaches 0, it could go bad before device->RefreshStatus()
2462   // is called below, caused by work enqueued onto the same device by other
2463   // concurrent ExecutorState objects.
2464   // 2. Some implementations of Device::RefreshStatus, such as
2465   // XlaDevice::RefreshStatus, may be inherently racy because it releases the
2466   // device mutex after a stream pointer is acquired and before the stream is
2467   // queried for status.
2468   // 3. It's the same for some implementations of Device::Sync, such as
2469   // XlaDevice::Sync.
2470   //
2471   // However, these race conditions are acceptable because a stream (and
2472   // therefore an XlaDevice) can only go from OK to not-OK, never the opposite,
2473   // which means we will at worst report errors when there isn't any, never the
2474   // opposite.
2475 
2476   // If inc_num_deferred_ops_function has ever been called, ExecutorState must
2477   // wait for all corresponding dec_num_deferred_ops_function calls to happen
2478   // regardless of status. This ensures that dec_num_deferred_ops_function can
2479   // safely use ExecutorState's resources.
2480   {
2481     mutex_lock lock(num_deferred_ops_mu_);
2482     while (num_deferred_ops_ > 0) {
2483       no_deferred_ops_cv_.wait(lock);
2484     }
2485   }
2486 
2487   // An early exit for devices don't allow sync on completion. Ops that run on
2488   // these devices should have used num_deferred_ops correctly to ensure the
2489   // device has finished all relevant work at this point.
2490   if (!device->AllowsSyncOnCompletion()) {
2491     status.Update(device->RefreshStatus());
2492     if (!status.ok()) {
2493       // In device async execution mode, it's possible for device execution to
2494       // lag behind ExecutorState scheduling so much that this is the first
2495       // place a device execution error surfaces.
2496       // If so, all ExecutorState::NodeDone calls have already happened with OK
2497       // status. This is the last defense where StartCancel must be called to
2498       // abort all computation still running on any device.
2499       // TODO(b/124523000): Always call Finish in a separate thread, so even if
2500       // StartCancel blocks the current thread's execution, we won't encounter
2501       // deadlocks caused by inter-op thread exhaustion.
2502       if (cancellation_manager_) {
2503         cancellation_manager_->StartCancel();
2504       }
2505     }
2506     delete this;
2507     runner([=]() { done_cb(status); });
2508     return;
2509   }
2510 
2511   if (sync_on_finish_ && status.ok()) {
2512     // Block until the device has finished all queued operations. For
2513     // devices like GPUs that continue to execute Ops after their Compute
2514     // methods have completed, this ensures that control is not returned to
2515     // the user until the step (and its side-effects) has actually completed.
2516     device->Sync([=](Status new_status) mutable {
2517       status.Update(new_status);
2518       delete this;
2519       runner([=]() { done_cb(status); });
2520     });
2521   } else {
2522     delete this;
2523     runner([=]() { done_cb(status); });
2524   }
2525 }
2526 
FindOrCreateChildFrame(FrameState * frame,int64 iter,const Node * node,FrameState ** child)2527 void ExecutorState::FindOrCreateChildFrame(FrameState* frame, int64 iter,
2528                                            const Node* node,
2529                                            FrameState** child) {
2530   // Get the child frame name.
2531   string enter_name;
2532   Status s = GetNodeAttr(node->attrs(), "frame_name", &enter_name);
2533   DCHECK(s.ok()) << s;
2534   const string child_name = MakeFrameName(frame, iter, enter_name);
2535 
2536   {
2537     mutex_lock executor_lock(mu_);
2538     auto it = outstanding_frames_.find(child_name);
2539     if (it != outstanding_frames_.end()) {
2540       *child = it->second;
2541       return;
2542     }
2543   }
2544 
2545   // Need to create a new frame instance.
2546   // Note that this new frame instance is created without any locks.
2547   if (vlog_) VLOG(2) << "Create frame: " << child_name;
2548 
2549   int parallel_iters;
2550   s = GetNodeAttr(node->attrs(), "parallel_iterations", &parallel_iters);
2551   DCHECK(s.ok()) << s;
2552   FrameState* temp = new FrameState(impl_, parallel_iters);
2553   temp->frame_name = child_name;
2554   temp->frame_id = Hash64(child_name);
2555   temp->parent_frame = frame;
2556   temp->parent_iter = iter;
2557   temp->InitializeFrameInfo(enter_name);
2558 
2559   // 'iterations' is a fixed-length circular buffer.
2560   temp->iterations.resize(temp->max_parallel_iterations + 1);
2561   // Initialize iteration 0.
2562   temp->iterations[0] =
2563       new IterationState(temp->pending_counts, temp->total_input_tensors);
2564 
2565   {
2566     mutex_lock executor_lock(mu_);
2567     auto it = outstanding_frames_.find(child_name);
2568     if (it != outstanding_frames_.end()) {
2569       *child = it->second;
2570     } else {
2571       mutex_lock frame_lock(frame->mu);
2572       frame->GetIteration(iter)->outstanding_frame_count++;
2573       outstanding_frames_[child_name] = temp;
2574       *child = temp;
2575       temp = nullptr;
2576     }
2577   }
2578   delete temp;  // Not used so delete it.
2579 }
2580 
DeleteFrame(FrameState * frame,TaggedNodeSeq * ready)2581 void ExecutorState::DeleteFrame(FrameState* frame, TaggedNodeSeq* ready) {
2582   // First, propagate dead_exits (if any) to the parent frame.
2583   FrameState* parent_frame = frame->parent_frame;
2584   const int64 parent_iter = frame->parent_iter;
2585   if (parent_frame != nullptr) {
2586     mutex_lock parent_frame_lock(parent_frame->mu);
2587     // Propagate all the dead exits to the parent frame.
2588     mutex_lock this_frame_lock(frame->mu);
2589     for (const Node* node : frame->dead_exits) {
2590       auto parent_iter_state = parent_frame->GetIteration(parent_iter);
2591       for (const Edge* e : node->out_edges()) {
2592         const Node* dst_node = e->dst();
2593 
2594         const auto dst_pending_id =
2595             impl_->gview_.node(dst_node->id())->pending_id;
2596 
2597         // TODO(yuanbyu): We don't need this if we require the subgraph
2598         // given to an executor not to contain a sink node.
2599         if (dst_node->IsSink()) continue;
2600 
2601         bool dst_dead = true;
2602         bool dst_ready = false;
2603         // We know this is a dead input to dst.
2604         if (IsMerge(dst_node)) {
2605           if (e->IsControlEdge()) {
2606             parent_iter_state->decrement_pending(dst_pending_id, 2);
2607             int count = parent_iter_state->pending(dst_pending_id);
2608             int dead_cnt = parent_iter_state->dead_count(dst_pending_id);
2609             dst_dead = (dead_cnt == dst_node->num_inputs());
2610             dst_ready = (count == 0) || ((count == 1) && dst_dead);
2611           } else {
2612             parent_iter_state->increment_dead_count(dst_pending_id);
2613             const int dead_cnt = parent_iter_state->dead_count(dst_pending_id);
2614             dst_dead = (dead_cnt == dst_node->num_inputs());
2615             dst_ready =
2616                 (parent_iter_state->pending(dst_pending_id) == 1) && dst_dead;
2617           }
2618         } else {
2619           parent_iter_state->increment_dead_count(dst_pending_id);
2620           dst_ready =
2621               (parent_iter_state->decrement_pending(dst_pending_id, 1) == 0);
2622         }
2623         if (dst_ready) {
2624           if (IsControlTrigger(dst_node)) dst_dead = false;
2625           ready->emplace_back(dst_node, parent_frame, parent_iter, dst_dead);
2626           parent_iter_state->outstanding_ops++;
2627         }
2628       }
2629     }
2630   }
2631 
2632   // Delete the frame.
2633   const string& frame_name = frame->frame_name;
2634   if (vlog_) VLOG(2) << "Delete frame " << frame_name;
2635   {
2636     mutex_lock executor_lock(mu_);
2637     outstanding_frames_.erase(frame_name);
2638   }
2639   delete frame;
2640 }
2641 
CleanupFramesIterations(FrameState * frame,int64 iter,TaggedNodeSeq * ready)2642 void ExecutorState::CleanupFramesIterations(FrameState* frame, int64 iter,
2643                                             TaggedNodeSeq* ready) {
2644   bool is_frame_done = false;
2645   {
2646     mutex_lock frame_lock(frame->mu);
2647     frame->GetIteration(iter)->outstanding_frame_count--;
2648     is_frame_done = frame->CleanupIterations(&impl_->gview_, iter, ready);
2649   }
2650   if (is_frame_done) {
2651     FrameState* parent_frame = frame->parent_frame;
2652     const int64 parent_iter = frame->parent_iter;
2653     DeleteFrame(frame, ready);
2654     if (parent_frame != nullptr) {
2655       // The completion of frame may cause completions in its parent frame.
2656       // So clean things up recursively.
2657       CleanupFramesIterations(parent_frame, parent_iter, ready);
2658     }
2659   }
2660 }
2661 
ActivateNodes(const NodeItem * item,const bool is_dead,int64 iter,EntryVector * outputs,TaggedNodeSeq * ready)2662 void ExecutorState::FrameState::ActivateNodes(const NodeItem* item,
2663                                               const bool is_dead, int64 iter,
2664                                               EntryVector* outputs,
2665                                               TaggedNodeSeq* ready) {
2666   const GraphView& gview = executor->gview_;
2667   IterationState* iter_state = GetIteration(iter);
2668   const size_t num_output_edges = item->num_output_edges;
2669   const EdgeInfo* edges = item->output_edge_list();
2670   Entry* input_tensors = iter_state->input_tensors;
2671   for (size_t out_index = 0; out_index < num_output_edges; out_index++) {
2672     const EdgeInfo& e = edges[out_index];
2673     const int dst_id = e.dst_id;
2674     const NodeItem* dst_item = gview.node(dst_id);
2675     const PendingCounts::Handle dst_pending_id = dst_item->pending_id;
2676     const int src_slot = e.output_slot;
2677 
2678     // TODO(yuanbyu): We don't need this if we require the subgraph
2679     // given to an executor not to contain a sink node.
2680     if (dst_item->is_sink) continue;
2681 
2682     bool dst_dead = false;
2683     bool dst_ready = false;
2684     // True iff this input for dst is needed. We only set this input for
2685     // dst if this flag is true. This is needed to make the thread safety
2686     // analysis happy.
2687     const bool is_control_edge = (src_slot == Graph::kControlSlot);
2688     bool dst_need_input = !is_control_edge;
2689     if (dst_item->is_merge) {
2690       // A merge node is ready if all control inputs have arrived and either
2691       // a) a live data input becomes available or b) all data inputs are
2692       // dead. For Merge, pending's LSB is set iff a live data input has
2693       // arrived.
2694       if (is_control_edge) {
2695         iter_state->decrement_pending(dst_pending_id, 2);
2696         int count = iter_state->pending(dst_pending_id);
2697         int dead_cnt = iter_state->dead_count(dst_pending_id);
2698         dst_dead = (dead_cnt == dst_item->num_inputs);
2699         dst_ready = (count == 0) || ((count == 1) && dst_dead);
2700       } else {
2701         if ((*outputs)[src_slot].has_value) {
2702           // This is a live data input.
2703           int count = iter_state->pending(dst_pending_id);
2704           iter_state->mark_live(dst_pending_id);
2705           // Only the first live edge sets the input and (potentially)
2706           // triggers execution. The low bit of count is set if and
2707           // only if no live input has been used yet (mark_live clears
2708           // it). The node should be started if and only if this is
2709           // the first live input and there are no pending control
2710           // edges, i.e. count == 1.
2711           dst_ready = (count == 1);
2712           dst_need_input = ((count & 0x1) == 1);
2713         } else {
2714           // This is a dead data input. Note that dst_node is dead if node is
2715           // a dead enter. We need this to handle properly a while loop on
2716           // the untaken branch of a conditional.
2717           // TODO(yuanbyu): This is a bit hacky, but a good solution for
2718           // now.
2719           iter_state->increment_dead_count(dst_pending_id);
2720           const int dead_cnt = iter_state->dead_count(dst_pending_id);
2721           dst_dead = (dead_cnt == dst_item->num_inputs) || item->is_enter;
2722           dst_ready = (iter_state->pending(dst_pending_id) == 1) && dst_dead;
2723           dst_need_input = false;
2724         }
2725       }
2726     } else {
2727       const bool increment_dead =
2728           (is_dead || (!is_control_edge && !(*outputs)[src_slot].has_value));
2729       int pending, dead;
2730       iter_state->adjust_for_activation(dst_pending_id, increment_dead,
2731                                         &pending, &dead);
2732       dst_dead = (dead > 0);
2733       dst_ready = (pending == 0);
2734     }
2735 
2736     if (dst_need_input) {
2737       const int dst_slot = e.input_slot;
2738       const int dst_loc = dst_item->input_start + dst_slot;
2739       if (e.is_last) {
2740         input_tensors[dst_loc] = std::move((*outputs)[src_slot]);
2741       } else {
2742         input_tensors[dst_loc] = (*outputs)[src_slot];
2743       }
2744     }
2745 
2746     // Add dst to the ready queue if it's ready
2747     if (dst_ready) {
2748       if (dst_item->is_control_trigger) dst_dead = false;
2749       ready->emplace_back(dst_item->node, this, iter, dst_dead);
2750       iter_state->outstanding_ops++;
2751     }
2752   }
2753 }
2754 
ActivateNexts(const GraphView * gview,int64 iter,TaggedNodeSeq * ready)2755 void ExecutorState::FrameState::ActivateNexts(const GraphView* gview,
2756                                               int64 iter,
2757                                               TaggedNodeSeq* ready) {
2758   // Propagate the deferred NextIteration nodes to the new iteration.
2759   for (auto& node_entry : next_iter_roots) {
2760     const Node* node = node_entry.first;
2761     const Entry& entry = node_entry.second;
2762     const bool is_dead = !entry.has_value;
2763     const NodeItem* item = gview->node(node->id());
2764     EntryVector outputs{entry};
2765     ActivateNodes(item, is_dead, iter, &outputs, ready);
2766   }
2767   next_iter_roots.clear();
2768 }
2769 
ActivateLoopInvs(const GraphView * gview,int64 iter,TaggedNodeSeq * ready)2770 void ExecutorState::FrameState::ActivateLoopInvs(const GraphView* gview,
2771                                                  int64 iter,
2772                                                  TaggedNodeSeq* ready) {
2773   // Propagate loop invariants to the new iteration.
2774   for (auto& node_entry : inv_values) {
2775     const Node* node = node_entry.first;
2776     const Entry& entry = node_entry.second;
2777     const bool is_dead = !entry.has_value;
2778     const NodeItem* item = gview->node(node->id());
2779     EntryVector outputs{entry};
2780     ActivateNodes(item, is_dead, iter, &outputs, ready);
2781   }
2782 }
2783 
AddLoopInv(const NodeItem * item,const Entry & entry,TaggedNodeSeq * ready)2784 void ExecutorState::FrameState::AddLoopInv(const NodeItem* item,
2785                                            const Entry& entry,
2786                                            TaggedNodeSeq* ready) {
2787   // Store this value.
2788   inv_values.push_back({item->node, entry});
2789 
2790   // Make this value available to all iterations.
2791   const bool is_dead = !entry.has_value;
2792   for (int i = 0; i <= iteration_count; ++i) {
2793     EntryVector outputs{entry};
2794     ActivateNodes(item, is_dead, i, &outputs, ready);
2795   }
2796 }
2797 
IsIterationDone(int64 iter)2798 bool ExecutorState::FrameState::IsIterationDone(int64 iter) {
2799   IterationState* iter_state = GetIteration(iter);
2800   if (iter_state->outstanding_ops == 0 &&
2801       iter_state->outstanding_frame_count == 0) {
2802     if (iter == 0) {
2803       // The enclosing frame has no pending input.
2804       return num_pending_inputs == 0;
2805     } else {
2806       // The preceding iteration is deleted (and therefore done).
2807       return (GetIteration(iter - 1) == nullptr);
2808     }
2809   }
2810   return false;
2811 }
2812 
IncrementIteration(const GraphView * gview,TaggedNodeSeq * ready)2813 void ExecutorState::FrameState::IncrementIteration(const GraphView* gview,
2814                                                    TaggedNodeSeq* ready) {
2815   iteration_count++;
2816   const int64 next_iter = iteration_count;
2817 
2818   // Initialize the next iteration.
2819   IterationState* iter_state =
2820       new IterationState(pending_counts, total_input_tensors);
2821   SetIteration(next_iter, iter_state);
2822   num_outstanding_iterations++;
2823   dead_exits.clear();
2824 
2825   // Activate the successors of the deferred roots in the new iteration.
2826   ActivateNexts(gview, next_iter, ready);
2827 
2828   // Activate the loop invariants in the new iteration.
2829   ActivateLoopInvs(gview, next_iter, ready);
2830 }
2831 
CleanupIterations(const GraphView * gview,int64 iter,TaggedNodeSeq * ready)2832 bool ExecutorState::FrameState::CleanupIterations(const GraphView* gview,
2833                                                   int64 iter,
2834                                                   TaggedNodeSeq* ready) {
2835   int64 curr_iter = iter;
2836   while (curr_iter <= iteration_count && IsIterationDone(curr_iter)) {
2837     // Delete the iteration curr_iter.
2838     delete GetIteration(curr_iter);
2839     SetIteration(curr_iter, nullptr);
2840     --num_outstanding_iterations;
2841     ++curr_iter;
2842 
2843     // When one iteration is completed, we check for deferred iteration,
2844     // and start it if there is one.
2845     if (!next_iter_roots.empty()) {
2846       IncrementIteration(gview, ready);
2847     }
2848   }
2849   return IsFrameDone();
2850 }
2851 
RunAsync(const Args & args,DoneCallback done)2852 void ExecutorImpl::RunAsync(const Args& args, DoneCallback done) {
2853   (new ExecutorState(args, this))->RunAsync(std::move(done));
2854 }
2855 
2856 }  // namespace
2857 
NewLocalExecutor(const LocalExecutorParams & params,std::unique_ptr<const Graph> graph,Executor ** executor)2858 Status NewLocalExecutor(const LocalExecutorParams& params,
2859                         std::unique_ptr<const Graph> graph,
2860                         Executor** executor) {
2861   ExecutorImpl* impl = new ExecutorImpl(params, std::move(graph));
2862   const Status s = impl->Initialize();
2863   if (s.ok()) {
2864     *executor = impl;
2865   } else {
2866     delete impl;
2867   }
2868   return s;
2869 }
2870 
CreateNonCachedKernel(Device * device,FunctionLibraryRuntime * flib,const NodeDef & ndef,int graph_def_version,OpKernel ** kernel)2871 Status CreateNonCachedKernel(Device* device, FunctionLibraryRuntime* flib,
2872                              const NodeDef& ndef, int graph_def_version,
2873                              OpKernel** kernel) {
2874   const auto device_type = DeviceType(device->attributes().device_type());
2875   auto allocator = device->GetAllocator(AllocatorAttributes());
2876   return CreateOpKernel(device_type, device, allocator, flib, ndef,
2877                         graph_def_version, kernel);
2878 }
2879 
DeleteNonCachedKernel(OpKernel * kernel)2880 void DeleteNonCachedKernel(OpKernel* kernel) { delete kernel; }
2881 
2882 namespace {
2883 
2884 class DefaultExecutorRegistrar {
2885  public:
DefaultExecutorRegistrar()2886   DefaultExecutorRegistrar() {
2887     Factory* factory = new Factory;
2888     ExecutorFactory::Register("", factory);
2889     ExecutorFactory::Register("DEFAULT", factory);
2890   }
2891 
2892  private:
2893   class Factory : public ExecutorFactory {
NewExecutor(const LocalExecutorParams & params,std::unique_ptr<const Graph> graph,std::unique_ptr<Executor> * out_executor)2894     Status NewExecutor(const LocalExecutorParams& params,
2895                        std::unique_ptr<const Graph> graph,
2896                        std::unique_ptr<Executor>* out_executor) override {
2897       Executor* ret = nullptr;
2898       TF_RETURN_IF_ERROR(NewLocalExecutor(params, std::move(graph), &ret));
2899       out_executor->reset(ret);
2900       return Status::OK();
2901     }
2902   };
2903 };
2904 static DefaultExecutorRegistrar registrar;
2905 
2906 }  // namespace
2907 
2908 }  // namespace tensorflow
2909