1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/core/common_runtime/executor.h"
17
18 #include <atomic>
19 #include <deque>
20 #include <memory>
21 #include <string>
22 #include <unordered_map>
23 #include <vector>
24
25 #include "tensorflow/core/common_runtime/costmodel_manager.h"
26 #include "tensorflow/core/common_runtime/executor_factory.h"
27 #include "tensorflow/core/common_runtime/pending_counts.h"
28 #include "tensorflow/core/common_runtime/step_stats_collector.h"
29 #include "tensorflow/core/framework/allocation_description.pb.h"
30 #include "tensorflow/core/framework/allocator.h"
31 #include "tensorflow/core/framework/cancellation.h"
32 #include "tensorflow/core/framework/collective.h"
33 #include "tensorflow/core/framework/control_flow.h"
34 #include "tensorflow/core/framework/device_attributes.pb.h"
35 #include "tensorflow/core/framework/graph.pb.h"
36 #include "tensorflow/core/framework/log_memory.h"
37 #include "tensorflow/core/framework/node_def_util.h"
38 #include "tensorflow/core/framework/op_kernel.h"
39 #include "tensorflow/core/framework/op_segment.h"
40 #include "tensorflow/core/framework/step_stats.pb.h"
41 #include "tensorflow/core/framework/tensor.h"
42 #include "tensorflow/core/framework/tensor_reference.h"
43 #include "tensorflow/core/framework/types.h"
44 #include "tensorflow/core/framework/types.pb.h"
45 #include "tensorflow/core/graph/edgeset.h"
46 #include "tensorflow/core/lib/core/errors.h"
47 #include "tensorflow/core/lib/core/notification.h"
48 #include "tensorflow/core/lib/core/stringpiece.h"
49 #include "tensorflow/core/lib/core/threadpool.h"
50 #include "tensorflow/core/lib/gtl/flatmap.h"
51 #include "tensorflow/core/lib/gtl/flatset.h"
52 #include "tensorflow/core/lib/gtl/inlined_vector.h"
53 #include "tensorflow/core/lib/gtl/manual_constructor.h"
54 #include "tensorflow/core/lib/gtl/stl_util.h"
55 #include "tensorflow/core/lib/hash/hash.h"
56 #include "tensorflow/core/lib/strings/str_util.h"
57 #include "tensorflow/core/lib/strings/stringprintf.h"
58 #include "tensorflow/core/platform/context.h"
59 #include "tensorflow/core/platform/env.h"
60 #include "tensorflow/core/platform/logging.h"
61 #include "tensorflow/core/platform/macros.h"
62 #include "tensorflow/core/platform/mutex.h"
63 #include "tensorflow/core/platform/profile_utils/cpu_utils.h"
64 #include "tensorflow/core/platform/thread_annotations.h"
65 #include "tensorflow/core/platform/tracing.h"
66 #include "tensorflow/core/platform/types.h"
67 #include "tensorflow/core/util/tensor_slice_reader_cache.h"
68
69 namespace tensorflow {
70 namespace {
71
72 // 1-D, 0 element tensor.
73 static const Tensor* const kEmptyTensor = new Tensor;
74
IsInitializationOp(const Node * node)75 bool IsInitializationOp(const Node* node) {
76 return node->op_def().allows_uninitialized_input();
77 }
78
79 // Helper routines for collecting step stats.
80 namespace nodestats {
NowInNsec()81 inline int64 NowInNsec() { return Env::Default()->NowNanos(); }
82
SetScheduled(NodeExecStatsInterface * stats,int64 micros)83 void SetScheduled(NodeExecStatsInterface* stats, int64 micros) {
84 if (!stats) return;
85 stats->SetScheduled(micros * EnvTime::kMicrosToNanos);
86 }
87
SetAllStart(NodeExecStatsInterface * stats)88 void SetAllStart(NodeExecStatsInterface* stats) {
89 if (!stats) return;
90 stats->RecordExecutorStarted();
91 }
92
SetOpStart(NodeExecStatsInterface * stats)93 void SetOpStart(NodeExecStatsInterface* stats) {
94 if (!stats) return;
95 stats->RecordComputeStarted();
96 }
97
SetOpEnd(NodeExecStatsInterface * stats)98 void SetOpEnd(NodeExecStatsInterface* stats) {
99 if (!stats) return;
100 stats->RecordComputeEnded();
101 }
102
SetAllEnd(NodeExecStatsInterface * stats)103 void SetAllEnd(NodeExecStatsInterface* stats) {
104 if (!stats) return;
105 stats->RecordExecutorEnded();
106 }
107
SetOutput(NodeExecStatsInterface * stats,int slot,const Tensor * v)108 void SetOutput(NodeExecStatsInterface* stats, int slot, const Tensor* v) {
109 if (!stats) return;
110 stats->SetOutput(slot, v);
111 }
112
SetMemory(NodeExecStatsInterface * stats,OpKernelContext * ctx)113 void SetMemory(NodeExecStatsInterface* stats, OpKernelContext* ctx) {
114 if (!stats) return;
115 stats->SetMemory(ctx);
116 }
117
SetReferencedTensors(NodeExecStatsInterface * stats,const TensorReferenceVector & tensors)118 void SetReferencedTensors(NodeExecStatsInterface* stats,
119 const TensorReferenceVector& tensors) {
120 if (!stats) return;
121 stats->SetReferencedTensors(tensors);
122 }
123
124 } // namespace nodestats
125
126 class ExecutorImpl;
127 class GraphView;
128
129 struct EdgeInfo {
130 int dst_id;
131 int output_slot : 31;
132 // true if this is the last info for output_slot in the EdgeInfo list.
133 bool is_last : 1;
134 int input_slot;
135 };
136
137 // Time the execution of kernels (in CPU cycles). Used to dynamically identify
138 // inexpensive kernels which can be dispatched inline.
139 struct KernelTimer {
140 uint64 start_cycles = profile_utils::CpuUtils::GetCurrentClockCycle();
141
ElapsedCyclestensorflow::__anon6f8fc96b0111::KernelTimer142 uint64 ElapsedCycles() {
143 return profile_utils::CpuUtils::GetCurrentClockCycle() - start_cycles;
144 }
145 };
146
147 struct NodeItem {
NodeItemtensorflow::__anon6f8fc96b0111::NodeItem148 NodeItem() {}
149
150 // A graph node.
151 const Node* node = nullptr;
152
153 // The kernel for this node.
154 OpKernel* kernel = nullptr;
155
156 bool kernel_is_async : 1; // True iff kernel->AsAsync() != nullptr
157 bool is_merge : 1; // True iff IsMerge(node)
158 bool is_enter : 1; // True iff IsEnter(node)
159 bool is_constant_enter : 1; // True iff IsEnter(node) and
160 // node->GetAttr("is_constant") == true.
161 bool is_exit : 1; // True iff IsExit(node)
162 bool is_control_trigger : 1; // True iff IsControlTrigger(node)
163 bool is_sink : 1; // True iff IsSink(node)
164 // True iff IsEnter(node) || IsExit(node) || IsNextIteration(node)
165 bool is_enter_exit_or_next_iter : 1;
166
167 // Cached values of node->num_inputs() and node->num_outputs(), to
168 // avoid levels of indirection.
169 int num_inputs;
170 int num_outputs;
171
172 // ExecutorImpl::tensors_[input_start] is the 1st positional input
173 // for this node.
174 int input_start = 0;
175
176 // Number of output edges.
177 size_t num_output_edges;
178
179 PendingCounts::Handle pending_id;
180
output_edge_listtensorflow::__anon6f8fc96b0111::NodeItem181 const EdgeInfo* output_edge_list() const { return output_edge_base(); }
182
183 // ith output edge.
output_edgetensorflow::__anon6f8fc96b0111::NodeItem184 const EdgeInfo& output_edge(int i) const {
185 DCHECK_GE(i, 0);
186 DCHECK_LT(i, num_output_edges);
187 return output_edge_base()[i];
188 }
189
input_typetensorflow::__anon6f8fc96b0111::NodeItem190 DataType input_type(int i) const {
191 DCHECK_LT(i, num_inputs);
192 return static_cast<DataType>(input_type_base()[i]);
193 }
output_typetensorflow::__anon6f8fc96b0111::NodeItem194 DataType output_type(int i) const {
195 DCHECK_LT(i, num_outputs);
196 return static_cast<DataType>(output_type_base()[i]);
197 }
198
199 // Return array of per-output allocator attributes.
output_attrstensorflow::__anon6f8fc96b0111::NodeItem200 const AllocatorAttributes* output_attrs() const { return output_attr_base(); }
201
202 // Return array of expected input index from which each output should
203 // be forwarded:
204 // kNeverForward (-2) for DO NOT FORWARD (must allocate).
205 // kNoReservation (-1) for no expected forwarding.
206 // 0... for forward from that input.
forward_fromtensorflow::__anon6f8fc96b0111::NodeItem207 const int* forward_from() const { return forward_from_base(); }
208
209 private:
210 friend class GraphView;
211
212 // Variable length section starts immediately after *this
213 // (uint8 is enough for DataType).
214 // EdgeInfo out_edges[num_out_edges];
215 // AllocatorAttributes output_attr[num_outputs];
216 // int forward_from[num_outputs];
217 // uint8 input_type[num_inputs];
218 // uint8 output_type[num_outputs];
219
220 // Return pointer to variable length section.
vartensorflow::__anon6f8fc96b0111::NodeItem221 char* var() const {
222 return const_cast<char*>(reinterpret_cast<const char*>(this) +
223 sizeof(NodeItem));
224 }
225
output_edge_basetensorflow::__anon6f8fc96b0111::NodeItem226 EdgeInfo* output_edge_base() const {
227 return reinterpret_cast<EdgeInfo*>(var());
228 }
output_attr_basetensorflow::__anon6f8fc96b0111::NodeItem229 AllocatorAttributes* output_attr_base() const {
230 return reinterpret_cast<AllocatorAttributes*>(var() + sizeof(EdgeInfo) *
231 num_output_edges);
232 }
forward_from_basetensorflow::__anon6f8fc96b0111::NodeItem233 int* forward_from_base() const {
234 return reinterpret_cast<int*>(var() + sizeof(EdgeInfo) * num_output_edges +
235 sizeof(AllocatorAttributes) * num_outputs);
236 }
input_type_basetensorflow::__anon6f8fc96b0111::NodeItem237 uint8* input_type_base() const {
238 return reinterpret_cast<uint8*>(
239 var() + sizeof(EdgeInfo) * num_output_edges +
240 sizeof(AllocatorAttributes) * num_outputs + sizeof(int) * num_outputs);
241 }
output_type_basetensorflow::__anon6f8fc96b0111::NodeItem242 uint8* output_type_base() const {
243 return reinterpret_cast<uint8*>(
244 var() + sizeof(EdgeInfo) * num_output_edges +
245 sizeof(AllocatorAttributes) * num_outputs + sizeof(int) * num_outputs +
246 sizeof(uint8) * num_inputs);
247 }
248
249 TF_DISALLOW_COPY_AND_ASSIGN(NodeItem);
250 };
251
252 typedef gtl::InlinedVector<TensorValue, 4> TensorValueVec;
253 typedef gtl::InlinedVector<DeviceContext*, 4> DeviceContextVec;
254 typedef gtl::InlinedVector<AllocatorAttributes, 4> AllocatorAttributeVec;
255
256 // Immutable view of a Graph organized for efficient execution.
257 class GraphView {
258 public:
GraphView()259 GraphView() : space_(nullptr) {}
260 ~GraphView();
261
262 void Initialize(const Graph* g);
263 Status SetAllocAttrs(const Graph* g, const Device* device);
264 void SetScopedAllocatorAttrs(const std::vector<const Node*>& sa_nodes);
265
node(size_t id) const266 NodeItem* node(size_t id) const {
267 DCHECK_GE(id, 0);
268 DCHECK_LT(id, num_nodes_);
269 uint32 offset = node_offsets_[id];
270 return ((offset == kuint32max)
271 ? nullptr
272 : reinterpret_cast<NodeItem*>(space_ + node_offsets_[id]));
273 }
274
275 private:
276 char* InitializeNode(char* ptr, const Node* n);
277 size_t NodeItemBytes(const Node* n);
278
279 int32 num_nodes_ = 0;
280 uint32* node_offsets_ = nullptr; // array of size "graph_.num_node_ids()"
281 // node_offsets_[id] holds the byte offset for node w/ "id" in space_
282
283 char* space_; // NodeItem objects are allocated here
284
285 TF_DISALLOW_COPY_AND_ASSIGN(GraphView);
286 };
287
288 class ExecutorImpl : public Executor {
289 public:
ExecutorImpl(const LocalExecutorParams & p,std::unique_ptr<const Graph> g)290 ExecutorImpl(const LocalExecutorParams& p, std::unique_ptr<const Graph> g)
291 : params_(p), graph_(std::move(g)), gview_() {
292 CHECK(p.create_kernel != nullptr);
293 CHECK(p.delete_kernel != nullptr);
294 }
295
~ExecutorImpl()296 ~ExecutorImpl() override {
297 for (int i = 0; i < graph_->num_node_ids(); i++) {
298 NodeItem* item = gview_.node(i);
299 if (item != nullptr) {
300 params_.delete_kernel(item->kernel);
301 }
302 }
303 for (auto fiter : frame_info_) {
304 delete fiter.second;
305 }
306 }
307
308 Status Initialize();
309
310 // Process all Nodes in the current graph, attempting to infer the
311 // memory allocation attributes to be used wherever they may allocate
312 // a tensor buffer.
313 Status SetAllocAttrs();
314
315 void RunAsync(const Args& args, DoneCallback done) override;
316
317 private:
318 friend class ExecutorState;
319
320 struct ControlFlowInfo {
321 gtl::FlatSet<string> unique_frame_names;
322 std::vector<string> frame_names;
323 };
324
325 struct FrameInfo {
FrameInfotensorflow::__anon6f8fc96b0111::ExecutorImpl::FrameInfo326 FrameInfo()
327 : input_count(0),
328 total_inputs(0),
329 pending_counts(nullptr),
330 nodes(nullptr) {}
331
332 // The total number of inputs to a frame.
333 int input_count;
334
335 // The total number of input tensors of a frame.
336 // == sum(nodes[*].num_inputs()) where nodes are the nodes in the frame.
337 int total_inputs;
338
339 // Used to determine the next place to allocate space in the
340 // pending_counts data structure we'll eventually construct
341 PendingCounts::Layout pending_counts_layout;
342
343 // Each frame has its own PendingCounts only for the nodes in the frame.
344 PendingCounts* pending_counts; // Owned
345
346 // The nodes in a frame. Used only for debugging.
347 std::vector<const Node*>* nodes; // Owned
348
~FrameInfotensorflow::__anon6f8fc96b0111::ExecutorImpl::FrameInfo349 ~FrameInfo() {
350 delete pending_counts;
351 delete nodes;
352 }
353 };
354
355 static Status BuildControlFlowInfo(const Graph* graph,
356 ControlFlowInfo* cf_info);
357 void InitializePending(const Graph* graph, const ControlFlowInfo& cf_info);
358
EnsureFrameInfo(const string & fname)359 FrameInfo* EnsureFrameInfo(const string& fname) {
360 auto slot = &frame_info_[fname];
361 if (*slot == nullptr) {
362 *slot = new FrameInfo;
363 }
364 return *slot;
365 }
366
367 // Owned.
368 LocalExecutorParams params_;
369 std::unique_ptr<const Graph> graph_;
370 GraphView gview_;
371
372 // A cached value of params_
373 bool device_record_tensor_accesses_ = false;
374
375 // Root nodes (with no in edges) that should form the initial ready queue
376 std::vector<const Node*> root_nodes_;
377
378 // Mapping from frame name to static information about the frame.
379 // TODO(yuanbyu): We could cache it along with the graph so to avoid
380 // the overhead of constructing it for each executor instance.
381 gtl::FlatMap<string, FrameInfo*> frame_info_;
382
383 TF_DISALLOW_COPY_AND_ASSIGN(ExecutorImpl);
384 };
385
386 // Infer memory allocation attributes of a node n's output,
387 // based on its use node dst. Note that dst might not be directly
388 // connected to n by a single edge, but might be a downstream
389 // consumer of n's output by reference. *attr is updated with any
390 // necessary attributes.
391 Status InferAllocAttr(const Node* n, const Node* dst,
392 const DeviceNameUtils::ParsedName& local_dev_name,
393 AllocatorAttributes* attr);
394
~GraphView()395 GraphView::~GraphView() {
396 static_assert(std::is_trivially_destructible<AllocatorAttributes>::value,
397 "Update code if AllocatorAttributes gains a destructor");
398 static_assert(std::is_trivially_destructible<EdgeInfo>::value,
399 "Update code if EdgeInfo gains a destructor");
400 for (int i = 0; i < num_nodes_; i++) {
401 NodeItem* n = node(i);
402 if (n != nullptr) {
403 n->NodeItem::~NodeItem();
404 // Memory for "n" itself is held in space_ & gets cleaned up below
405 }
406 }
407 delete[] node_offsets_;
408 delete[] space_;
409 }
410
NodeItemBytes(const Node * n)411 size_t GraphView::NodeItemBytes(const Node* n) {
412 const size_t num_output_edges = n->out_edges().size();
413 const int num_inputs = n->num_inputs();
414 const int num_outputs = n->num_outputs();
415
416 // Compute number of bytes needed for NodeItem and variable length data.
417 // We do not subtract sizeof(var) since num_inputs/num_outputs might
418 // both be zero.
419 const size_t raw_bytes =
420 sizeof(NodeItem) // Fixed
421 + num_output_edges * sizeof(EdgeInfo) // output_edges[...]
422 + num_outputs * sizeof(AllocatorAttributes) // output_attr[...]
423 + num_outputs * sizeof(int) // forward_from[num_outputs]
424 + num_inputs * sizeof(uint8) // input_type[num_inputs]
425 + num_outputs * sizeof(uint8); // output_type[num_outputs]
426 static constexpr size_t kItemAlignment = sizeof(NodeItem*);
427 static_assert(kItemAlignment % alignof(NodeItem) == 0,
428 "NodeItem must be aligned with kItemAlignment");
429 static_assert(kItemAlignment % alignof(EdgeInfo) == 0,
430 "EdgeInfo must be aligned with kItemAlignment");
431 static_assert(kItemAlignment % alignof(AllocatorAttributes) == 0,
432 "AllocatorAttributes must be aligned with kItemAlignment");
433 static_assert(sizeof(NodeItem) % alignof(EdgeInfo) == 0,
434 "NodeItem must be aligned with EdgeInfo");
435 static_assert(sizeof(NodeItem) % alignof(AllocatorAttributes) == 0,
436 "NodeItem must be aligned with AllocatorAttributes");
437 static_assert(sizeof(EdgeInfo) % alignof(AllocatorAttributes) == 0,
438 "EdgeInfo must be aligned with AllocatorAttributes");
439 const size_t bytes =
440 ((raw_bytes + kItemAlignment - 1) / kItemAlignment) * kItemAlignment;
441 return bytes;
442 }
443
InitializeNode(char * ptr,const Node * n)444 char* GraphView::InitializeNode(char* ptr, const Node* n) {
445 const int id = n->id();
446 CHECK(node_offsets_[id] == kuint32max); // Initial value in constructor
447
448 const size_t bytes = NodeItemBytes(n);
449 constexpr size_t kItemAlignment = sizeof(NodeItem*);
450 CHECK_EQ(reinterpret_cast<uintptr_t>(ptr) % kItemAlignment, 0);
451 NodeItem* item = reinterpret_cast<NodeItem*>(ptr);
452
453 // We store a 32-bit offset relative to the beginning of space_, so that we
454 // only need an array of 32-bit values to map from node id to the NodeItem*,
455 // (versus 64 bits on most machines if we just stored an array of NodeItem*
456 // pointers). Casting to int64 is needed on 32bit CPU to avoid comparing
457 // values as "int" vs "size_t" in CHECK_LE.
458 CHECK_LE(static_cast<int64>(ptr - space_), kuint32max);
459 const uint32 offset = static_cast<uint32>(ptr - space_);
460 node_offsets_[id] = offset;
461 ptr += bytes;
462
463 const size_t num_output_edges = n->out_edges().size();
464 const int num_inputs = n->num_inputs();
465 const int num_outputs = n->num_outputs();
466
467 new (item) NodeItem();
468 item->num_inputs = num_inputs;
469 item->num_outputs = num_outputs;
470 item->num_output_edges = num_output_edges;
471
472 // Fill output edges.
473 // Keep track of the last EdgeInfo in the EdgeInfo array that references
474 // a given output slot. For all but the last, we need to do a copy of the
475 // Tensor when propagating results downstream in the graph, but for the
476 // last one, we can just do a move of the Tensor object to propagate it.
477 gtl::InlinedVector<EdgeInfo*, 4> last_indices(num_outputs, nullptr);
478 EdgeInfo* dst_edge = item->output_edge_base();
479 for (auto e : n->out_edges()) {
480 dst_edge->dst_id = e->dst()->id();
481 CHECK_LE(e->src_output(), 0x3FFFFFFF); // Must fit in 31 bits
482 dst_edge->output_slot = e->src_output();
483 dst_edge->is_last = false;
484 const int output_slot = dst_edge->output_slot;
485 if (output_slot >= 0) {
486 last_indices[output_slot] = dst_edge;
487 }
488 dst_edge->input_slot = e->dst_input();
489 dst_edge++;
490 }
491 for (EdgeInfo* edge_info : last_indices) {
492 if (edge_info != nullptr) {
493 edge_info->is_last = true;
494 }
495 }
496
497 AllocatorAttributes* output_attrs = item->output_attr_base();
498 for (int i = 0; i < num_outputs; i++) {
499 new (&output_attrs[i]) AllocatorAttributes();
500 }
501
502 DCHECK_LT(DataType_MAX, 255); // Must fit in uint8
503 uint8* input_types = item->input_type_base();
504 for (int i = 0; i < num_inputs; i++) {
505 input_types[i] = static_cast<uint8>(n->input_type(i));
506 DCHECK_EQ(item->input_type(i), n->input_type(i));
507 }
508
509 // Check ScopedAllocatorAttrs and forward_from. Also assign output_types.
510 {
511 std::vector<int> forward_input;
512 Status fwd_status =
513 GetNodeAttr(n->attrs(), "_forward_input", &forward_input);
514 std::vector<int> scoped_allocator_attrs;
515 Status sa_status =
516 GetNodeAttr(n->attrs(), "_scoped_allocator", &scoped_allocator_attrs);
517
518 int* forward_from = item->forward_from_base();
519 uint8* output_types = item->output_type_base();
520 for (int i = 0; i < num_outputs; ++i) {
521 output_types[i] = static_cast<uint8>(n->output_type(i));
522 DCHECK_EQ(item->output_type(i), n->output_type(i));
523
524 forward_from[i] = OpKernelContext::Params::kNoReservation;
525 if (sa_status.ok()) {
526 for (int j = 0; j < scoped_allocator_attrs.size(); j += 2) {
527 if (scoped_allocator_attrs[j] == i) {
528 // This output slot must be explicitly allocated from a
529 // ScopedAllocator.
530 forward_from[i] = OpKernelContext::Params::kNeverForward;
531 DCHECK_EQ(output_attrs[i].scope_id, 0);
532 output_attrs[i].scope_id = scoped_allocator_attrs[j + 1];
533 }
534 }
535 }
536 if (fwd_status.ok() &&
537 forward_from[i] == OpKernelContext::Params::kNoReservation) {
538 DCHECK_EQ(forward_input.size() % 2, 0);
539 for (int j = 0; j < forward_input.size(); j += 2) {
540 if (forward_input[j + 1] == i) {
541 DCHECK_EQ(forward_from[i], OpKernelContext::Params::kNoReservation);
542 forward_from[i] = forward_input[j];
543 break;
544 }
545 }
546 }
547 }
548 }
549
550 return ptr;
551 }
552
Initialize(const Graph * g)553 void GraphView::Initialize(const Graph* g) {
554 CHECK(node_offsets_ == nullptr);
555 const int num_nodes = g->num_node_ids();
556 num_nodes_ = num_nodes;
557 size_t total_bytes = 0;
558 for (const Node* n : g->nodes()) {
559 total_bytes += NodeItemBytes(n);
560 }
561
562 node_offsets_ = new uint32[num_nodes];
563 for (int i = 0; i < num_nodes; i++) {
564 node_offsets_[i] = kuint32max;
565 }
566
567 space_ = new char[total_bytes]; // NodeItem objects are allocated here
568 char* ptr = space_;
569 for (const Node* n : g->nodes()) {
570 ptr = InitializeNode(ptr, n);
571 }
572 CHECK_EQ(ptr, space_ + total_bytes);
573 }
574
GetMaxPendingCounts(const Node * n,size_t * max_pending,size_t * max_dead_count)575 void GetMaxPendingCounts(const Node* n, size_t* max_pending,
576 size_t* max_dead_count) {
577 const size_t num_in_edges = n->in_edges().size();
578 size_t initial_count;
579 if (IsMerge(n)) {
580 // merge waits all control inputs so we initialize the pending
581 // count to be the number of control edges.
582 int32 num_control_edges = 0;
583 for (const Edge* edge : n->in_edges()) {
584 if (edge->IsControlEdge()) {
585 num_control_edges++;
586 }
587 }
588 // Use bit 0 to indicate if we are waiting for a ready live data input.
589 initial_count = 1 + (num_control_edges << 1);
590 } else {
591 initial_count = num_in_edges;
592 }
593
594 *max_pending = initial_count;
595 *max_dead_count = num_in_edges;
596 }
597
Initialize()598 Status ExecutorImpl::Initialize() {
599 gview_.Initialize(graph_.get());
600
601 // Build the information about frames in this subgraph.
602 ControlFlowInfo cf_info;
603 TF_RETURN_IF_ERROR(BuildControlFlowInfo(graph_.get(), &cf_info));
604
605 // Cache this value so we make this virtual function call once, rather
606 // that O(# steps * # nodes per step) times.
607 device_record_tensor_accesses_ =
608 params_.device->RequiresRecordingAccessedTensors();
609
610 for (auto& it : cf_info.unique_frame_names) {
611 EnsureFrameInfo(it)->nodes = new std::vector<const Node*>;
612 }
613
614 // Preprocess every node in the graph to create an instance of op
615 // kernel for each node.
616 for (const Node* n : graph_->nodes()) {
617 const int id = n->id();
618 const string& frame_name = cf_info.frame_names[id];
619 FrameInfo* frame_info = EnsureFrameInfo(frame_name);
620
621 // See if this node is a root node, and if so, add to root_nodes_.
622 if (n->in_edges().empty()) {
623 root_nodes_.push_back(n);
624 }
625
626 NodeItem* item = gview_.node(id);
627 item->node = n;
628
629 item->input_start = frame_info->total_inputs;
630 frame_info->total_inputs += n->num_inputs();
631
632 Status s = params_.create_kernel(n->def(), &item->kernel);
633 if (!s.ok()) {
634 item->kernel = nullptr;
635 s = AttachDef(s, *n);
636 LOG(ERROR) << "Executor failed to create kernel. " << s;
637 return s;
638 }
639 CHECK(item->kernel);
640 item->kernel_is_async = (item->kernel->AsAsync() != nullptr);
641 item->is_merge = IsMerge(n);
642 item->is_enter = IsEnter(n);
643 if (item->is_enter) {
644 bool is_constant_enter;
645 TF_RETURN_IF_ERROR(
646 GetNodeAttr(n->attrs(), "is_constant", &is_constant_enter));
647 item->is_constant_enter = is_constant_enter;
648 } else {
649 item->is_constant_enter = false;
650 }
651 item->is_exit = IsExit(n);
652 item->is_control_trigger = IsControlTrigger(n);
653 item->is_sink = IsSink(n);
654 item->is_enter_exit_or_next_iter =
655 (IsEnter(n) || IsExit(n) || IsNextIteration(n));
656
657 // Compute the maximum values we'll store for this node in the
658 // pending counts data structure, and allocate a handle in
659 // that frame's pending counts data structure that has enough
660 // space to store these maximal count values.
661 size_t max_pending, max_dead;
662 GetMaxPendingCounts(n, &max_pending, &max_dead);
663 item->pending_id =
664 frame_info->pending_counts_layout.CreateHandle(max_pending, max_dead);
665
666 // Initialize static information about the frames in the graph.
667 frame_info->nodes->push_back(n);
668 if (IsEnter(n)) {
669 string enter_name;
670 TF_RETURN_IF_ERROR(GetNodeAttr(n->attrs(), "frame_name", &enter_name));
671 EnsureFrameInfo(enter_name)->input_count++;
672 }
673 }
674
675 // Initialize PendingCounts only after item->pending_id is initialized for
676 // all nodes.
677 InitializePending(graph_.get(), cf_info);
678
679 return gview_.SetAllocAttrs(graph_.get(), params_.device);
680 }
681
682 // If a Node has been marked to use a ScopedAllocator x for output i, then
683 // sc_attr will contain the subsequence (i, x) at an even offset. This function
684 // extracts and transfers that ScopedAllocator id to alloc_attr. For now, we
685 // only allow one ScopedAllocator use per Node.
ExtractScopedAllocatorAttr(const std::vector<int> & sc_attr,int output_index,AllocatorAttributes * alloc_attr)686 bool ExtractScopedAllocatorAttr(const std::vector<int>& sc_attr,
687 int output_index,
688 AllocatorAttributes* alloc_attr) {
689 DCHECK_LE(2, sc_attr.size());
690 for (int i = 0; i < sc_attr.size(); i += 2) {
691 if (sc_attr[i] == output_index) {
692 CHECK_EQ(alloc_attr->scope_id, 0);
693 alloc_attr->scope_id = sc_attr[i + 1];
694 return true;
695 }
696 }
697 return false;
698 }
699
SetScopedAllocatorAttrs(const std::vector<const Node * > & sa_nodes)700 void GraphView::SetScopedAllocatorAttrs(
701 const std::vector<const Node*>& sa_nodes) {
702 for (const Node* sa : sa_nodes) {
703 NodeItem* sa_item = node(sa->id());
704 AllocatorAttributes* sa_attrs = sa_item->output_attr_base();
705 // Control edges out of the ScopedAllocator should be use instances, but may
706 // include a few other nodes.
707 for (const auto& e : sa->out_edges()) {
708 if (!e->IsControlEdge()) {
709 continue;
710 }
711 Node* use_node = e->dst();
712 NodeItem* item = node(use_node->id());
713 AllocatorAttributes* use_attrs = item->output_attr_base();
714 std::vector<int> scoped_allocator_attrs;
715 Status s = GetNodeAttr(use_node->attrs(), "_scoped_allocator",
716 &scoped_allocator_attrs);
717 if (!s.ok()) {
718 VLOG(2) << "Failed to find expected ScopedAllocator attr on "
719 << use_node->name();
720 continue;
721 }
722 // There can be more than one output using ScopedAllocation, but this
723 // analysis assumes they use the same ScopedAllocator.
724 for (const auto& e : use_node->out_edges()) {
725 if (!e->IsControlEdge()) {
726 AllocatorAttributes attr;
727 if (ExtractScopedAllocatorAttr(scoped_allocator_attrs,
728 e->src_output(), &attr)) {
729 // Set the scope_id on this use instance node.
730 (use_attrs + e->src_output())->Merge(attr);
731 // Propagate the other attributes of this node back to the SA node.
732 attr = *(use_attrs + e->src_output());
733 attr.scope_id = 0;
734 sa_attrs->Merge(attr);
735 }
736 }
737 }
738 }
739 }
740 }
741
SetAllocAttrs(const Graph * g,const Device * device)742 Status GraphView::SetAllocAttrs(const Graph* g, const Device* device) {
743 Status s;
744 DeviceNameUtils::ParsedName local_dev_name = device->parsed_name();
745
746 std::vector<const Node*> scoped_allocator_instances;
747 for (const Node* n : g->nodes()) {
748 NodeItem* item = node(n->id());
749 AllocatorAttributes* attrs = item->output_attr_base();
750 if (IsScopedAllocator(n)) {
751 scoped_allocator_instances.push_back(n);
752 }
753
754 // Examine the out edges of each node looking for special use
755 // cases that may affect memory allocation attributes.
756 for (const auto& e : n->out_edges()) {
757 if (!e->IsControlEdge()) {
758 AllocatorAttributes attr;
759 s = InferAllocAttr(n, e->dst(), local_dev_name, &attr);
760 if (!s.ok()) return s;
761 if (attr.value != 0 || attr.scope_id != 0) {
762 attrs[e->src_output()].Merge(attr);
763 }
764 }
765 }
766
767 for (int out = 0; out < n->num_outputs(); out++) {
768 const OpKernel* op_kernel = item->kernel;
769 DCHECK_LT(out, op_kernel->output_memory_types().size());
770 bool on_host = op_kernel->output_memory_types()[out] == HOST_MEMORY;
771 if (on_host) {
772 AllocatorAttributes h;
773 h.set_on_host(on_host);
774 attrs[out].Merge(h);
775 }
776 }
777 }
778 SetScopedAllocatorAttrs(scoped_allocator_instances);
779 return s;
780 }
781
InferAllocAttr(const Node * n,const Node * dst,const DeviceNameUtils::ParsedName & local_dev_name,AllocatorAttributes * attr)782 Status InferAllocAttr(const Node* n, const Node* dst,
783 const DeviceNameUtils::ParsedName& local_dev_name,
784 AllocatorAttributes* attr) {
785 Status s;
786 // Note that it's possible for *n to be a Recv and *dst to be a Send,
787 // so these two cases are not mutually exclusive.
788 if (IsRecv(n)) {
789 string src_name;
790 s = GetNodeAttr(n->attrs(), "send_device", &src_name);
791 if (!s.ok()) return s;
792 DeviceNameUtils::ParsedName parsed_src_name;
793 if (!DeviceNameUtils::ParseFullName(src_name, &parsed_src_name)) {
794 s = errors::Internal("Bad send_device attr '", src_name, "' in node ",
795 n->name());
796 return s;
797 }
798 if (!DeviceNameUtils::IsSameAddressSpace(parsed_src_name, local_dev_name)) {
799 // Value is going to be the sink of an RPC.
800 attr->set_nic_compatible(true);
801 VLOG(2) << "node " << n->name() << " is the sink of an RPC in";
802 } else if ((local_dev_name.type == "CPU" || n->IsHostRecv()) &&
803 parsed_src_name.type != "CPU") {
804 // Value is going to be the sink of a local DMA from GPU to CPU (or
805 // other types of accelerators).
806 attr->set_gpu_compatible(true);
807 VLOG(2) << "node " << n->name() << " is the sink of a gpu->cpu copy";
808 } else {
809 VLOG(2) << "default alloc case local type " << local_dev_name.type
810 << " remote type " << parsed_src_name.type;
811 }
812 }
813 if (IsSend(dst)) {
814 string dst_name;
815 s = GetNodeAttr(dst->attrs(), "recv_device", &dst_name);
816 if (!s.ok()) return s;
817 DeviceNameUtils::ParsedName parsed_dst_name;
818 if (!DeviceNameUtils::ParseFullName(dst_name, &parsed_dst_name)) {
819 s = errors::Internal("Bad recv_device attr '", dst_name, "' in node ",
820 n->name());
821 return s;
822 }
823 if (!DeviceNameUtils::IsSameAddressSpace(parsed_dst_name, local_dev_name)) {
824 // Value is going to be the source of an RPC.
825 attr->set_nic_compatible(true);
826 VLOG(2) << "node " << n->name() << " is the source of an RPC out";
827 } else if ((local_dev_name.type == "CPU" || dst->IsHostSend()) &&
828 parsed_dst_name.type != "CPU") {
829 // Value is going to be the source of a local DMA from CPU to GPU (or
830 // other types of accelerators).
831 // Note that this does not cover the case where the allocation of the
832 // output tensor is not generated by the src: n.
833 attr->set_gpu_compatible(true);
834 VLOG(2) << "node " << n->name() << " is the source of a cpu->gpu copy";
835 } else {
836 VLOG(2) << "default alloc case local type " << local_dev_name.type
837 << " remote type " << parsed_dst_name.type;
838 }
839 }
840 if (n->IsCollective()) {
841 // We'll make the sweeping assumption that any collective op is going
842 // to be involved in network i/o.
843 attr->set_nic_compatible(true);
844 }
845 return s;
846 }
847
848 // The state associated with one invocation of ExecutorImpl::Run.
849 // ExecutorState dispatches nodes when they become ready and keeps
850 // track of how many predecessors of a node have not done (pending_).
851 class ExecutorState {
852 public:
853 ExecutorState(const Executor::Args& args, ExecutorImpl* impl);
854 ~ExecutorState();
855
856 void RunAsync(Executor::DoneCallback done);
857
858 private:
859 // Either a tensor pointer (pass-by-reference) or a tensor (pass-by-value).
860 // TODO(yuanbyu): A better way to do "has_value"?
861 struct Entry {
Entrytensorflow::__anon6f8fc96b0111::ExecutorState::Entry862 Entry() {}
Entrytensorflow::__anon6f8fc96b0111::ExecutorState::Entry863 Entry(const Entry& other)
864 : ref(other.ref),
865 ref_mu(other.ref_mu),
866 has_value(other.has_value),
867 val_field_is_set(other.val_field_is_set),
868 alloc_attr(other.alloc_attr),
869 device_context(other.device_context) {
870 if (val_field_is_set) {
871 val.Init(*other.val);
872 }
873 }
~Entrytensorflow::__anon6f8fc96b0111::ExecutorState::Entry874 ~Entry() {
875 if (val_field_is_set) val.Destroy();
876 }
877
operator =tensorflow::__anon6f8fc96b0111::ExecutorState::Entry878 Entry& operator=(const Entry& other) {
879 if (val_field_is_set) {
880 val.Destroy();
881 }
882 ref = other.ref;
883 ref_mu = other.ref_mu;
884 has_value = other.has_value;
885 val_field_is_set = other.val_field_is_set;
886 alloc_attr = other.alloc_attr;
887 device_context = other.device_context;
888 if (val_field_is_set) {
889 val.Init(*other.val);
890 }
891 return *this;
892 }
893
operator =tensorflow::__anon6f8fc96b0111::ExecutorState::Entry894 Entry& operator=(Entry&& other) {
895 if (val_field_is_set) {
896 val.Destroy();
897 }
898 ref = other.ref;
899 ref_mu = other.ref_mu;
900 has_value = other.has_value;
901 val_field_is_set = other.val_field_is_set;
902 alloc_attr = other.alloc_attr;
903 device_context = other.device_context;
904 if (val_field_is_set) {
905 val.Init(std::move(*other.val));
906 }
907 return *this;
908 }
909
910 // Clears the <val> field.
ClearValtensorflow::__anon6f8fc96b0111::ExecutorState::Entry911 void ClearVal() {
912 if (val_field_is_set) {
913 val.Destroy();
914 val_field_is_set = false;
915 has_value = false;
916 }
917 }
918
919 // A tensor value, if val_field_is_set.
920 ManualConstructor<Tensor> val;
921
922 Tensor* ref = nullptr; // A tensor reference.
923 mutex* ref_mu = nullptr; // mutex for *ref if ref is not nullptr.
924
925 // Whether the value exists, either in <val> or <ref>.
926 bool has_value = false;
927
928 bool val_field_is_set = false;
929
930 // The attributes of the allocator that creates the tensor.
931 AllocatorAttributes alloc_attr;
932
933 // Every entry carries an optional DeviceContext containing
934 // Device-specific information about how the Tensor was produced.
935 DeviceContext* device_context = nullptr;
936 };
937
938 // Contains a value for [node->id()] for the device context assigned by the
939 // device at the beginning of a step.
940 DeviceContextMap device_context_map_;
941
942 struct TaggedNode;
943 typedef gtl::InlinedVector<TaggedNode, 8> TaggedNodeSeq;
944 typedef gtl::InlinedVector<Entry, 4> EntryVector;
945
946 struct IterationState {
IterationStatetensorflow::__anon6f8fc96b0111::ExecutorState::IterationState947 explicit IterationState(const PendingCounts* pending_counts,
948 int total_input_tensors)
949 : input_tensors(new Entry[total_input_tensors]),
950 outstanding_ops(0),
951 outstanding_frame_count(0),
952 counts_(*pending_counts) { // Initialize with copy of *pending_counts
953 }
954
955 // The state of an iteration.
956
957 // One copy per iteration. For iteration k, i-th node's j-th input is in
958 // input_tensors[k][impl_->nodes[i].input_start + j]. An entry is either
959 // a tensor pointer (pass-by-reference) or a tensor (pass-by-value).
960 //
961 // NOTE: No need to protect input_tensors[i] by any locks because it
962 // is resized once. Each element of tensors_ is written once by the
963 // source node of an edge and is cleared by the destination of the same
964 // edge. The latter node is never run concurrently with the former node.
965 Entry* input_tensors;
966
967 // The number of outstanding ops for each iteration.
968 size_t outstanding_ops;
969
970 // The number of outstanding frames for each iteration.
971 int outstanding_frame_count;
pendingtensorflow::__anon6f8fc96b0111::ExecutorState::IterationState972 int pending(PendingCounts::Handle h) { return counts_.pending(h); }
decrement_pendingtensorflow::__anon6f8fc96b0111::ExecutorState::IterationState973 int decrement_pending(PendingCounts::Handle h, int v) {
974 return counts_.decrement_pending(h, v);
975 }
976 // Mark a merge node as live
977 // REQUIRES: Node corresponding to "h" is a merge node
mark_livetensorflow::__anon6f8fc96b0111::ExecutorState::IterationState978 void mark_live(PendingCounts::Handle h) { counts_.mark_live(h); }
979 // Mark a node to show that processing has started.
mark_startedtensorflow::__anon6f8fc96b0111::ExecutorState::IterationState980 void mark_started(PendingCounts::Handle h) { counts_.mark_started(h); }
981 // Mark a node to show that processing has completed.
mark_completedtensorflow::__anon6f8fc96b0111::ExecutorState::IterationState982 void mark_completed(PendingCounts::Handle h) { counts_.mark_completed(h); }
node_statetensorflow::__anon6f8fc96b0111::ExecutorState::IterationState983 PendingCounts::NodeState node_state(PendingCounts::Handle h) {
984 return counts_.node_state(h);
985 }
986
dead_counttensorflow::__anon6f8fc96b0111::ExecutorState::IterationState987 int dead_count(PendingCounts::Handle h) { return counts_.dead_count(h); }
increment_dead_counttensorflow::__anon6f8fc96b0111::ExecutorState::IterationState988 void increment_dead_count(PendingCounts::Handle h) {
989 counts_.increment_dead_count(h);
990 }
adjust_for_activationtensorflow::__anon6f8fc96b0111::ExecutorState::IterationState991 void adjust_for_activation(PendingCounts::Handle h, bool increment_dead,
992 int* pending_result, int* dead_result) {
993 counts_.adjust_for_activation(h, increment_dead, pending_result,
994 dead_result);
995 }
996
~IterationStatetensorflow::__anon6f8fc96b0111::ExecutorState::IterationState997 ~IterationState() { delete[] input_tensors; }
998
999 private:
1000 PendingCounts counts_;
1001 };
1002
1003 struct FrameState {
FrameStatetensorflow::__anon6f8fc96b0111::ExecutorState::FrameState1004 explicit FrameState(const ExecutorImpl* impl, int parallel_iters)
1005 : executor(impl),
1006 max_parallel_iterations(parallel_iters),
1007 num_outstanding_iterations(1) {}
1008
1009 // A new frame is created for each loop. Execution starts at iteration 0.
1010 // When a value at iteration 0 passes through a NextIteration node,
1011 // iteration 1 is created and starts running. Note that iteration 0 may
1012 // still be running so multiple iterations may run in parallel. The
1013 // frame maintains the state of iterations in several data structures
1014 // such as pending_count and input_tensors. When iteration 0 completes,
1015 // we garbage collect the state of iteration 0.
1016 //
1017 // A frame instance is considered "done" and can be garbage collected
1018 // if all its inputs have entered and all its iterations are "done".
1019 //
1020 // A frame manages the live iterations of an iterative computation.
1021 // Iteration i is considered "done" when there are no outstanding ops,
1022 // frames at iteration i are done, all recvs for this iteration are
1023 // completed, and iteration i-1 is done. For iteration 0, we instead
1024 // wait for there to be no more pending inputs of the frame.
1025 //
1026 // Frames and iterations are garbage collected once they are done.
1027 // The state we need to keep around is highly dependent on the
1028 // parallelism enabled by the scheduler. We may want to have the
1029 // scheduler dynamically control the outstanding number of live
1030 // parallel frames and iterations. To reduce the state space, the
1031 // scheduler might want to schedule ops in inner frames first and
1032 // lower iterations first.
1033 //
1034 // This frame state is mostly initialized lazily on demand so we
1035 // don't introduce unnecessary overhead.
1036
1037 // The executor the frame is in.
1038 const ExecutorImpl* executor = nullptr;
1039
1040 // The name of this frame, which is the concatenation of its parent
1041 // frame name, the iteration of the parent frame when this frame was
1042 // created, and the value of the attr 'frame_name'.
1043 string frame_name;
1044
1045 // The unique id for this frame. Generated by fingerprinting
1046 // frame_name.
1047 uint64 frame_id;
1048
1049 // The iteration id of its parent frame when this frame is created.
1050 // -1 if there is no parent frame. The frame_name/parent_iter pair
1051 // uniquely identifies this FrameState.
1052 int64 parent_iter = -1;
1053
1054 // The FrameState of its parent frame.
1055 FrameState* parent_frame = nullptr;
1056
1057 // The maximum allowed number of parallel iterations.
1058 const int max_parallel_iterations;
1059
1060 // The number of inputs this frame is still waiting.
1061 int num_pending_inputs = 0;
1062
1063 // The highest iteration number we have reached so far in this frame.
1064 int64 iteration_count GUARDED_BY(mu) = 0;
1065
1066 // The number of outstanding iterations.
1067 int num_outstanding_iterations GUARDED_BY(mu) = 1;
1068
1069 // The active iteration states of this frame.
1070 gtl::InlinedVector<IterationState*, 12> iterations;
1071
1072 // The NextIteration nodes to enter a new iteration. If the number of
1073 // outstanding iterations reaches the limit, we will defer the start of
1074 // the next iteration until the number of outstanding iterations falls
1075 // below the limit.
1076 std::vector<std::pair<const Node*, Entry>> next_iter_roots GUARDED_BY(mu);
1077
1078 // The values of the loop invariants for this loop. They are added into
1079 // this list as they "enter" the frame. When a loop invariant enters,
1080 // we make it available to all active iterations. When the frame starts
1081 // a new iteration, we make all the current loop invariants available
1082 // to the new iteration.
1083 std::vector<std::pair<const Node*, Entry>> inv_values GUARDED_BY(mu);
1084
1085 // The list of dead exit nodes for the current highest iteration. We
1086 // will only "execute" the dead exits of the final iteration.
1087 std::vector<const Node*> dead_exits GUARDED_BY(mu);
1088
1089 // Static information specific to this frame.
1090 PendingCounts* pending_counts = nullptr;
1091 int total_input_tensors = 0;
1092 std::vector<const Node*>* nodes = nullptr;
1093
1094 // Lock ordering: ExecutorState.mu_ < mu;
1095 // during structured traversal: parent_frame->mu < mu.
1096 mutex mu;
1097
InitializeFrameInfotensorflow::__anon6f8fc96b0111::ExecutorState::FrameState1098 void InitializeFrameInfo(const string& enter_name) {
1099 auto it_frame_info = executor->frame_info_.find(enter_name);
1100 DCHECK(it_frame_info != executor->frame_info_.end());
1101 ExecutorImpl::FrameInfo* finfo = it_frame_info->second;
1102 pending_counts = finfo->pending_counts;
1103 total_input_tensors = finfo->total_inputs;
1104 num_pending_inputs = finfo->input_count;
1105 nodes = finfo->nodes;
1106 }
1107
GetIterationtensorflow::__anon6f8fc96b0111::ExecutorState::FrameState1108 inline IterationState* GetIteration(int64 iter)
1109 EXCLUSIVE_LOCKS_REQUIRED(mu) {
1110 size_t index = iter % iterations.size();
1111 return iterations[index];
1112 }
1113
SetIterationtensorflow::__anon6f8fc96b0111::ExecutorState::FrameState1114 inline void SetIteration(int64 iter, IterationState* state)
1115 EXCLUSIVE_LOCKS_REQUIRED(mu) {
1116 size_t index = iter % iterations.size();
1117 DCHECK(state == nullptr || iterations[index] == nullptr);
1118 iterations[index] = state;
1119 }
1120
1121 // Decrement the outstanding op count and clean up the iterations in the
1122 // frame. Return true iff the execution of the frame is done.
DecrementOutstandingOpstensorflow::__anon6f8fc96b0111::ExecutorState::FrameState1123 inline bool DecrementOutstandingOps(const GraphView* gview, int64 iter,
1124 TaggedNodeSeq* ready) {
1125 mutex_lock l(mu);
1126 return DecrementOutstandingOpsLocked(gview, iter, ready);
1127 }
1128
1129 // Decrement the outstanding op count and clean up the iterations in the
1130 // frame. Return true iff the execution of the frame is done.
DecrementOutstandingOpsLockedtensorflow::__anon6f8fc96b0111::ExecutorState::FrameState1131 inline bool DecrementOutstandingOpsLocked(const GraphView* gview,
1132 int64 iter, TaggedNodeSeq* ready)
1133 EXCLUSIVE_LOCKS_REQUIRED(mu) {
1134 IterationState* istate = GetIteration(iter);
1135 istate->outstanding_ops--;
1136 if (istate->outstanding_ops != 0) {
1137 return false;
1138 } else {
1139 return CleanupIterations(gview, iter, ready);
1140 }
1141 }
1142
1143 // Returns true if the computation in the frame is completed.
IsFrameDonetensorflow::__anon6f8fc96b0111::ExecutorState::FrameState1144 inline bool IsFrameDone() EXCLUSIVE_LOCKS_REQUIRED(mu) {
1145 return (num_pending_inputs == 0 && num_outstanding_iterations == 0);
1146 }
1147
1148 // Returns true if the iteration of the frame is completed.
1149 bool IsIterationDone(int64 iter) EXCLUSIVE_LOCKS_REQUIRED(mu);
1150
1151 // Increments the iteration id. If this is a new iteration, initialize it.
1152 void IncrementIteration(const GraphView* gview, TaggedNodeSeq* ready)
1153 EXCLUSIVE_LOCKS_REQUIRED(mu);
1154
1155 // Activate all the deferred NextIteration nodes in a new iteration.
1156 void ActivateNexts(const GraphView* gview, int64 iter, TaggedNodeSeq* ready)
1157 EXCLUSIVE_LOCKS_REQUIRED(mu);
1158
1159 // Activate all the current loop invariants in a new iteration.
1160 void ActivateLoopInvs(const GraphView* gview, int64 iter,
1161 TaggedNodeSeq* ready) EXCLUSIVE_LOCKS_REQUIRED(mu);
1162
1163 // Add a new loop invariant and make it available to all active
1164 // iterations.
1165 void AddLoopInv(const NodeItem* item, const Entry& value,
1166 TaggedNodeSeq* ready) EXCLUSIVE_LOCKS_REQUIRED(mu);
1167
1168 // Activate the successors of a node. Contents of *outputs are left in an
1169 // indeterminate state after returning from this method.
1170 void ActivateNodes(const NodeItem* item, const bool is_dead, int64 iter,
1171 EntryVector* outputs, TaggedNodeSeq* ready)
1172 EXCLUSIVE_LOCKS_REQUIRED(mu);
1173
1174 // Cleanup iterations of this frame starting from iteration iter.
1175 bool CleanupIterations(const GraphView* gview, int64 iter,
1176 TaggedNodeSeq* ready) EXCLUSIVE_LOCKS_REQUIRED(mu);
1177
~FrameStatetensorflow::__anon6f8fc96b0111::ExecutorState::FrameState1178 ~FrameState() {
1179 for (size_t i = 0; i < iterations.size(); ++i) {
1180 delete iterations[i];
1181 iterations[i] = nullptr;
1182 }
1183 }
1184 };
1185
1186 // A tagged node: <frame*, iter, node*>.
1187 struct TaggedNode {
1188 const Node* node = nullptr;
1189 FrameState* input_frame = nullptr;
1190 int64 input_iter = -1;
1191 bool is_dead = false;
1192
TaggedNodetensorflow::__anon6f8fc96b0111::ExecutorState::TaggedNode1193 TaggedNode(const Node* t_node, FrameState* in_frame, int64 in_iter,
1194 bool dead) {
1195 node = t_node;
1196 input_frame = in_frame;
1197 input_iter = in_iter;
1198 is_dead = dead;
1199 }
1200 };
1201
1202 // A drop-in replacement for std::deque<TaggedNode>. We typically don't
1203 // have that many nodes in the ready queue, so we just use a vector and
1204 // don't free up memory from the queue as we consume nodes.
1205 class TaggedNodeReadyQueue {
1206 public:
TaggedNodeReadyQueue()1207 TaggedNodeReadyQueue() : front_index_(0) {}
1208
push_back(TaggedNode node)1209 void push_back(TaggedNode node) { ready_.push_back(node); }
front() const1210 TaggedNode front() const {
1211 DCHECK_LT(front_index_, ready_.size());
1212 return ready_[front_index_];
1213 }
pop_front()1214 void pop_front() {
1215 DCHECK_LT(front_index_, ready_.size());
1216 front_index_++;
1217 if ((front_index_ == ready_.size()) || (front_index_ > 16384)) {
1218 if (front_index_ == ready_.size()) {
1219 ready_.clear();
1220 } else {
1221 // Lots of unused entries at beginning of vector: move everything
1222 // down to start of vector.
1223 ready_.erase(ready_.begin(), ready_.begin() + front_index_);
1224 }
1225 front_index_ = 0;
1226 }
1227 }
empty() const1228 bool empty() const { return ready_.empty(); }
begin() const1229 const TaggedNode* begin() const { return ready_.begin() + front_index_; }
end() const1230 const TaggedNode* end() const { return ready_.end(); }
1231
1232 private:
1233 gtl::InlinedVector<TaggedNode, 16> ready_;
1234 int front_index_;
1235 };
1236
1237 struct AsyncState;
1238
1239 const bool vlog_; // true if VLOG_IS_ON(1). Used to check vlog cheaply.
1240
1241 // true if LogMemory::IsEnabled(). Used to check memory enabled cheaply.
1242 const bool log_memory_;
1243
1244 int64 step_id_;
1245 // Not owned.
1246 Rendezvous* rendezvous_;
1247 CollectiveExecutor* collective_executor_ = nullptr;
1248 SessionState* session_state_;
1249 string session_handle_;
1250 TensorStore* tensor_store_;
1251 // Step-local container.
1252 ScopedStepContainer* step_container_;
1253 StepStatsCollectorInterface* const stats_collector_;
1254 const tracing::EventCollector* const event_collector_;
1255 Context context_;
1256
1257 // QUESTION: Make it a checkpoint::TensorSliceReaderCacheWrapper
1258 // instead of a pointer? (avoids having to delete).
1259 checkpoint::TensorSliceReaderCacheWrapper* slice_reader_cache_;
1260 CallFrameInterface* call_frame_;
1261 const ExecutorImpl* impl_;
1262 CancellationManager* cancellation_manager_;
1263 Executor::Args::Runner runner_;
1264 bool sync_on_finish_;
1265 const bool trace_using_annotations_;
1266
1267 // Owned.
1268
1269 // A flag that is set on error after the frame state has been
1270 // dumped for diagnostic purposes.
1271 bool dumped_on_error_ = false;
1272
1273 // The root frame in which the execution of this step is started.
1274 FrameState* root_frame_;
1275
1276 // Invoked when the execution finishes.
1277 Executor::DoneCallback done_cb_;
1278
1279 std::atomic_int_fast32_t num_outstanding_ops_;
1280
1281 // Available via OpKernelContext to every OpKernel invocation.
1282 mutex num_deferred_ops_mu_;
1283 condition_variable no_deferred_ops_cv_;
1284 int64 num_deferred_ops_ GUARDED_BY(num_deferred_ops_mu_) = 0;
1285
1286 mutex mu_;
1287 Status status_ GUARDED_BY(mu_);
1288
1289 // Mapping from frame name to outstanding frames. A new frame is created
1290 // at some iteration of an active frame. So the unique key for the new
1291 // child frame is composed of the name of the parent frame, the iteration
1292 // number at which the parent frame is creating the new frame, and the
1293 // name of the new frame from nodedef.
1294 gtl::FlatMap<string, FrameState*> outstanding_frames_ GUARDED_BY(mu_);
1295
1296 // The unique name of a frame.
MakeFrameName(FrameState * frame,int64 iter_id,const string & name)1297 inline string MakeFrameName(FrameState* frame, int64 iter_id,
1298 const string& name) {
1299 return strings::StrCat(frame->frame_name, ";", iter_id, ";", name);
1300 }
1301
1302 // Find an existing or create a new child frame in the frame 'frame' at
1303 // iteration 'iter'.
1304 void FindOrCreateChildFrame(FrameState* frame, int64 iter, const Node* node,
1305 FrameState** child);
1306
1307 // Delete a frame. Called when the frame is done.
1308 void DeleteFrame(FrameState* frame, TaggedNodeSeq* ready);
1309
1310 // Cleanup frames and iterations starting from frame/iter. Called when
1311 // a child frame is done.
1312 void CleanupFramesIterations(FrameState* frame, int64 iter,
1313 TaggedNodeSeq* ready);
1314
1315 // Process a ready node in current thread.
1316 void Process(TaggedNode node, int64 scheduled_nsec);
1317
1318 // Before invoking item->kernel, fills in its "inputs".
1319 Status PrepareInputs(const NodeItem& item, Entry* first_input,
1320 TensorValueVec* inputs,
1321 DeviceContextVec* input_device_contexts,
1322 AllocatorAttributeVec* input_alloc_attrs,
1323 bool* is_input_dead);
1324
1325 // After item->kernel computation is done, processes its outputs.
1326 Status ProcessOutputs(const NodeItem& item, OpKernelContext* ctx,
1327 EntryVector* outputs, NodeExecStatsInterface* stats);
1328
1329 // After processing the outputs, propagates the outputs to their dsts.
1330 // Contents of *outputs are left in an indeterminate state after
1331 // returning from this method.
1332 void PropagateOutputs(const TaggedNode& tagged_node, const NodeItem* item,
1333 EntryVector* outputs, TaggedNodeSeq* ready);
1334
1335 // "node" just finishes. Takes ownership of "stats". Returns true if
1336 // execution has completed.
1337 bool NodeDone(const Status& s, const Node* node, const TaggedNodeSeq& ready,
1338 NodeExecStatsInterface* stats,
1339 TaggedNodeReadyQueue* inline_ready);
1340
1341 // Schedule all the expensive nodes in 'ready', and put all the inexpensive
1342 // nodes in 'ready' into 'inline_ready'.
1343 void ScheduleReady(const TaggedNodeSeq& ready,
1344 TaggedNodeReadyQueue* inline_ready);
1345
1346 // For debugging/logging only.
1347 inline void MaybeMarkCompleted(FrameState* frame, int64 iter, int64 id);
1348
1349 // Provide debugging output about an outstanding node in the executor.
1350 void DumpPendingNodeState(const int node_id, const Entry* input_vector,
1351 bool show_nodes_with_no_ready_inputs);
1352 void DumpActiveNodeState(const int node_id, const Entry* input_vector);
1353
1354 // Provide debugging output about an outstanding iteration in the executor.
1355 void DumpIterationState(const FrameState* frame, IterationState* iteration);
1356
1357 // Provide debugging output of the state of the executor.
1358 void DumpState();
1359 const Tensor* GetTensorValueForDump(const Entry& input);
1360
1361 // Clean up when this executor is done.
1362 void Finish();
1363 // Schedule Finish() on a separate thread if it needs to wait for deferred
1364 // async ops to complete; otherwise run it on the current thread.
1365 void ScheduleFinish();
1366
1367 // A standalone routine for this expression so that we can express
1368 // that we don't want thread safety analysis on this reference (it's
1369 // safe to do without the lock because the iterations array never
1370 // resizes and this particular iteration's array element will not
1371 // be changed out from under us because the iteration is still alive).
GetInputTensors(FrameState * input_frame,int64 input_iter) const1372 Entry* GetInputTensors(FrameState* input_frame,
1373 int64 input_iter) const NO_THREAD_SAFETY_ANALYSIS {
1374 return input_frame->GetIteration(input_iter)->input_tensors;
1375 }
1376 };
1377
ExecutorState(const Executor::Args & args,ExecutorImpl * impl)1378 ExecutorState::ExecutorState(const Executor::Args& args, ExecutorImpl* impl)
1379 : vlog_(VLOG_IS_ON(1)),
1380 log_memory_(LogMemory::IsEnabled()),
1381 step_id_(args.step_id),
1382 rendezvous_(args.rendezvous),
1383 collective_executor_(args.collective_executor),
1384 session_state_(args.session_state),
1385 session_handle_(args.session_handle),
1386 tensor_store_(args.tensor_store),
1387 step_container_(args.step_container),
1388 stats_collector_(args.stats_collector),
1389 event_collector_(
1390 tracing::GetEventCollector(tracing::EventCategory::kCompute)),
1391 context_(ContextKind::kThread),
1392 slice_reader_cache_(new checkpoint::TensorSliceReaderCacheWrapper),
1393 call_frame_(args.call_frame),
1394 impl_(impl),
1395 cancellation_manager_(args.cancellation_manager),
1396 runner_(args.runner),
1397 sync_on_finish_(args.sync_on_finish),
1398 trace_using_annotations_(impl->params_.device->TraceUsingAnnotations()),
1399 num_outstanding_ops_(0) {
1400 // We start the entire execution in iteration 0 of the root frame
1401 // so let us create the root frame and the state for iteration 0.
1402 // We assume root_frame_->frame_name.empty().
1403 root_frame_ = new FrameState(impl_, 1);
1404 root_frame_->frame_id = 0; // must be 0
1405 root_frame_->InitializeFrameInfo(root_frame_->frame_name);
1406
1407 // Initialize iteration 0.
1408 root_frame_->iterations.resize(root_frame_->max_parallel_iterations);
1409 root_frame_->iterations[0] = new IterationState(
1410 root_frame_->pending_counts, root_frame_->total_input_tensors);
1411
1412 outstanding_frames_.insert({root_frame_->frame_name, root_frame_});
1413 }
1414
~ExecutorState()1415 ExecutorState::~ExecutorState() {
1416 for (auto name_frame : outstanding_frames_) {
1417 delete name_frame.second;
1418 }
1419 for (auto it : device_context_map_) {
1420 it->Unref();
1421 }
1422 delete slice_reader_cache_;
1423 }
1424
BuildControlFlowInfo(const Graph * g,ControlFlowInfo * cf_info)1425 Status ExecutorImpl::BuildControlFlowInfo(const Graph* g,
1426 ControlFlowInfo* cf_info) {
1427 const int num_nodes = g->num_node_ids();
1428 cf_info->frame_names.resize(num_nodes);
1429 std::vector<Node*> parent_nodes;
1430 parent_nodes.resize(num_nodes);
1431 std::vector<bool> visited;
1432 visited.resize(num_nodes);
1433
1434 string frame_name;
1435 std::deque<Node*> ready;
1436
1437 // Initialize with the root nodes.
1438 for (Node* n : g->nodes()) {
1439 if (n->in_edges().empty()) {
1440 visited[n->id()] = true;
1441 cf_info->unique_frame_names.insert(frame_name);
1442 ready.push_back(n);
1443 }
1444 }
1445
1446 while (!ready.empty()) {
1447 Node* curr_node = ready.front();
1448 int curr_id = curr_node->id();
1449 ready.pop_front();
1450
1451 Node* parent = nullptr;
1452 if (IsEnter(curr_node)) {
1453 // Enter a child frame.
1454 TF_RETURN_IF_ERROR(
1455 GetNodeAttr(curr_node->attrs(), "frame_name", &frame_name));
1456 parent = curr_node;
1457 } else if (IsExit(curr_node)) {
1458 // Exit to the parent frame.
1459 parent = parent_nodes[curr_id];
1460 frame_name = cf_info->frame_names[parent->id()];
1461 parent = parent_nodes[parent->id()];
1462 } else {
1463 parent = parent_nodes[curr_id];
1464 frame_name = cf_info->frame_names[curr_id];
1465 }
1466
1467 for (const Edge* out_edge : curr_node->out_edges()) {
1468 Node* out = out_edge->dst();
1469 const int out_id = out->id();
1470
1471 // Add to ready queue if not visited.
1472 bool is_visited = visited[out_id];
1473 if (!is_visited) {
1474 ready.push_back(out);
1475 visited[out_id] = true;
1476
1477 // Process the node 'out'.
1478 cf_info->frame_names[out_id] = frame_name;
1479 parent_nodes[out_id] = parent;
1480 cf_info->unique_frame_names.insert(frame_name);
1481 }
1482 }
1483 }
1484
1485 return Status::OK();
1486 }
1487
InitializePending(const Graph * graph,const ControlFlowInfo & cf_info)1488 void ExecutorImpl::InitializePending(const Graph* graph,
1489 const ControlFlowInfo& cf_info) {
1490 for (auto& it : cf_info.unique_frame_names) {
1491 FrameInfo* finfo = EnsureFrameInfo(it);
1492 PendingCounts* counts = new PendingCounts(finfo->pending_counts_layout);
1493 DCHECK_EQ(finfo->pending_counts, nullptr);
1494 finfo->pending_counts = counts;
1495 }
1496 for (const Node* n : graph->nodes()) {
1497 const int id = n->id();
1498 const string& name = cf_info.frame_names[id];
1499 size_t max_pending, max_dead;
1500 GetMaxPendingCounts(n, &max_pending, &max_dead);
1501 const NodeItem* item = gview_.node(id);
1502 PendingCounts* counts = EnsureFrameInfo(name)->pending_counts;
1503 counts->set_initial_count(item->pending_id, max_pending);
1504 }
1505 }
1506
RunAsync(Executor::DoneCallback done)1507 void ExecutorState::RunAsync(Executor::DoneCallback done) {
1508 const Graph* graph = impl_->graph_.get();
1509 TaggedNodeSeq ready;
1510
1511 // Ask the device to fill in the device context map.
1512 Device* device = impl_->params_.device;
1513 const Status fill_status =
1514 device->FillContextMap(graph, &device_context_map_);
1515 if (!fill_status.ok()) {
1516 delete this;
1517 done(fill_status);
1518 return;
1519 }
1520
1521 // Initialize the ready queue.
1522 for (const Node* n : impl_->root_nodes_) {
1523 DCHECK_EQ(n->in_edges().size(), 0);
1524 ready.push_back(TaggedNode{n, root_frame_, 0, false});
1525 }
1526 if (ready.empty()) {
1527 delete this;
1528 done(Status::OK());
1529 } else {
1530 num_outstanding_ops_ = ready.size();
1531 root_frame_->iterations[0]->outstanding_ops = ready.size();
1532 done_cb_ = std::move(done);
1533 // Schedule to run all the ready ops in thread pool.
1534 ScheduleReady(ready, nullptr);
1535 }
1536 }
1537
1538 // State kept alive for executing an asynchronous node in another
1539 // thread. NOTE: We need to make a copy of p.input,
1540 // p.input_device_contexts, and p.input_alloc_attrs for asynchronous
1541 // kernels because OpKernelContext methods like input_type(i) needs
1542 // the param points to valid input type vector. It's not an issue for
1543 // sync kernels because these vectors are kept on the stack.
1544 struct ExecutorState::AsyncState {
AsyncStatetensorflow::__anon6f8fc96b0111::ExecutorState::AsyncState1545 AsyncState(const OpKernelContext::Params& p, const TaggedNode& _tagged_node,
1546 const NodeItem* _item, Entry* _first_input,
1547 NodeExecStatsInterface* _stats)
1548 : saved_inputs(*p.inputs),
1549 saved_input_device_contexts(*p.input_device_contexts),
1550 saved_input_alloc_attrs(*p.input_alloc_attrs),
1551 params(p),
1552 tagged_node(_tagged_node),
1553 item(_item),
1554 first_input(_first_input),
1555 // ParamsButClearingEigenGPUDevice does equivalent of
1556 // params.eigen_gpu_device = nullptr;
1557 ctx(ParamsButClearingEigenGPUDevice(¶ms), item->num_outputs),
1558 stats(_stats) {
1559 params.inputs = &saved_inputs;
1560 params.input_device_contexts = &saved_input_device_contexts;
1561 params.input_alloc_attrs = &saved_input_alloc_attrs;
1562 }
1563
1564 TensorValueVec saved_inputs;
1565 DeviceContextVec saved_input_device_contexts;
1566 AllocatorAttributeVec saved_input_alloc_attrs;
1567 OpKernelContext::Params params;
1568 TaggedNode tagged_node;
1569 const NodeItem* item;
1570 Entry* first_input;
1571 OpKernelContext ctx;
1572 NodeExecStatsInterface* stats;
1573
1574 private:
ParamsButClearingEigenGPUDevicetensorflow::__anon6f8fc96b0111::ExecutorState::AsyncState1575 OpKernelContext::Params* ParamsButClearingEigenGPUDevice(
1576 OpKernelContext::Params* p) {
1577 // Ensure OpKernelContext constructor will make a new eigen GPU device if
1578 // necessary.
1579 p->eigen_gpu_device = nullptr; // Force allocation
1580 return p;
1581 }
1582 };
1583
1584 // Returns true if `item` might be traced by the given trace and event
1585 // collectors. Returns false only if `item` definitely will not be traced.
MightTrace(const NodeItem & item,const tracing::EventCollector * event_collector,bool using_annotations)1586 bool MightTrace(const NodeItem& item,
1587 const tracing::EventCollector* event_collector,
1588 bool using_annotations) {
1589 // Tracing will only be enabled if either `event_collector` is non null,
1590 // or `trace_collector` is non-null and enabled for this particular kernel.
1591 // Although `tracing::ScopedActivity`,
1592 // `tracing::ScopedAnnotation`, and `tracing::ScopedRegion` check subsets of
1593 // these properties internally in their constructors, the cost of passing the
1594 // necessary arguments to them can be significant, so we avoid constructing
1595 // them in the common case (when we know they will not be used).
1596 if (event_collector != nullptr) {
1597 return true;
1598 }
1599 auto* trace_collector = tracing::GetTraceCollector();
1600 if (trace_collector) {
1601 if (using_annotations) {
1602 return trace_collector->IsEnabledForAnnotations();
1603 } else {
1604 return trace_collector->IsEnabledForActivities(
1605 item.kernel->IsExpensive());
1606 }
1607 }
1608 return false;
1609 }
1610
Process(TaggedNode tagged_node,int64 scheduled_nsec)1611 void ExecutorState::Process(TaggedNode tagged_node, int64 scheduled_nsec) {
1612 WithContext wc(context_);
1613 const GraphView& gview = impl_->gview_;
1614 TaggedNodeSeq ready;
1615 TaggedNodeReadyQueue inline_ready;
1616
1617 // Parameters passed to OpKernel::Compute.
1618 TensorValueVec inputs;
1619 DeviceContextVec input_device_contexts;
1620 AllocatorAttributeVec input_alloc_attrs;
1621
1622 OpKernelContext::Params params;
1623 params.step_id = step_id_;
1624 Device* device = impl_->params_.device;
1625 params.device = device;
1626 params.log_memory = log_memory_;
1627 params.record_tensor_accesses = impl_->device_record_tensor_accesses_;
1628 params.rendezvous = rendezvous_;
1629 params.collective_executor = collective_executor_;
1630 params.session_state = session_state_;
1631 params.session_handle = session_handle_;
1632 params.tensor_store = tensor_store_;
1633 params.cancellation_manager = cancellation_manager_;
1634 params.call_frame = call_frame_;
1635 params.function_library = impl_->params_.function_library;
1636 params.resource_manager = device->resource_manager();
1637 params.step_container = step_container_;
1638 params.slice_reader_cache = slice_reader_cache_;
1639 params.inputs = &inputs;
1640 params.input_device_contexts = &input_device_contexts;
1641 params.input_alloc_attrs = &input_alloc_attrs;
1642 params.runner = &runner_;
1643 params.stats_collector = stats_collector_;
1644 params.inc_num_deferred_ops_function = [this]() {
1645 mutex_lock lock(num_deferred_ops_mu_);
1646 num_deferred_ops_++;
1647 };
1648 params.dec_num_deferred_ops_function = [this]() {
1649 mutex_lock lock(num_deferred_ops_mu_);
1650 num_deferred_ops_--;
1651 if (num_deferred_ops_ == 0) {
1652 no_deferred_ops_cv_.notify_all();
1653 }
1654 };
1655
1656 Status s;
1657 NodeExecStatsInterface* stats = nullptr;
1658
1659 EntryVector outputs;
1660 bool completed = false;
1661 inline_ready.push_back(tagged_node);
1662 while (!inline_ready.empty()) {
1663 tagged_node = inline_ready.front();
1664 inline_ready.pop_front();
1665 const Node* node = tagged_node.node;
1666 FrameState* input_frame = tagged_node.input_frame;
1667 const int64 input_iter = tagged_node.input_iter;
1668 const int id = node->id();
1669 const NodeItem& item = *gview.node(id);
1670
1671 // TODO(misard) Replace with a finer-grain enabling flag once we
1672 // add better optional debugging support.
1673 if (vlog_ && VLOG_IS_ON(1)) {
1674 mutex_lock l(input_frame->mu);
1675 input_frame->GetIteration(input_iter)->mark_started(item.pending_id);
1676 }
1677
1678 // Set the device_context for this node id, if it exists.
1679 if (id < device_context_map_.size()) {
1680 params.op_device_context = device_context_map_[id];
1681 }
1682
1683 params.track_allocations = false;
1684 stats = nullptr;
1685 if (stats_collector_ && !tagged_node.is_dead) {
1686 stats = stats_collector_->CreateNodeExecStats(node);
1687 // Track allocations if and only if we are collecting statistics, and
1688 // `stats` object is expecting allocations to be tracked.
1689 params.track_allocations = stats ? stats->TrackAllocations() : false;
1690 nodestats::SetScheduled(stats, scheduled_nsec);
1691 nodestats::SetAllStart(stats);
1692 }
1693
1694 if (vlog_) {
1695 VLOG(1) << "Process node: " << id << " step " << params.step_id << " "
1696 << SummarizeNode(*node) << (tagged_node.is_dead ? " is dead" : "")
1697 << " device: " << device->name();
1698 }
1699
1700 Entry* input_tensors = GetInputTensors(input_frame, input_iter);
1701 Entry* first_input = input_tensors + item.input_start;
1702 outputs.clear();
1703
1704 TensorReferenceVector accessed_tensors;
1705 DeviceContext* device_context = nullptr;
1706 // Only execute this node if it is not dead or it is a send/recv
1707 // transfer node. For transfer nodes, we need to propagate the "dead"
1708 // bit even when the node is dead.
1709 bool launched_asynchronously = false;
1710 if (tagged_node.is_dead && !IsTransferNode(node)) {
1711 outputs.resize(item.num_outputs);
1712 } else {
1713 // Prepares inputs.
1714 bool is_input_dead = false;
1715 s = PrepareInputs(item, first_input, &inputs, &input_device_contexts,
1716 &input_alloc_attrs, &is_input_dead);
1717 if (!s.ok()) {
1718 // Clear inputs.
1719 int num_inputs = item.num_inputs;
1720 for (int i = 0; i < num_inputs; ++i) {
1721 (first_input + i)->ClearVal();
1722 }
1723 MaybeMarkCompleted(input_frame, input_iter, id);
1724 // Continue to process the nodes in 'inline_ready'.
1725 completed = NodeDone(s, item.node, ready, stats, &inline_ready);
1726 continue;
1727 }
1728
1729 // Set up compute params.
1730 OpKernel* op_kernel = item.kernel;
1731 params.op_kernel = op_kernel;
1732 params.frame_iter = FrameAndIter(input_frame->frame_id, input_iter);
1733 params.is_input_dead = is_input_dead;
1734 params.output_attr_array = item.output_attrs();
1735 params.forward_from_array = item.forward_from();
1736
1737 if (item.kernel_is_async) {
1738 // Asynchronous computes.
1739 AsyncOpKernel* async = item.kernel->AsAsync();
1740 DCHECK(async != nullptr);
1741 launched_asynchronously = true;
1742 AsyncState* state =
1743 new AsyncState(params, tagged_node, &item, first_input, stats);
1744
1745 auto done = [this, state]() {
1746 Device* device = impl_->params_.device;
1747 NodeExecStatsInterface* stats = state->stats; // Shorthand
1748 Entry* first_input = state->first_input; // Shorthand
1749
1750 nodestats::SetOpEnd(stats);
1751 EntryVector outputs;
1752 Status s = ProcessOutputs(*state->item, &state->ctx, &outputs, stats);
1753 nodestats::SetMemory(stats, &state->ctx);
1754 if (vlog_) {
1755 VLOG(2) << "Async kernel done: " << state->item->node->id()
1756 << " step " << step_id_ << " "
1757 << SummarizeNode(*state->item->node)
1758 << (state->tagged_node.is_dead ? " is dead" : "")
1759 << " device: " << device->name();
1760 }
1761
1762 // Clears inputs.
1763 const int num_inputs = state->item->num_inputs;
1764 for (int i = 0; i < num_inputs; ++i) {
1765 (first_input + i)->ClearVal();
1766 }
1767 FrameState* input_frame = state->tagged_node.input_frame;
1768 const int64 input_iter = state->tagged_node.input_iter;
1769 const int id = state->tagged_node.node->id();
1770 MaybeMarkCompleted(input_frame, input_iter, id);
1771 TaggedNodeSeq ready;
1772 if (s.ok()) {
1773 PropagateOutputs(state->tagged_node, state->item, &outputs, &ready);
1774 }
1775 outputs.clear();
1776 if (s.ok() && impl_->device_record_tensor_accesses_) {
1777 // Get the list of all tensors accessed during the execution
1778 TensorReferenceVector accessed;
1779 state->ctx.retrieve_accessed_tensors(&accessed);
1780 nodestats::SetReferencedTensors(stats, accessed);
1781 // callee takes ownership of the vector
1782 device->ConsumeListOfAccessedTensors(state->ctx.op_device_context(),
1783 accessed);
1784 }
1785 const bool completed =
1786 NodeDone(s, state->item->node, ready, stats, nullptr);
1787 delete state;
1788 if (completed) ScheduleFinish();
1789 };
1790 nodestats::SetOpStart(stats);
1791 device->ComputeAsync(async, &state->ctx, done);
1792 } else {
1793 // Synchronous computes.
1794 OpKernelContext ctx(¶ms, item.num_outputs);
1795 nodestats::SetOpStart(stats);
1796
1797 if (TF_PREDICT_FALSE(
1798 MightTrace(item, event_collector_, trace_using_annotations_))) {
1799 const string& op_name = op_kernel->name();
1800 tracing::ScopedRegion region(tracing::EventCategory::kCompute,
1801 op_name);
1802 if (trace_using_annotations_) {
1803 // The OpKernel may create child activities (such as GPU kernel
1804 // launches), so use a `ScopedAnnotation` to relate these activities
1805 // in the trace.
1806 tracing::ScopedAnnotation activity(
1807 op_name, strings::StrCat(op_kernel->type_string(),
1808 "#id=", step_id_, "#"));
1809 device->Compute(op_kernel, &ctx);
1810 } else {
1811 // Use the cheaper `ScopedActivity` to trace just the OpKernel
1812 // execution.
1813 tracing::ScopedActivity activity(
1814 op_name,
1815 strings::StrCat(op_kernel->type_string(), "#id=", step_id_,
1816 "#"),
1817 item.kernel->IsExpensive());
1818 device->Compute(op_kernel, &ctx);
1819 }
1820 } else {
1821 // In the common case, avoid creating any tracing objects.
1822 if (op_kernel->IsExpensive()) {
1823 KernelTimer timer;
1824 device->Compute(op_kernel, &ctx);
1825 op_kernel->UpdateCostEstimate(timer.ElapsedCycles());
1826 } else {
1827 device->Compute(op_kernel, &ctx);
1828 }
1829 }
1830
1831 nodestats::SetOpEnd(stats);
1832 s = ProcessOutputs(item, &ctx, &outputs, stats);
1833 if (s.ok() && impl_->device_record_tensor_accesses_) {
1834 // Get the list of all tensors accessed during the execution
1835 ctx.retrieve_accessed_tensors(&accessed_tensors);
1836 device_context = ctx.op_device_context();
1837 }
1838 nodestats::SetMemory(stats, &ctx);
1839 }
1840 }
1841
1842 if (!launched_asynchronously) {
1843 if (vlog_) {
1844 VLOG(2) << "Synchronous kernel done: " << id << " step "
1845 << params.step_id << " " << SummarizeNode(*node)
1846 << (tagged_node.is_dead ? " is dead: " : "")
1847 << " device: " << device->name();
1848 }
1849
1850 // Clears inputs.
1851 const int num_inputs = item.num_inputs;
1852 for (int i = 0; i < num_inputs; ++i) {
1853 (first_input + i)->ClearVal();
1854 }
1855 MaybeMarkCompleted(input_frame, input_iter, id);
1856 // Propagates outputs.
1857 if (s.ok()) {
1858 PropagateOutputs(tagged_node, &item, &outputs, &ready);
1859 }
1860 outputs.clear();
1861 if (!accessed_tensors.empty()) {
1862 nodestats::SetReferencedTensors(stats, accessed_tensors);
1863 // device_context is set above in synchronous computes
1864 device->ConsumeListOfAccessedTensors(device_context, accessed_tensors);
1865 }
1866 if (stats) {
1867 scheduled_nsec = nodestats::NowInNsec();
1868 }
1869 // Postprocess.
1870 completed = NodeDone(s, item.node, ready, stats, &inline_ready);
1871 }
1872 } // while !inline_ready.empty()
1873
1874 // This thread of computation is done if completed = true.
1875 if (completed) ScheduleFinish();
1876 }
1877
PrepareInputs(const NodeItem & item,Entry * first_input,TensorValueVec * inputs,DeviceContextVec * input_device_contexts,AllocatorAttributeVec * input_alloc_attrs,bool * is_input_dead)1878 Status ExecutorState::PrepareInputs(const NodeItem& item, Entry* first_input,
1879 TensorValueVec* inputs,
1880 DeviceContextVec* input_device_contexts,
1881 AllocatorAttributeVec* input_alloc_attrs,
1882 bool* is_input_dead) {
1883 const Node* node = item.node;
1884
1885 inputs->clear();
1886 inputs->resize(item.num_inputs);
1887 input_device_contexts->clear();
1888 input_device_contexts->resize(item.num_inputs);
1889 input_alloc_attrs->clear();
1890 input_alloc_attrs->resize(item.num_inputs);
1891
1892 *is_input_dead = false;
1893
1894 bool is_merge = item.is_merge;
1895 for (int i = 0; i < item.num_inputs; ++i) {
1896 const bool expect_ref = IsRefType(item.input_type(i));
1897 Entry* entry = first_input + i;
1898 (*input_device_contexts)[i] = entry->device_context;
1899 (*input_alloc_attrs)[i] = entry->alloc_attr;
1900
1901 // i-th input.
1902 TensorValue* inp = &(*inputs)[i];
1903
1904 // Only merge and transfer nodes can have no-value inputs.
1905 if (!entry->has_value) {
1906 if (!is_merge) {
1907 DCHECK(IsTransferNode(node)) << node->name() << " - input " << i;
1908 DCHECK(!entry->val_field_is_set) << node->name() << " - input " << i;
1909 entry->has_value = true;
1910 entry->val_field_is_set = true;
1911 entry->val.Init(*kEmptyTensor);
1912 inp->tensor = entry->val.get();
1913 *is_input_dead = true;
1914 }
1915 continue;
1916 }
1917 if (entry->ref == nullptr) {
1918 if (expect_ref) {
1919 return AttachDef(
1920 errors::InvalidArgument(i, "-th input expects a ref type"),
1921 item.kernel->def());
1922 }
1923 inp->tensor = entry->val.get();
1924 } else {
1925 {
1926 tf_shared_lock ml(*entry->ref_mu);
1927 if (!entry->ref->IsInitialized() && !IsInitializationOp(item.node)) {
1928 return AttachDef(errors::FailedPrecondition(
1929 "Attempting to use uninitialized value ",
1930 item.kernel->requested_input(i)),
1931 item.kernel->def());
1932 }
1933 }
1934 if (expect_ref) {
1935 inp->mutex_if_ref = entry->ref_mu;
1936 inp->tensor = entry->ref;
1937 } else {
1938 // Automatically deref the tensor ref when the op expects a
1939 // tensor but is given a ref to a tensor. Need to deref it
1940 // under the mutex.
1941 {
1942 tf_shared_lock l(*(entry->ref_mu));
1943 DCHECK(!entry->val_field_is_set);
1944 entry->val.Init(*entry->ref);
1945 entry->val_field_is_set = true;
1946 }
1947 entry->ref = nullptr;
1948 entry->ref_mu = nullptr;
1949
1950 inp->tensor = entry->val.get();
1951 // The dtype of entry->ref could have been changed by another operation
1952 // that ran after the operation that "produced" it executed, so
1953 // re-validate that the type of the dereferenced tensor matches the
1954 // expected input type.
1955 if (item.input_type(i) != inp->tensor->dtype()) {
1956 return AttachDef(
1957 errors::InvalidArgument(
1958 i, "-th input expects type ",
1959 DataTypeString(item.input_type(i)),
1960 " but automatically dereferenced input tensor has type ",
1961 DataTypeString(inp->tensor->dtype())),
1962 item.kernel->def());
1963 }
1964 }
1965 }
1966 }
1967 return Status::OK();
1968 }
1969
ProcessOutputs(const NodeItem & item,OpKernelContext * ctx,EntryVector * outputs,NodeExecStatsInterface * stats)1970 Status ExecutorState::ProcessOutputs(const NodeItem& item, OpKernelContext* ctx,
1971 EntryVector* outputs,
1972 NodeExecStatsInterface* stats) {
1973 const Node* node = item.node;
1974 DCHECK_EQ(0, outputs->size());
1975 outputs->resize(item.num_outputs);
1976
1977 Status s = ctx->status();
1978 if (!s.ok()) {
1979 s = AttachDef(s, item.kernel->def());
1980 // TODO(misard) Replace with a finer-grain enabling flag once we
1981 // add better optional debugging support.
1982 if (vlog_ && VLOG_IS_ON(1)) {
1983 LOG(WARNING) << this << " Compute status: " << s;
1984 DumpState();
1985 }
1986 if (s.code() == error::RESOURCE_EXHAUSTED) {
1987 if (stats_collector_) {
1988 string err = stats_collector_->ReportAllocsOnResourceExhausted(
1989 s.error_message());
1990 s = Status(s.code(), strings::StrCat(s.error_message(), err));
1991 } else {
1992 s = Status(
1993 s.code(),
1994 strings::StrCat(
1995 s.error_message(),
1996 "\nHint: If you want to see a list of allocated tensors when "
1997 "OOM happens, add report_tensor_allocations_upon_oom "
1998 "to RunOptions for current allocation info.\n"));
1999 }
2000 }
2001 return s;
2002 }
2003
2004 // Get the device_context for this node id, if it exists.
2005 DeviceContext* device_context = nullptr;
2006 if (node->id() < device_context_map_.size()) {
2007 device_context = device_context_map_[node->id()];
2008 }
2009
2010 for (int i = 0; i < item.num_outputs; ++i) {
2011 const TensorValue val = ctx->release_output(i);
2012 if (val.tensor == nullptr) {
2013 // Unless it's a Switch or a Recv, the node must produce a
2014 // tensor value at i-th output.
2015 if (!IsSwitch(node) && !IsRecv(node)) {
2016 s.Update(errors::Internal("Missing ", i, "-th output from ",
2017 FormatNodeForError(*node)));
2018 }
2019 } else {
2020 Entry* out = &((*outputs)[i]);
2021
2022 // Set the device context of the output entry.
2023 out->device_context = device_context;
2024
2025 // Set the allocator attributes of the output entry.
2026 out->alloc_attr = ctx->output_alloc_attr(i);
2027
2028 // Sanity check of output tensor types.
2029 DataType dtype;
2030 if (val.is_ref()) {
2031 tf_shared_lock ml(*val.mutex_if_ref);
2032 dtype = MakeRefType(val->dtype());
2033 } else {
2034 dtype = val->dtype();
2035 }
2036 if (dtype == item.output_type(i)) {
2037 if (stats && val.tensor->IsInitialized()) {
2038 nodestats::SetOutput(stats, i, val.tensor);
2039 }
2040 if (val.is_ref()) {
2041 out->has_value = true;
2042 out->ref = val.tensor;
2043 out->ref_mu = val.mutex_if_ref;
2044 if (log_memory_) {
2045 Tensor to_log;
2046 {
2047 // Dereference the tensor under the lock.
2048 tf_shared_lock l(*out->ref_mu);
2049 to_log = *out->ref;
2050 }
2051 LogMemory::RecordTensorOutput(ctx->op_kernel().name(),
2052 ctx->step_id(), i, to_log);
2053 }
2054 } else {
2055 // NOTE that std::move is used here, so val.tensor goes to
2056 // uninitialized state (val.tensor->IsInitialized return false).
2057 DCHECK(!out->val_field_is_set);
2058 out->has_value = true;
2059 out->val_field_is_set = true;
2060 out->val.Init(std::move(*val.tensor));
2061 if (log_memory_) {
2062 LogMemory::RecordTensorOutput(ctx->op_kernel().name(),
2063 ctx->step_id(), i, *out->val);
2064 }
2065 }
2066 } else {
2067 s.Update(errors::Internal("Output ", i, " of type ",
2068 DataTypeString(dtype),
2069 " does not match declared output type ",
2070 DataTypeString(item.output_type(i)),
2071 " for node ", FormatNodeForError(*node)));
2072 }
2073 }
2074 if (!val.is_ref()) {
2075 // If OpKernelContext returns outputs via pass-by-value, we
2076 // don't need this trouble.
2077 delete val.tensor;
2078 }
2079 }
2080 return s;
2081 }
2082
PropagateOutputs(const TaggedNode & tagged_node,const NodeItem * item,EntryVector * outputs,TaggedNodeSeq * ready)2083 void ExecutorState::PropagateOutputs(const TaggedNode& tagged_node,
2084 const NodeItem* item, EntryVector* outputs,
2085 TaggedNodeSeq* ready) {
2086 auto activity_handle =
2087 [&]() -> std::unique_ptr<tracing::TraceCollector::Handle> {
2088 auto* trace_collector = tracing::GetTraceCollector();
2089 if (TF_PREDICT_FALSE(trace_collector != nullptr &&
2090 trace_collector->IsEnabledForActivities(
2091 false /* is_expensive */))) {
2092 const string& op_name = item->kernel->name();
2093 // Intentionally using ExecutorPropagateOutputs as the first key so that
2094 // users are aware that it's not the op invocation.
2095 return trace_collector->CreateActivityHandle(
2096 "ExecutorPropagateOutputs",
2097 strings::StrCat(op_name, "#id=", step_id_, "#"),
2098 false /* is_expensive */);
2099 } else {
2100 return nullptr;
2101 }
2102 }();
2103
2104 const Node* node = tagged_node.node;
2105 FrameState* input_frame = tagged_node.input_frame;
2106 const int64 input_iter = tagged_node.input_iter;
2107 const bool is_dead = tagged_node.is_dead;
2108
2109 // Propagates outputs along out edges, and puts newly ready nodes
2110 // into the ready queue.
2111 ready->clear();
2112 bool is_frame_done = false;
2113 FrameState* output_frame = input_frame;
2114 int64 output_iter = input_iter;
2115
2116 if (!item->is_enter_exit_or_next_iter) {
2117 // Fast path for nodes types that don't need special handling
2118 DCHECK_EQ(input_frame, output_frame);
2119 // Normal path for most nodes
2120 mutex_lock l(input_frame->mu);
2121 output_frame->ActivateNodes(item, is_dead, output_iter, outputs, ready);
2122 is_frame_done = input_frame->DecrementOutstandingOpsLocked(
2123 &impl_->gview_, input_iter, ready);
2124 } else if (item->is_enter) {
2125 FindOrCreateChildFrame(input_frame, input_iter, node, &output_frame);
2126 output_iter = 0;
2127 {
2128 const NodeItem* item = impl_->gview_.node(node->id());
2129 mutex_lock l(output_frame->mu);
2130 if (item->is_constant_enter) {
2131 // Propagate to all active iterations if this is a loop invariant.
2132 output_frame->AddLoopInv(item, (*outputs)[0], ready);
2133 } else {
2134 output_frame->ActivateNodes(item, is_dead, output_iter, outputs, ready);
2135 }
2136 output_frame->num_pending_inputs--;
2137 }
2138 is_frame_done =
2139 input_frame->DecrementOutstandingOps(&impl_->gview_, input_iter, ready);
2140 } else if (item->is_exit) {
2141 if (is_dead) {
2142 mutex_lock l(input_frame->mu);
2143 // Stop and remember this node if it is a dead exit.
2144 if (input_iter == input_frame->iteration_count) {
2145 input_frame->dead_exits.push_back(node);
2146 }
2147 is_frame_done = input_frame->DecrementOutstandingOpsLocked(
2148 &impl_->gview_, input_iter, ready);
2149 } else {
2150 output_frame = input_frame->parent_frame;
2151 output_iter = input_frame->parent_iter;
2152 {
2153 mutex_lock l(output_frame->mu);
2154 output_frame->ActivateNodes(item, is_dead, output_iter, outputs, ready);
2155 }
2156 is_frame_done = input_frame->DecrementOutstandingOps(&impl_->gview_,
2157 input_iter, ready);
2158 }
2159 } else {
2160 DCHECK(IsNextIteration(node));
2161 mutex_lock l(input_frame->mu);
2162 if (is_dead) {
2163 // Stop the deadness propagation.
2164 output_frame = nullptr;
2165 } else {
2166 if (input_iter == input_frame->iteration_count &&
2167 input_frame->num_outstanding_iterations ==
2168 input_frame->max_parallel_iterations) {
2169 // Reached the maximum for parallel iterations.
2170 input_frame->next_iter_roots.push_back({node, (*outputs)[0]});
2171 output_frame = nullptr;
2172 } else {
2173 // If this is a new iteration, start it.
2174 if (input_iter == input_frame->iteration_count) {
2175 input_frame->IncrementIteration(&impl_->gview_, ready);
2176 }
2177 output_iter = input_iter + 1;
2178 }
2179 }
2180 if (output_frame != nullptr) {
2181 // This is the case when node is not Enter, Exit, or NextIteration.
2182 DCHECK(input_frame == output_frame);
2183 output_frame->ActivateNodes(item, is_dead, output_iter, outputs, ready);
2184 }
2185 is_frame_done = input_frame->DecrementOutstandingOpsLocked(
2186 &impl_->gview_, input_iter, ready);
2187 }
2188
2189 // At this point, this node is completely done. We also know if the
2190 // completion of this node makes its frame completed.
2191 if (is_frame_done) {
2192 FrameState* parent_frame = input_frame->parent_frame;
2193 const int64 parent_iter = input_frame->parent_iter;
2194 DeleteFrame(input_frame, ready);
2195 if (parent_frame != nullptr) {
2196 // The completion of frame may cause completions in its parent frame.
2197 // So clean things up recursively.
2198 CleanupFramesIterations(parent_frame, parent_iter, ready);
2199 }
2200 }
2201 }
2202
NodeDone(const Status & s,const Node * node,const TaggedNodeSeq & ready,NodeExecStatsInterface * stats,TaggedNodeReadyQueue * inline_ready)2203 bool ExecutorState::NodeDone(const Status& s, const Node* node,
2204 const TaggedNodeSeq& ready,
2205 NodeExecStatsInterface* stats,
2206 TaggedNodeReadyQueue* inline_ready) {
2207 nodestats::SetAllEnd(stats);
2208 if (stats) {
2209 if (stats_collector_) {
2210 stats->Done(impl_->params_.device->name());
2211 } else {
2212 delete stats;
2213 }
2214 }
2215
2216 bool abort_run = false;
2217 if (!s.ok()) {
2218 // Some error happened. This thread of computation is done.
2219 mutex_lock l(mu_);
2220 if (status_.ok()) {
2221 abort_run = true;
2222 status_ = s;
2223 }
2224 }
2225 if (abort_run) {
2226 TRACEPRINTF("StartAbort: %s", s.ToString().c_str());
2227 if (rendezvous_) {
2228 rendezvous_->StartAbort(s);
2229 }
2230 if (collective_executor_) {
2231 collective_executor_->StartAbort(s);
2232 }
2233 if (cancellation_manager_) {
2234 cancellation_manager_->StartCancel();
2235 }
2236 }
2237
2238 bool completed = false;
2239 const size_t ready_size = ready.size();
2240 if (ready_size == 0 || !s.ok()) {
2241 completed = (num_outstanding_ops_.fetch_sub(1) == 1);
2242 } else if (ready_size > 1) {
2243 num_outstanding_ops_.fetch_add(ready_size - 1, std::memory_order_relaxed);
2244 }
2245
2246 // Schedule the ready nodes in 'ready'.
2247 if (s.ok()) {
2248 ScheduleReady(ready, inline_ready);
2249 }
2250 return completed;
2251 }
2252
ScheduleReady(const TaggedNodeSeq & ready,TaggedNodeReadyQueue * inline_ready)2253 void ExecutorState::ScheduleReady(const TaggedNodeSeq& ready,
2254 TaggedNodeReadyQueue* inline_ready) {
2255 if (ready.empty()) return;
2256
2257 int64 scheduled_nsec = 0;
2258 if (stats_collector_) {
2259 scheduled_nsec = nodestats::NowInNsec();
2260 }
2261
2262 if (inline_ready == nullptr) {
2263 // Schedule to run all the ready ops in thread pool.
2264 for (auto& tagged_node : ready) {
2265 runner_([=]() { Process(tagged_node, scheduled_nsec); });
2266 }
2267 return;
2268 }
2269
2270 const GraphView& gview = impl_->gview_;
2271 const TaggedNode* curr_expensive_node = nullptr;
2272 for (auto& tagged_node : ready) {
2273 const NodeItem& item = *gview.node(tagged_node.node->id());
2274 if (tagged_node.is_dead || !item.kernel->IsExpensive()) {
2275 // Inline this inexpensive node.
2276 inline_ready->push_back(tagged_node);
2277 } else {
2278 if (curr_expensive_node) {
2279 // Dispatch to another thread since there is plenty of work to
2280 // do for this thread.
2281 runner_(std::bind(&ExecutorState::Process, this, *curr_expensive_node,
2282 scheduled_nsec));
2283 }
2284 curr_expensive_node = &tagged_node;
2285 }
2286 }
2287 if (curr_expensive_node) {
2288 if (inline_ready->empty()) {
2289 // Tail recursion optimization
2290 inline_ready->push_back(*curr_expensive_node);
2291 } else {
2292 // There are inline nodes to run already. We dispatch this expensive
2293 // node to other thread.
2294 runner_(std::bind(&ExecutorState::Process, this, *curr_expensive_node,
2295 scheduled_nsec));
2296 }
2297 }
2298 }
2299
MaybeMarkCompleted(FrameState * frame,int64 iter,int64 node_id)2300 inline void ExecutorState::MaybeMarkCompleted(FrameState* frame, int64 iter,
2301 int64 node_id) {
2302 // TODO(misard) Replace with a finer-grain enabling flag once we
2303 // add better optional debugging support.
2304 if (vlog_ && VLOG_IS_ON(1)) {
2305 const NodeItem* item = impl_->gview_.node(node_id);
2306 mutex_lock l(frame->mu);
2307 frame->GetIteration(iter)->mark_completed(item->pending_id);
2308 }
2309 }
2310
GetTensorValueForDump(const Entry & input)2311 const Tensor* ExecutorState::GetTensorValueForDump(const Entry& input) {
2312 if (!input.has_value) {
2313 return kEmptyTensor;
2314 } else if (input.ref == nullptr) {
2315 return input.val.get();
2316 } else {
2317 return input.ref;
2318 }
2319 }
2320
DumpPendingNodeState(const int node_id,const Entry * input_vector,const bool show_nodes_with_no_ready_inputs)2321 void ExecutorState::DumpPendingNodeState(
2322 const int node_id, const Entry* input_vector,
2323 const bool show_nodes_with_no_ready_inputs) {
2324 const NodeItem& node_item = *impl_->gview_.node(node_id);
2325 const Node& node = *node_item.node;
2326 const int input_base = node_item.input_start;
2327 if (!show_nodes_with_no_ready_inputs) {
2328 bool has_ready_input = false;
2329 for (int i = 0; i < node.num_inputs(); ++i) {
2330 const Entry& input = input_vector[input_base + i];
2331 const Tensor* tensor = GetTensorValueForDump(input);
2332 if (tensor->IsInitialized()) {
2333 has_ready_input = true;
2334 break;
2335 }
2336 }
2337 if (!has_ready_input) {
2338 return;
2339 }
2340 }
2341 LOG(WARNING) << " Pending Node: " << node.DebugString();
2342 for (int i = 0; i < node.num_inputs(); ++i) {
2343 const Entry& input = input_vector[input_base + i];
2344 const Tensor* tensor = GetTensorValueForDump(input);
2345 if (tensor->IsInitialized()) {
2346 LOG(WARNING) << " Input " << i << ": "
2347 << strings::StrCat(
2348 "Tensor<type: ", DataTypeString(tensor->dtype()),
2349 " shape: ", tensor->shape().DebugString(), ">");
2350 } else {
2351 LOG(WARNING) << " Input " << i << ": not present";
2352 }
2353 }
2354 }
2355
DumpActiveNodeState(const int node_id,const Entry * input_vector)2356 void ExecutorState::DumpActiveNodeState(const int node_id,
2357 const Entry* input_vector) {
2358 const NodeItem& node_item = *impl_->gview_.node(node_id);
2359 const Node& node = *node_item.node;
2360 LOG(WARNING) << " Active Node: " << node.DebugString();
2361 const int input_base = node_item.input_start;
2362 for (int i = 0; i < node.num_inputs(); ++i) {
2363 const Entry& input = input_vector[input_base + i];
2364 const Tensor* tensor = GetTensorValueForDump(input);
2365 if (tensor->IsInitialized()) {
2366 LOG(WARNING) << " Input " << i << ": "
2367 << strings::StrCat(
2368 "Tensor<type: ", DataTypeString(tensor->dtype()),
2369 " shape: ", tensor->shape().DebugString(), ">");
2370 } else {
2371 LOG(WARNING) << " Input " << i << ": not present";
2372 }
2373 }
2374 }
2375
DumpIterationState(const FrameState * frame,IterationState * iteration)2376 void ExecutorState::DumpIterationState(const FrameState* frame,
2377 IterationState* iteration) {
2378 const std::vector<const Node*>* nodes = frame->nodes;
2379 // Dump any waiting nodes that are holding on to tensors.
2380 for (const Node* node : *nodes) {
2381 const int node_id = node->id();
2382 PendingCounts::Handle pending_id = impl_->gview_.node(node_id)->pending_id;
2383 if (iteration->node_state(pending_id) == PendingCounts::PENDING_NOTREADY ||
2384 iteration->node_state(pending_id) == PendingCounts::PENDING_READY) {
2385 DumpPendingNodeState(node_id, iteration->input_tensors, false);
2386 }
2387 }
2388 // Then the active nodes.
2389 for (const Node* node : *nodes) {
2390 const int node_id = node->id();
2391 PendingCounts::Handle pending_id = impl_->gview_.node(node_id)->pending_id;
2392 if (iteration->node_state(pending_id) == PendingCounts::STARTED) {
2393 DumpActiveNodeState(node_id, iteration->input_tensors);
2394 }
2395 }
2396 // Show all input tensors in use.
2397 const int total_input_tensors = frame->total_input_tensors;
2398 size_t total_bytes = 0;
2399 for (int i = 0; i < total_input_tensors; ++i) {
2400 const Entry& input = iteration->input_tensors[i];
2401 const Tensor* tensor = GetTensorValueForDump(input);
2402 if (tensor->IsInitialized()) {
2403 LOG(WARNING) << " Input " << i << ": "
2404 << strings::StrCat(
2405 "Tensor<type: ", DataTypeString(tensor->dtype()),
2406 " shape: ", tensor->shape().DebugString(),
2407 ", bytes: ", tensor->TotalBytes(), ">");
2408 total_bytes += tensor->TotalBytes();
2409 }
2410 }
2411 LOG(WARNING) << " Total bytes " << total_bytes;
2412 }
2413
DumpState()2414 void ExecutorState::DumpState() {
2415 mutex_lock l(mu_);
2416 if (!dumped_on_error_) {
2417 LOG(WARNING) << "Dumping state";
2418 for (auto& frame : outstanding_frames_) {
2419 LOG(WARNING) << frame.first;
2420 FrameState* frame_state = frame.second;
2421 mutex_lock frame_lock(frame_state->mu);
2422 for (IterationState* iteration : frame_state->iterations) {
2423 LOG(WARNING) << " Iteration:";
2424 DumpIterationState(frame_state, iteration);
2425 }
2426 }
2427 dumped_on_error_ = true;
2428 }
2429 }
2430
ScheduleFinish()2431 void ExecutorState::ScheduleFinish() {
2432 int num_deferred_ops;
2433 {
2434 mutex_lock lock(num_deferred_ops_mu_);
2435 num_deferred_ops = num_deferred_ops_;
2436 }
2437 if (num_deferred_ops > 0) {
2438 // Finish() may be blocked waiting for deferred async ops to complete. The
2439 // execution of deferred async ops may be waiting for non-enqueued ops of
2440 // other executors to complete. So running Finish() on the current thread
2441 // (inter-op threadpool thread) may lead to a deadlock due to threadpool
2442 // exhaustion. Instead, we run it on a separate thread to unblock the
2443 // threadpool thread.
2444 Env::Default()->SchedClosure([this]() { Finish(); });
2445 } else {
2446 Finish();
2447 }
2448 }
2449
Finish()2450 void ExecutorState::Finish() {
2451 mu_.lock();
2452 auto status = status_;
2453 auto done_cb = std::move(done_cb_);
2454 auto runner = std::move(runner_);
2455 mu_.unlock();
2456 CHECK(done_cb != nullptr);
2457 Device* device = impl_->params_.device;
2458
2459 // There are several potential race conditions below. To name a few:
2460 // 1. Even if the device's status is OK at the precise moment when
2461 // num_deferred_ops_ reaches 0, it could go bad before device->RefreshStatus()
2462 // is called below, caused by work enqueued onto the same device by other
2463 // concurrent ExecutorState objects.
2464 // 2. Some implementations of Device::RefreshStatus, such as
2465 // XlaDevice::RefreshStatus, may be inherently racy because it releases the
2466 // device mutex after a stream pointer is acquired and before the stream is
2467 // queried for status.
2468 // 3. It's the same for some implementations of Device::Sync, such as
2469 // XlaDevice::Sync.
2470 //
2471 // However, these race conditions are acceptable because a stream (and
2472 // therefore an XlaDevice) can only go from OK to not-OK, never the opposite,
2473 // which means we will at worst report errors when there isn't any, never the
2474 // opposite.
2475
2476 // If inc_num_deferred_ops_function has ever been called, ExecutorState must
2477 // wait for all corresponding dec_num_deferred_ops_function calls to happen
2478 // regardless of status. This ensures that dec_num_deferred_ops_function can
2479 // safely use ExecutorState's resources.
2480 {
2481 mutex_lock lock(num_deferred_ops_mu_);
2482 while (num_deferred_ops_ > 0) {
2483 no_deferred_ops_cv_.wait(lock);
2484 }
2485 }
2486
2487 // An early exit for devices don't allow sync on completion. Ops that run on
2488 // these devices should have used num_deferred_ops correctly to ensure the
2489 // device has finished all relevant work at this point.
2490 if (!device->AllowsSyncOnCompletion()) {
2491 status.Update(device->RefreshStatus());
2492 if (!status.ok()) {
2493 // In device async execution mode, it's possible for device execution to
2494 // lag behind ExecutorState scheduling so much that this is the first
2495 // place a device execution error surfaces.
2496 // If so, all ExecutorState::NodeDone calls have already happened with OK
2497 // status. This is the last defense where StartCancel must be called to
2498 // abort all computation still running on any device.
2499 // TODO(b/124523000): Always call Finish in a separate thread, so even if
2500 // StartCancel blocks the current thread's execution, we won't encounter
2501 // deadlocks caused by inter-op thread exhaustion.
2502 if (cancellation_manager_) {
2503 cancellation_manager_->StartCancel();
2504 }
2505 }
2506 delete this;
2507 runner([=]() { done_cb(status); });
2508 return;
2509 }
2510
2511 if (sync_on_finish_ && status.ok()) {
2512 // Block until the device has finished all queued operations. For
2513 // devices like GPUs that continue to execute Ops after their Compute
2514 // methods have completed, this ensures that control is not returned to
2515 // the user until the step (and its side-effects) has actually completed.
2516 device->Sync([=](Status new_status) mutable {
2517 status.Update(new_status);
2518 delete this;
2519 runner([=]() { done_cb(status); });
2520 });
2521 } else {
2522 delete this;
2523 runner([=]() { done_cb(status); });
2524 }
2525 }
2526
FindOrCreateChildFrame(FrameState * frame,int64 iter,const Node * node,FrameState ** child)2527 void ExecutorState::FindOrCreateChildFrame(FrameState* frame, int64 iter,
2528 const Node* node,
2529 FrameState** child) {
2530 // Get the child frame name.
2531 string enter_name;
2532 Status s = GetNodeAttr(node->attrs(), "frame_name", &enter_name);
2533 DCHECK(s.ok()) << s;
2534 const string child_name = MakeFrameName(frame, iter, enter_name);
2535
2536 {
2537 mutex_lock executor_lock(mu_);
2538 auto it = outstanding_frames_.find(child_name);
2539 if (it != outstanding_frames_.end()) {
2540 *child = it->second;
2541 return;
2542 }
2543 }
2544
2545 // Need to create a new frame instance.
2546 // Note that this new frame instance is created without any locks.
2547 if (vlog_) VLOG(2) << "Create frame: " << child_name;
2548
2549 int parallel_iters;
2550 s = GetNodeAttr(node->attrs(), "parallel_iterations", ¶llel_iters);
2551 DCHECK(s.ok()) << s;
2552 FrameState* temp = new FrameState(impl_, parallel_iters);
2553 temp->frame_name = child_name;
2554 temp->frame_id = Hash64(child_name);
2555 temp->parent_frame = frame;
2556 temp->parent_iter = iter;
2557 temp->InitializeFrameInfo(enter_name);
2558
2559 // 'iterations' is a fixed-length circular buffer.
2560 temp->iterations.resize(temp->max_parallel_iterations + 1);
2561 // Initialize iteration 0.
2562 temp->iterations[0] =
2563 new IterationState(temp->pending_counts, temp->total_input_tensors);
2564
2565 {
2566 mutex_lock executor_lock(mu_);
2567 auto it = outstanding_frames_.find(child_name);
2568 if (it != outstanding_frames_.end()) {
2569 *child = it->second;
2570 } else {
2571 mutex_lock frame_lock(frame->mu);
2572 frame->GetIteration(iter)->outstanding_frame_count++;
2573 outstanding_frames_[child_name] = temp;
2574 *child = temp;
2575 temp = nullptr;
2576 }
2577 }
2578 delete temp; // Not used so delete it.
2579 }
2580
DeleteFrame(FrameState * frame,TaggedNodeSeq * ready)2581 void ExecutorState::DeleteFrame(FrameState* frame, TaggedNodeSeq* ready) {
2582 // First, propagate dead_exits (if any) to the parent frame.
2583 FrameState* parent_frame = frame->parent_frame;
2584 const int64 parent_iter = frame->parent_iter;
2585 if (parent_frame != nullptr) {
2586 mutex_lock parent_frame_lock(parent_frame->mu);
2587 // Propagate all the dead exits to the parent frame.
2588 mutex_lock this_frame_lock(frame->mu);
2589 for (const Node* node : frame->dead_exits) {
2590 auto parent_iter_state = parent_frame->GetIteration(parent_iter);
2591 for (const Edge* e : node->out_edges()) {
2592 const Node* dst_node = e->dst();
2593
2594 const auto dst_pending_id =
2595 impl_->gview_.node(dst_node->id())->pending_id;
2596
2597 // TODO(yuanbyu): We don't need this if we require the subgraph
2598 // given to an executor not to contain a sink node.
2599 if (dst_node->IsSink()) continue;
2600
2601 bool dst_dead = true;
2602 bool dst_ready = false;
2603 // We know this is a dead input to dst.
2604 if (IsMerge(dst_node)) {
2605 if (e->IsControlEdge()) {
2606 parent_iter_state->decrement_pending(dst_pending_id, 2);
2607 int count = parent_iter_state->pending(dst_pending_id);
2608 int dead_cnt = parent_iter_state->dead_count(dst_pending_id);
2609 dst_dead = (dead_cnt == dst_node->num_inputs());
2610 dst_ready = (count == 0) || ((count == 1) && dst_dead);
2611 } else {
2612 parent_iter_state->increment_dead_count(dst_pending_id);
2613 const int dead_cnt = parent_iter_state->dead_count(dst_pending_id);
2614 dst_dead = (dead_cnt == dst_node->num_inputs());
2615 dst_ready =
2616 (parent_iter_state->pending(dst_pending_id) == 1) && dst_dead;
2617 }
2618 } else {
2619 parent_iter_state->increment_dead_count(dst_pending_id);
2620 dst_ready =
2621 (parent_iter_state->decrement_pending(dst_pending_id, 1) == 0);
2622 }
2623 if (dst_ready) {
2624 if (IsControlTrigger(dst_node)) dst_dead = false;
2625 ready->emplace_back(dst_node, parent_frame, parent_iter, dst_dead);
2626 parent_iter_state->outstanding_ops++;
2627 }
2628 }
2629 }
2630 }
2631
2632 // Delete the frame.
2633 const string& frame_name = frame->frame_name;
2634 if (vlog_) VLOG(2) << "Delete frame " << frame_name;
2635 {
2636 mutex_lock executor_lock(mu_);
2637 outstanding_frames_.erase(frame_name);
2638 }
2639 delete frame;
2640 }
2641
CleanupFramesIterations(FrameState * frame,int64 iter,TaggedNodeSeq * ready)2642 void ExecutorState::CleanupFramesIterations(FrameState* frame, int64 iter,
2643 TaggedNodeSeq* ready) {
2644 bool is_frame_done = false;
2645 {
2646 mutex_lock frame_lock(frame->mu);
2647 frame->GetIteration(iter)->outstanding_frame_count--;
2648 is_frame_done = frame->CleanupIterations(&impl_->gview_, iter, ready);
2649 }
2650 if (is_frame_done) {
2651 FrameState* parent_frame = frame->parent_frame;
2652 const int64 parent_iter = frame->parent_iter;
2653 DeleteFrame(frame, ready);
2654 if (parent_frame != nullptr) {
2655 // The completion of frame may cause completions in its parent frame.
2656 // So clean things up recursively.
2657 CleanupFramesIterations(parent_frame, parent_iter, ready);
2658 }
2659 }
2660 }
2661
ActivateNodes(const NodeItem * item,const bool is_dead,int64 iter,EntryVector * outputs,TaggedNodeSeq * ready)2662 void ExecutorState::FrameState::ActivateNodes(const NodeItem* item,
2663 const bool is_dead, int64 iter,
2664 EntryVector* outputs,
2665 TaggedNodeSeq* ready) {
2666 const GraphView& gview = executor->gview_;
2667 IterationState* iter_state = GetIteration(iter);
2668 const size_t num_output_edges = item->num_output_edges;
2669 const EdgeInfo* edges = item->output_edge_list();
2670 Entry* input_tensors = iter_state->input_tensors;
2671 for (size_t out_index = 0; out_index < num_output_edges; out_index++) {
2672 const EdgeInfo& e = edges[out_index];
2673 const int dst_id = e.dst_id;
2674 const NodeItem* dst_item = gview.node(dst_id);
2675 const PendingCounts::Handle dst_pending_id = dst_item->pending_id;
2676 const int src_slot = e.output_slot;
2677
2678 // TODO(yuanbyu): We don't need this if we require the subgraph
2679 // given to an executor not to contain a sink node.
2680 if (dst_item->is_sink) continue;
2681
2682 bool dst_dead = false;
2683 bool dst_ready = false;
2684 // True iff this input for dst is needed. We only set this input for
2685 // dst if this flag is true. This is needed to make the thread safety
2686 // analysis happy.
2687 const bool is_control_edge = (src_slot == Graph::kControlSlot);
2688 bool dst_need_input = !is_control_edge;
2689 if (dst_item->is_merge) {
2690 // A merge node is ready if all control inputs have arrived and either
2691 // a) a live data input becomes available or b) all data inputs are
2692 // dead. For Merge, pending's LSB is set iff a live data input has
2693 // arrived.
2694 if (is_control_edge) {
2695 iter_state->decrement_pending(dst_pending_id, 2);
2696 int count = iter_state->pending(dst_pending_id);
2697 int dead_cnt = iter_state->dead_count(dst_pending_id);
2698 dst_dead = (dead_cnt == dst_item->num_inputs);
2699 dst_ready = (count == 0) || ((count == 1) && dst_dead);
2700 } else {
2701 if ((*outputs)[src_slot].has_value) {
2702 // This is a live data input.
2703 int count = iter_state->pending(dst_pending_id);
2704 iter_state->mark_live(dst_pending_id);
2705 // Only the first live edge sets the input and (potentially)
2706 // triggers execution. The low bit of count is set if and
2707 // only if no live input has been used yet (mark_live clears
2708 // it). The node should be started if and only if this is
2709 // the first live input and there are no pending control
2710 // edges, i.e. count == 1.
2711 dst_ready = (count == 1);
2712 dst_need_input = ((count & 0x1) == 1);
2713 } else {
2714 // This is a dead data input. Note that dst_node is dead if node is
2715 // a dead enter. We need this to handle properly a while loop on
2716 // the untaken branch of a conditional.
2717 // TODO(yuanbyu): This is a bit hacky, but a good solution for
2718 // now.
2719 iter_state->increment_dead_count(dst_pending_id);
2720 const int dead_cnt = iter_state->dead_count(dst_pending_id);
2721 dst_dead = (dead_cnt == dst_item->num_inputs) || item->is_enter;
2722 dst_ready = (iter_state->pending(dst_pending_id) == 1) && dst_dead;
2723 dst_need_input = false;
2724 }
2725 }
2726 } else {
2727 const bool increment_dead =
2728 (is_dead || (!is_control_edge && !(*outputs)[src_slot].has_value));
2729 int pending, dead;
2730 iter_state->adjust_for_activation(dst_pending_id, increment_dead,
2731 &pending, &dead);
2732 dst_dead = (dead > 0);
2733 dst_ready = (pending == 0);
2734 }
2735
2736 if (dst_need_input) {
2737 const int dst_slot = e.input_slot;
2738 const int dst_loc = dst_item->input_start + dst_slot;
2739 if (e.is_last) {
2740 input_tensors[dst_loc] = std::move((*outputs)[src_slot]);
2741 } else {
2742 input_tensors[dst_loc] = (*outputs)[src_slot];
2743 }
2744 }
2745
2746 // Add dst to the ready queue if it's ready
2747 if (dst_ready) {
2748 if (dst_item->is_control_trigger) dst_dead = false;
2749 ready->emplace_back(dst_item->node, this, iter, dst_dead);
2750 iter_state->outstanding_ops++;
2751 }
2752 }
2753 }
2754
ActivateNexts(const GraphView * gview,int64 iter,TaggedNodeSeq * ready)2755 void ExecutorState::FrameState::ActivateNexts(const GraphView* gview,
2756 int64 iter,
2757 TaggedNodeSeq* ready) {
2758 // Propagate the deferred NextIteration nodes to the new iteration.
2759 for (auto& node_entry : next_iter_roots) {
2760 const Node* node = node_entry.first;
2761 const Entry& entry = node_entry.second;
2762 const bool is_dead = !entry.has_value;
2763 const NodeItem* item = gview->node(node->id());
2764 EntryVector outputs{entry};
2765 ActivateNodes(item, is_dead, iter, &outputs, ready);
2766 }
2767 next_iter_roots.clear();
2768 }
2769
ActivateLoopInvs(const GraphView * gview,int64 iter,TaggedNodeSeq * ready)2770 void ExecutorState::FrameState::ActivateLoopInvs(const GraphView* gview,
2771 int64 iter,
2772 TaggedNodeSeq* ready) {
2773 // Propagate loop invariants to the new iteration.
2774 for (auto& node_entry : inv_values) {
2775 const Node* node = node_entry.first;
2776 const Entry& entry = node_entry.second;
2777 const bool is_dead = !entry.has_value;
2778 const NodeItem* item = gview->node(node->id());
2779 EntryVector outputs{entry};
2780 ActivateNodes(item, is_dead, iter, &outputs, ready);
2781 }
2782 }
2783
AddLoopInv(const NodeItem * item,const Entry & entry,TaggedNodeSeq * ready)2784 void ExecutorState::FrameState::AddLoopInv(const NodeItem* item,
2785 const Entry& entry,
2786 TaggedNodeSeq* ready) {
2787 // Store this value.
2788 inv_values.push_back({item->node, entry});
2789
2790 // Make this value available to all iterations.
2791 const bool is_dead = !entry.has_value;
2792 for (int i = 0; i <= iteration_count; ++i) {
2793 EntryVector outputs{entry};
2794 ActivateNodes(item, is_dead, i, &outputs, ready);
2795 }
2796 }
2797
IsIterationDone(int64 iter)2798 bool ExecutorState::FrameState::IsIterationDone(int64 iter) {
2799 IterationState* iter_state = GetIteration(iter);
2800 if (iter_state->outstanding_ops == 0 &&
2801 iter_state->outstanding_frame_count == 0) {
2802 if (iter == 0) {
2803 // The enclosing frame has no pending input.
2804 return num_pending_inputs == 0;
2805 } else {
2806 // The preceding iteration is deleted (and therefore done).
2807 return (GetIteration(iter - 1) == nullptr);
2808 }
2809 }
2810 return false;
2811 }
2812
IncrementIteration(const GraphView * gview,TaggedNodeSeq * ready)2813 void ExecutorState::FrameState::IncrementIteration(const GraphView* gview,
2814 TaggedNodeSeq* ready) {
2815 iteration_count++;
2816 const int64 next_iter = iteration_count;
2817
2818 // Initialize the next iteration.
2819 IterationState* iter_state =
2820 new IterationState(pending_counts, total_input_tensors);
2821 SetIteration(next_iter, iter_state);
2822 num_outstanding_iterations++;
2823 dead_exits.clear();
2824
2825 // Activate the successors of the deferred roots in the new iteration.
2826 ActivateNexts(gview, next_iter, ready);
2827
2828 // Activate the loop invariants in the new iteration.
2829 ActivateLoopInvs(gview, next_iter, ready);
2830 }
2831
CleanupIterations(const GraphView * gview,int64 iter,TaggedNodeSeq * ready)2832 bool ExecutorState::FrameState::CleanupIterations(const GraphView* gview,
2833 int64 iter,
2834 TaggedNodeSeq* ready) {
2835 int64 curr_iter = iter;
2836 while (curr_iter <= iteration_count && IsIterationDone(curr_iter)) {
2837 // Delete the iteration curr_iter.
2838 delete GetIteration(curr_iter);
2839 SetIteration(curr_iter, nullptr);
2840 --num_outstanding_iterations;
2841 ++curr_iter;
2842
2843 // When one iteration is completed, we check for deferred iteration,
2844 // and start it if there is one.
2845 if (!next_iter_roots.empty()) {
2846 IncrementIteration(gview, ready);
2847 }
2848 }
2849 return IsFrameDone();
2850 }
2851
RunAsync(const Args & args,DoneCallback done)2852 void ExecutorImpl::RunAsync(const Args& args, DoneCallback done) {
2853 (new ExecutorState(args, this))->RunAsync(std::move(done));
2854 }
2855
2856 } // namespace
2857
NewLocalExecutor(const LocalExecutorParams & params,std::unique_ptr<const Graph> graph,Executor ** executor)2858 Status NewLocalExecutor(const LocalExecutorParams& params,
2859 std::unique_ptr<const Graph> graph,
2860 Executor** executor) {
2861 ExecutorImpl* impl = new ExecutorImpl(params, std::move(graph));
2862 const Status s = impl->Initialize();
2863 if (s.ok()) {
2864 *executor = impl;
2865 } else {
2866 delete impl;
2867 }
2868 return s;
2869 }
2870
CreateNonCachedKernel(Device * device,FunctionLibraryRuntime * flib,const NodeDef & ndef,int graph_def_version,OpKernel ** kernel)2871 Status CreateNonCachedKernel(Device* device, FunctionLibraryRuntime* flib,
2872 const NodeDef& ndef, int graph_def_version,
2873 OpKernel** kernel) {
2874 const auto device_type = DeviceType(device->attributes().device_type());
2875 auto allocator = device->GetAllocator(AllocatorAttributes());
2876 return CreateOpKernel(device_type, device, allocator, flib, ndef,
2877 graph_def_version, kernel);
2878 }
2879
DeleteNonCachedKernel(OpKernel * kernel)2880 void DeleteNonCachedKernel(OpKernel* kernel) { delete kernel; }
2881
2882 namespace {
2883
2884 class DefaultExecutorRegistrar {
2885 public:
DefaultExecutorRegistrar()2886 DefaultExecutorRegistrar() {
2887 Factory* factory = new Factory;
2888 ExecutorFactory::Register("", factory);
2889 ExecutorFactory::Register("DEFAULT", factory);
2890 }
2891
2892 private:
2893 class Factory : public ExecutorFactory {
NewExecutor(const LocalExecutorParams & params,std::unique_ptr<const Graph> graph,std::unique_ptr<Executor> * out_executor)2894 Status NewExecutor(const LocalExecutorParams& params,
2895 std::unique_ptr<const Graph> graph,
2896 std::unique_ptr<Executor>* out_executor) override {
2897 Executor* ret = nullptr;
2898 TF_RETURN_IF_ERROR(NewLocalExecutor(params, std::move(graph), &ret));
2899 out_executor->reset(ret);
2900 return Status::OK();
2901 }
2902 };
2903 };
2904 static DefaultExecutorRegistrar registrar;
2905
2906 } // namespace
2907
2908 } // namespace tensorflow
2909