1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/common_runtime/propagator_state.h"
17 
18 #include "tensorflow/core/common_runtime/graph_view.h"
19 #include "tensorflow/core/common_runtime/immutable_executor_state.h"
20 #include "tensorflow/core/common_runtime/propagator_debug_utils.h"
21 #include "tensorflow/core/framework/op_kernel.h"
22 #include "tensorflow/core/lib/hash/hash.h"
23 #include "tensorflow/core/platform/hash.h"
24 #include "tensorflow/core/profiler/lib/traceme.h"
25 
26 namespace tensorflow {
27 
PropagatorState(const ImmutableExecutorState & immutable_state,int64 step_id,bool vlog)28 PropagatorState::PropagatorState(const ImmutableExecutorState& immutable_state,
29                                  int64 step_id, bool vlog)
30     : immutable_state_(immutable_state),
31       step_id_(step_id),
32       vlog_(vlog || VLOG_IS_ON(1)) {
33   // We start the entire execution in iteration 0 of the root frame
34   // so let us create the root frame and the state for iteration 0.
35   // We assume root_frame_->frame_name.empty().
36   root_frame_ = new FrameState(immutable_state_, 1);
37   root_frame_->frame_id = 0;  // must be 0
38   root_frame_->InitializeFrameInfo(immutable_state_.get_root_frame_info());
39 
40   // Initialize iteration 0.
41   root_frame_->SetIteration(
42       0, new PropagatorState::IterationState(0, root_frame_->pending_counts,
43                                              root_frame_->total_input_tensors));
44 
45   outstanding_frames_.emplace(root_frame_->frame_id, root_frame_);
46 }
47 
~PropagatorState()48 PropagatorState::~PropagatorState() {
49   for (auto name_frame : outstanding_frames_) {
50     delete name_frame.second;
51   }
52 }
53 
ActivateRoots(gtl::ArraySlice<const NodeItem * > roots,TaggedNodeSeq * ready)54 void PropagatorState::ActivateRoots(gtl::ArraySlice<const NodeItem*> roots,
55                                     TaggedNodeSeq* ready) {
56   mutex_lock l(root_frame_->mu);
57   IterationState* root_iter = root_frame_->GetIteration(0);
58   for (const NodeItem* item : roots) {
59     DCHECK_EQ(item->num_inputs, 0);
60     ready->emplace_back(item, root_frame_, root_iter, false);
61   }
62   root_iter->outstanding_ops = ready->size();
63 }
64 
PropagateOutputs(const TaggedNode & tagged_node,EntryVector * outputs,TaggedNodeSeq * ready)65 void PropagatorState::PropagateOutputs(const TaggedNode& tagged_node,
66                                        EntryVector* outputs,
67                                        TaggedNodeSeq* ready) {
68   profiler::TraceMe activity(
69       [&]() {
70         return strings::StrCat(
71             "ExecutorPropagateOutputs#", "id=", step_id_,
72             ",kernel_name=", tagged_node.node_item->kernel->name_view(),
73             ",num_output_edges=", tagged_node.node_item->num_output_edges,
74             ",num_output_control_edges=",
75             tagged_node.node_item->num_output_control_edges, "#");
76       },
77       profiler::GetTFTraceMeLevel(/*is_expensive=*/false));
78 
79   const NodeItem* const item = tagged_node.node_item;
80   FrameState* const input_frame = tagged_node.input_frame;
81   IterationState* const input_iter = tagged_node.input_iter;
82   const bool is_dead = tagged_node.is_dead;
83 
84   // Propagates outputs along out edges, and puts newly ready nodes
85   // into the ready queue.
86   DCHECK(ready->empty());
87   bool is_frame_done = false;
88   FrameState* output_frame = input_frame;
89   IterationState* output_iter = input_iter;
90 
91   if (!item->is_enter_exit_or_next_iter) {
92     // Fast path for node types that don't need special handling.
93     // This is the case for most nodes.
94     DCHECK_EQ(input_frame, output_frame);
95     FrameState* frame = input_frame;
96     is_frame_done = frame->ActivateNodesAndAdjustOutstanding(
97         item, is_dead, output_iter, outputs, ready);
98   } else if (item->is_enter) {
99     FindOrCreateChildFrame(input_frame, input_iter, *item, &output_frame);
100     {
101       mutex_lock l(output_frame->mu);
102       output_iter = output_frame->GetIteration(0);
103       if (item->is_constant_enter) {
104         // Propagate to all active iterations if this is a loop invariant.
105         output_frame->AddLoopInv(item, (*outputs)[0], ready);
106       } else {
107         int activated = output_frame->ActivateNodesLocked(
108             item, is_dead, output_iter, outputs, ready);
109         output_frame->AdjustOutstandingOpsLocked(output_iter, activated, ready);
110       }
111       output_frame->num_pending_inputs--;
112     }
113     is_frame_done = input_frame->DecrementOutstandingOps(input_iter, ready);
114   } else if (item->is_exit) {
115     if (is_dead) {
116       mutex_lock l(input_frame->mu);
117       // Stop and remember this node if it is a dead exit.
118       if (input_iter->iter_num == input_frame->iteration_count) {
119         input_frame->dead_exits.push_back(item);
120       }
121       is_frame_done =
122           input_frame->DecrementOutstandingOpsLocked(input_iter, ready);
123     } else {
124       output_frame = input_frame->parent_frame;
125       output_iter = input_frame->parent_iter;
126       {
127         mutex_lock l(output_frame->mu);
128         int activated = output_frame->ActivateNodesLocked(
129             item, is_dead, output_iter, outputs, ready);
130         output_frame->AdjustOutstandingOpsLocked(output_iter, activated, ready);
131       }
132       is_frame_done = input_frame->DecrementOutstandingOps(input_iter, ready);
133     }
134   } else {
135     DCHECK(item->is_next_iteration);
136     mutex_lock l(input_frame->mu);
137     if (is_dead) {
138       // Stop the deadness propagation.
139       output_frame = nullptr;
140     } else {
141       if (input_iter->iter_num == input_frame->iteration_count &&
142           input_frame->num_outstanding_iterations ==
143               input_frame->max_parallel_iterations) {
144         // Reached the maximum for parallel iterations.
145         input_frame->next_iter_roots.push_back({item, (*outputs)[0]});
146         output_frame = nullptr;
147       } else {
148         // If this is a new iteration, start it.
149         if (input_iter->iter_num == input_frame->iteration_count) {
150           output_iter = input_frame->IncrementIteration(ready);
151         } else {
152           output_iter = input_frame->GetIteration(input_iter->iter_num + 1);
153         }
154       }
155     }
156     if (output_frame != nullptr) {
157       // This is the case when node is not Enter, Exit, or NextIteration.
158       DCHECK(input_frame == output_frame);
159       int activated = output_frame->ActivateNodesLocked(
160           item, is_dead, output_iter, outputs, ready);
161       output_frame->AdjustOutstandingOpsLocked(output_iter, activated, ready);
162     }
163     is_frame_done =
164         input_frame->DecrementOutstandingOpsLocked(input_iter, ready);
165   }
166 
167   // At this point, this node is completely done. We also know if the
168   // completion of this node makes its frame completed.
169   if (is_frame_done) {
170     FrameState* parent_frame = input_frame->parent_frame;
171     IterationState* parent_iter = input_frame->parent_iter;
172     DeleteFrame(input_frame, ready);
173     if (parent_frame != nullptr) {
174       // The completion of frame may cause completions in its parent frame.
175       // So clean things up recursively.
176       CleanupFramesIterations(parent_frame, parent_iter, ready);
177     }
178   }
179 }
180 
DumpIterationState(const FrameState * frame,IterationState * iteration)181 void PropagatorState::DumpIterationState(const FrameState* frame,
182                                          IterationState* iteration) {
183   const std::vector<const NodeItem*>* nodes = frame->nodes;
184   // Dump any waiting nodes that are holding on to tensors.
185   for (const NodeItem* node : *nodes) {
186     PendingCounts::Handle pending_id =
187         immutable_state_.pending_ids()[node->node_id];
188     if (iteration->node_state(pending_id) == PendingCounts::PENDING_NOTREADY ||
189         iteration->node_state(pending_id) == PendingCounts::PENDING_READY) {
190       DumpPendingNodeState(*node, iteration->input_tensors, false);
191     }
192   }
193   // Then the active nodes.
194   for (const NodeItem* node : *nodes) {
195     PendingCounts::Handle pending_id =
196         immutable_state_.pending_ids()[node->node_id];
197     if (iteration->node_state(pending_id) == PendingCounts::STARTED) {
198       DumpActiveNodeState(*node, iteration->input_tensors);
199     }
200   }
201   // Show all input tensors in use.
202   const int total_input_tensors = frame->total_input_tensors;
203   size_t total_bytes = 0;
204   for (int i = 0; i < total_input_tensors; ++i) {
205     const Entry& input = iteration->input_tensors[i];
206     const Tensor* tensor = GetTensorValueForDump(input);
207     if (tensor->IsInitialized()) {
208       LOG(WARNING) << "    Input " << i << ": "
209                    << strings::StrCat(
210                           "Tensor<type: ", DataTypeString(tensor->dtype()),
211                           " shape: ", tensor->shape().DebugString(),
212                           ", bytes: ", tensor->TotalBytes(), ">");
213       total_bytes += tensor->TotalBytes();
214     }
215   }
216   LOG(WARNING) << "    Total bytes " << total_bytes;
217 }
218 
DumpState()219 void PropagatorState::DumpState() {
220   mutex_lock l(mu_);
221   LOG(WARNING) << "Dumping state";
222   for (auto& frame : outstanding_frames_) {
223     LOG(WARNING) << frame.first;
224     FrameState* frame_state = frame.second;
225     frame_state->DumpIterationState(this);
226   }
227 }
228 
FindOrCreateChildFrame(FrameState * frame,IterationState * iter_state,const NodeItem & node_item,FrameState ** child)229 void PropagatorState::FindOrCreateChildFrame(FrameState* frame,
230                                              IterationState* iter_state,
231                                              const NodeItem& node_item,
232                                              FrameState** child) {
233   // Get the child frame name.
234   const ImmutableExecutorState::FrameInfo& frame_info =
235       immutable_state_.get_enter_frame_info(node_item);
236 
237   const uint64 child_id = Hash64Combine(
238       frame->frame_id,
239       Hash64Combine(iter_state->iter_num, Hash64(frame_info.name)));
240 
241   {
242     tf_shared_lock executor_lock(mu_);
243     auto it = outstanding_frames_.find(child_id);
244     if (it != outstanding_frames_.end()) {
245       *child = it->second;
246       return;
247     }
248   }
249 
250   // Need to create a new frame instance.
251   // Note that this new frame instance is created without any locks.
252   if (vlog_) {
253     const string child_name = strings::StrCat(
254         frame->frame_name, ";", iter_state->iter_num, ";", frame_info.name);
255     VLOG(2) << "Create frame: " << child_name << " id: " << child_id;
256   }
257 
258   FrameState* temp =
259       new FrameState(immutable_state_, frame_info.parallel_iterations);
260   temp->frame_id = child_id;
261   temp->parent_frame = frame;
262   temp->parent_iter = iter_state;
263   temp->InitializeFrameInfo(frame_info);
264 
265   // Initialize iteration 0.
266   {
267     mutex_lock l(temp->mu);
268     temp->SetIteration(0, new IterationState(0, temp->pending_counts,
269                                              temp->total_input_tensors));
270   }
271 
272   {
273     mutex_lock executor_lock(mu_);
274     auto it = outstanding_frames_.find(child_id);
275     if (it != outstanding_frames_.end()) {
276       *child = it->second;
277     } else {
278       mutex_lock frame_lock(frame->mu);
279       iter_state->outstanding_frame_count++;
280       outstanding_frames_[child_id] = temp;
281       *child = temp;
282       temp = nullptr;
283     }
284   }
285   delete temp;  // Not used so delete it.
286 }
287 
DeleteFrame(FrameState * frame,TaggedNodeSeq * ready)288 void PropagatorState::DeleteFrame(FrameState* frame, TaggedNodeSeq* ready) {
289   // First, propagate dead_exits (if any) to the parent frame.
290   FrameState* parent_frame = frame->parent_frame;
291   IterationState* parent_iter_state = frame->parent_iter;
292   if (parent_frame != nullptr) {
293     mutex_lock parent_frame_lock(parent_frame->mu);
294     // Propagate all the dead exits to the parent frame.
295     mutex_lock this_frame_lock(frame->mu);
296 
297     for (const NodeItem* item : frame->dead_exits) {
298       auto maybe_add_to_ready = [&](const NodeItem& dst_item, bool dst_ready,
299                                     bool dst_dead) {
300         if (dst_ready) {
301           if (dst_item.is_control_trigger) dst_dead = false;
302           ready->emplace_back(&dst_item, parent_frame, parent_iter_state,
303                               dst_dead);
304           parent_iter_state->outstanding_ops++;
305         }
306       };
307 
308       auto propagate_to_non_merge = [&](PendingCounts::Handle dst_pending_id) {
309         parent_iter_state->increment_dead_count(dst_pending_id);
310         return parent_iter_state->decrement_pending(dst_pending_id, 1) == 0;
311       };
312 
313       for (const EdgeInfo& e : item->output_edges()) {
314         const NodeItem& dst_item =
315             immutable_state_.graph_view().node_ref(e.dst_id);
316         const auto dst_pending_id = immutable_state_.pending_ids()[e.dst_id];
317 
318         bool dst_dead = true;
319         bool dst_ready;
320         // We know this is a dead input to dst.
321         if (dst_item.is_merge) {
322           parent_iter_state->increment_dead_count(dst_pending_id);
323           const int dead_cnt = parent_iter_state->dead_count(dst_pending_id);
324           dst_dead = (dead_cnt == dst_item.num_inputs);
325           dst_ready =
326               (parent_iter_state->pending(dst_pending_id) == 1) && dst_dead;
327         } else {
328           dst_ready = propagate_to_non_merge(dst_pending_id);
329         }
330         maybe_add_to_ready(dst_item, dst_ready, dst_dead);
331       }
332 
333       for (const ControlEdgeInfo& e : item->output_control_edges()) {
334         const NodeItem& dst_item =
335             immutable_state_.graph_view().node_ref(e.dst_id);
336         const auto dst_pending_id = immutable_state_.pending_ids()[e.dst_id];
337 
338         bool dst_dead;
339         bool dst_ready;
340         // We know this is a dead input to dst.
341         if (dst_item.is_merge) {
342           parent_iter_state->decrement_pending(dst_pending_id, 2);
343           int count = parent_iter_state->pending(dst_pending_id);
344           int dead_cnt = parent_iter_state->dead_count(dst_pending_id);
345           dst_dead = (dead_cnt == dst_item.num_inputs);
346           dst_ready = (count == 0) || ((count == 1) && dst_dead);
347         } else {
348           dst_dead = true;
349           dst_ready = propagate_to_non_merge(dst_pending_id);
350         }
351         maybe_add_to_ready(dst_item, dst_ready, dst_dead);
352       }
353     }
354   }
355 
356   // Delete the frame.
357   if (vlog_) VLOG(2) << "Delete frame " << frame->frame_id;
358   {
359     mutex_lock executor_lock(mu_);
360     outstanding_frames_.erase(frame->frame_id);
361   }
362   delete frame;
363 }
364 
CleanupFramesIterations(FrameState * frame,IterationState * iter_state,TaggedNodeSeq * ready)365 void PropagatorState::CleanupFramesIterations(FrameState* frame,
366                                               IterationState* iter_state,
367                                               TaggedNodeSeq* ready) {
368   bool is_frame_done = false;
369   {
370     mutex_lock frame_lock(frame->mu);
371     iter_state->outstanding_frame_count--;
372     is_frame_done = frame->CleanupIterations(iter_state, ready);
373   }
374   if (is_frame_done) {
375     FrameState* parent_frame = frame->parent_frame;
376     IterationState* parent_iter = frame->parent_iter;
377     DeleteFrame(frame, ready);
378     if (parent_frame != nullptr) {
379       // The completion of frame may cause completions in its parent frame.
380       // So clean things up recursively.
381       CleanupFramesIterations(parent_frame, parent_iter, ready);
382     }
383   }
384 }
385 
386 template <bool atomic>
ActivateNodesFastPathInternal(const NodeItem * item,const bool is_dead,IterationState * iter_state,EntryVector * outputs,TaggedNodeSeq * ready)387 int PropagatorState::FrameState::ActivateNodesFastPathInternal(
388     const NodeItem* item, const bool is_dead, IterationState* iter_state,
389     EntryVector* outputs, TaggedNodeSeq* ready) {
390   // If we know that none of the item's edge destinations require special
391   // handling (i.e. none of the nodes is a merge or control trigger node), we
392   // can take a fast path that avoids accessing the destination NodeItem.
393   const GraphView& gview = immutable_state.graph_view();
394   int new_outstanding = 0;
395 
396 // Add dst to the ready queue if it's ready
397 //
398 // NOTE(mrry): Use a macro here instead of a lambda, because this method is
399 // performance-critical and we need to ensure that the code is inlined.
400 #define MAYBE_ADD_TO_READY(dst_id, adjust_result)         \
401   do {                                                    \
402     if (!adjust_result.any_pending) {                     \
403       const NodeItem* dst_item = &gview.node_ref(dst_id); \
404       TaggedNode& t = ready->emplace_back();              \
405       t.node_item = dst_item;                             \
406       t.input_frame = this;                               \
407       t.input_iter = iter_state;                          \
408       t.is_dead = adjust_result.any_dead;                 \
409       new_outstanding++;                                  \
410     }                                                     \
411   } while (0);
412 
413   Entry* input_tensors = iter_state->input_tensors;
414   for (const EdgeInfo& e : item->output_edges()) {
415     const int dst_id = e.dst_id;
416     const PendingCounts::Handle dst_pending_id =
417         immutable_state.pending_ids()[dst_id];
418     const int src_slot = e.output_slot;
419 
420     const bool increment_dead =
421         (is_dead || ((*outputs)[src_slot].state == Entry::State::NO_VALUE));
422     const int dst_loc = e.input_slot;
423     if (e.is_last) {
424       input_tensors[dst_loc] = std::move((*outputs)[src_slot]);
425     } else {
426       input_tensors[dst_loc] = (*outputs)[src_slot];
427     }
428     const PendingCounts::AdjustResult adjust_result =
429         atomic
430             ? iter_state->adjust_for_activation_atomic(dst_pending_id,
431                                                        increment_dead)
432             : iter_state->adjust_for_activation(dst_pending_id, increment_dead);
433     MAYBE_ADD_TO_READY(dst_id, adjust_result);
434   }
435 
436   for (const ControlEdgeInfo& e : item->output_control_edges()) {
437     const int dst_id = e.dst_id;
438     const PendingCounts::Handle dst_pending_id =
439         immutable_state.pending_ids()[dst_id];
440     const PendingCounts::AdjustResult adjust_result =
441         atomic
442             ? iter_state->adjust_for_activation_atomic(dst_pending_id, is_dead)
443             : iter_state->adjust_for_activation(dst_pending_id, is_dead);
444     MAYBE_ADD_TO_READY(dst_id, adjust_result);
445   }
446 
447   return new_outstanding;
448 #undef MAYBE_ADD_TO_READY
449 }
450 
ActivateNodesSlowPath(const NodeItem * item,const bool is_dead,IterationState * iter_state,EntryVector * outputs,TaggedNodeSeq * ready)451 int PropagatorState::FrameState::ActivateNodesSlowPath(
452     const NodeItem* item, const bool is_dead, IterationState* iter_state,
453     EntryVector* outputs, TaggedNodeSeq* ready) {
454   // If any of the edge destinations is a merge or a control trigger node,
455   // we need to read each destination NodeItem to determine what action
456   // to take.
457   const GraphView& gview = immutable_state.graph_view();
458   int activated = 0;
459   auto maybe_add_to_ready = [&](int dst_id, const NodeItem* dst_item,
460                                 bool dst_ready, bool dst_dead) {
461     // Add dst to the ready queue if it's ready
462     if (dst_ready) {
463       if (dst_item->is_control_trigger) dst_dead = false;
464       ready->emplace_back(dst_item, this, iter_state, dst_dead);
465       activated++;
466     }
467   };
468 
469   Entry* input_tensors = iter_state->input_tensors;
470 
471   for (const EdgeInfo& e : item->output_edges()) {
472     const int dst_id = e.dst_id;
473     const NodeItem* dst_item = &gview.node_ref(dst_id);
474     const PendingCounts::Handle dst_pending_id =
475         immutable_state.pending_ids()[dst_id];
476     const int src_slot = e.output_slot;
477 
478     bool dst_dead = false;
479     bool dst_ready = false;
480     bool dst_need_input = true;
481 
482     if (dst_item->is_merge) {
483       // A merge node is ready if all control inputs have arrived and either
484       // a) a live data input becomes available or b) all data inputs are
485       // dead. For Merge, pending's LSB is set iff a live data input has
486       // arrived.
487       if ((*outputs)[src_slot].state != Entry::State::NO_VALUE) {
488         // This is a live data input.
489         int count = iter_state->pending(dst_pending_id);
490         iter_state->mark_live(dst_pending_id);
491         // Only the first live edge sets the input and (potentially)
492         // triggers execution. The low bit of count is set if and
493         // only if no live input has been used yet (mark_live clears
494         // it). The node should be started if and only if this is
495         // the first live input and there are no pending control
496         // edges, i.e. count == 1.
497         dst_ready = (count == 1);
498         dst_need_input = ((count & 0x1) == 1);
499       } else {
500         // This is a dead data input. Note that dst_node is dead if node is
501         // a dead enter. We need this to handle properly a while loop on
502         // the untaken branch of a conditional.
503         // TODO(yuanbyu): This is a bit hacky, but a good solution for
504         // now.
505         iter_state->increment_dead_count(dst_pending_id);
506         const int dead_cnt = iter_state->dead_count(dst_pending_id);
507         dst_dead = (dead_cnt == dst_item->num_inputs) || item->is_enter;
508         dst_ready = (iter_state->pending(dst_pending_id) == 1) && dst_dead;
509         dst_need_input = false;
510       }
511     } else {
512       // Handle all other (non-merge) nodes.
513       const bool increment_dead =
514           (is_dead || ((*outputs)[src_slot].state == Entry::State::NO_VALUE));
515       const PendingCounts::AdjustResult adjust_result =
516           iter_state->adjust_for_activation(dst_pending_id, increment_dead);
517       dst_dead = adjust_result.any_dead;
518       dst_ready = !adjust_result.any_pending;
519     }
520 
521     if (dst_need_input) {
522       const int dst_loc = e.input_slot;
523       if (e.is_last) {
524         input_tensors[dst_loc] = std::move((*outputs)[src_slot]);
525       } else {
526         input_tensors[dst_loc] = (*outputs)[src_slot];
527       }
528     }
529 
530     maybe_add_to_ready(dst_id, dst_item, dst_ready, dst_dead);
531   }
532 
533   for (const ControlEdgeInfo& e : item->output_control_edges()) {
534     const int dst_id = e.dst_id;
535     const NodeItem* dst_item = &gview.node_ref(dst_id);
536     const PendingCounts::Handle dst_pending_id =
537         immutable_state.pending_ids()[dst_id];
538 
539     bool dst_dead;
540     bool dst_ready;
541     if (dst_item->is_merge) {
542       // A merge node is ready if all control inputs have arrived and either
543       // a) a live data input becomes available or b) all data inputs are
544       // dead. For Merge, pending's LSB is set iff a live data input has
545       // arrived.
546       iter_state->decrement_pending(dst_pending_id, 2);
547       int count = iter_state->pending(dst_pending_id);
548       int dead_cnt = iter_state->dead_count(dst_pending_id);
549       dst_dead = (dead_cnt == dst_item->num_inputs);
550       dst_ready = (count == 0) || ((count == 1) && dst_dead);
551     } else {
552       // Handle all other (non-merge) nodes.
553       const PendingCounts::AdjustResult adjust_result =
554           iter_state->adjust_for_activation(dst_pending_id, is_dead);
555       dst_dead = adjust_result.any_dead;
556       dst_ready = !adjust_result.any_pending;
557     }
558     maybe_add_to_ready(dst_id, dst_item, dst_ready, dst_dead);
559   }
560 
561   return activated;
562 }
563 
ActivateNodesAndAdjustOutstanding(const NodeItem * item,const bool is_dead,IterationState * iter_state,EntryVector * outputs,TaggedNodeSeq * ready)564 bool PropagatorState::FrameState::ActivateNodesAndAdjustOutstanding(
565     const NodeItem* item, const bool is_dead, IterationState* iter_state,
566     EntryVector* outputs, TaggedNodeSeq* ready) {
567   if (TF_PREDICT_FALSE(item->is_any_consumer_merge_or_control_trigger)) {
568     mutex_lock l(mu);
569     int activated =
570         ActivateNodesSlowPath(item, is_dead, iter_state, outputs, ready);
571     return AdjustOutstandingOpsLocked(iter_state, activated - 1, ready);
572   }
573   {
574     tf_shared_lock l(mu);
575     int activated =
576         ActivateNodesFastPathShared(item, is_dead, iter_state, outputs, ready);
577     bool iter_done = AdjustOutstandingOpsFastPath(iter_state, activated - 1);
578     if (!iter_done) return false;
579   }
580   mutex_lock l(mu);
581   return CleanupIterations(iter_state, ready);
582 }
583 
ActivateNodesLocked(const NodeItem * item,const bool is_dead,IterationState * iter_state,EntryVector * outputs,TaggedNodeSeq * ready)584 int PropagatorState::FrameState::ActivateNodesLocked(const NodeItem* item,
585                                                      const bool is_dead,
586                                                      IterationState* iter_state,
587                                                      EntryVector* outputs,
588                                                      TaggedNodeSeq* ready) {
589   if (TF_PREDICT_FALSE(item->is_any_consumer_merge_or_control_trigger)) {
590     return ActivateNodesSlowPath(item, is_dead, iter_state, outputs, ready);
591   } else {
592     return ActivateNodesFastPathLocked(item, is_dead, iter_state, outputs,
593                                        ready);
594   }
595 }
596 
ActivateNexts(IterationState * iter_state,TaggedNodeSeq * ready)597 void PropagatorState::FrameState::ActivateNexts(IterationState* iter_state,
598                                                 TaggedNodeSeq* ready) {
599   int activated = 0;
600   // Propagate the deferred NextIteration nodes to the new iteration.
601   for (auto& node_entry : next_iter_roots) {
602     const NodeItem* item = node_entry.first;
603     const Entry& entry = node_entry.second;
604     const bool is_dead = entry.state == Entry::State::NO_VALUE;
605     EntryVector outputs{entry};
606     activated +=
607         ActivateNodesLocked(item, is_dead, iter_state, &outputs, ready);
608   }
609   next_iter_roots.clear();
610   AdjustOutstandingOpsLocked(iter_state, activated, ready);
611 }
612 
ActivateLoopInvs(IterationState * iter_state,TaggedNodeSeq * ready)613 void PropagatorState::FrameState::ActivateLoopInvs(IterationState* iter_state,
614                                                    TaggedNodeSeq* ready) {
615   // Propagate loop invariants to the new iteration.
616   int activated = 0;
617   for (auto& node_entry : inv_values) {
618     const NodeItem* item = node_entry.first;
619     const Entry& entry = node_entry.second;
620     const bool is_dead = entry.state == Entry::State::NO_VALUE;
621     EntryVector outputs{entry};
622     activated +=
623         ActivateNodesLocked(item, is_dead, iter_state, &outputs, ready);
624   }
625   AdjustOutstandingOpsLocked(iter_state, activated, ready);
626 }
627 
AddLoopInv(const NodeItem * item,const Entry & entry,TaggedNodeSeq * ready)628 void PropagatorState::FrameState::AddLoopInv(const NodeItem* item,
629                                              const Entry& entry,
630                                              TaggedNodeSeq* ready) {
631   // Store this value.
632   inv_values.push_back({item, entry});
633 
634   // Make this value available to all iterations.
635   const bool is_dead = entry.state == Entry::State::NO_VALUE;
636   for (int i = 0; i <= iteration_count; ++i) {
637     EntryVector outputs{entry};
638     IterationState* iter_state = GetIteration(i);
639     int activated =
640         ActivateNodesLocked(item, is_dead, iter_state, &outputs, ready);
641     AdjustOutstandingOpsLocked(iter_state, activated, ready);
642   }
643 }
644 
IsIterationDone(IterationState * iter_state)645 bool PropagatorState::FrameState::IsIterationDone(IterationState* iter_state) {
646   if (iter_state->outstanding_ops == 0 &&
647       iter_state->outstanding_frame_count == 0) {
648     if (iter_state->iter_num == 0) {
649       // The enclosing frame has no pending input.
650       return num_pending_inputs == 0;
651     } else {
652       // The preceding iteration is deleted (and therefore done).
653       return (GetIteration(iter_state->iter_num - 1) == nullptr);
654     }
655   }
656   return false;
657 }
658 
659 PropagatorState::IterationState*
IncrementIteration(TaggedNodeSeq * ready)660 PropagatorState::FrameState::IncrementIteration(TaggedNodeSeq* ready) {
661   iteration_count++;
662 
663   // Initialize the next iteration.
664   IterationState* next_iter =
665       new IterationState(iteration_count, pending_counts, total_input_tensors);
666   SetIteration(iteration_count, next_iter);
667   num_outstanding_iterations++;
668   dead_exits.clear();
669 
670   // Activate the successors of the deferred roots in the new iteration.
671   ActivateNexts(next_iter, ready);
672 
673   // Activate the loop invariants in the new iteration.
674   ActivateLoopInvs(next_iter, ready);
675 
676   return next_iter;
677 }
678 
CleanupIterations(IterationState * iter_state,TaggedNodeSeq * ready)679 bool PropagatorState::FrameState::CleanupIterations(IterationState* iter_state,
680                                                     TaggedNodeSeq* ready) {
681   int64 curr_iter = iter_state->iter_num;
682   while (curr_iter <= iteration_count && IsIterationDone(iter_state)) {
683     delete iter_state;
684     SetIteration(curr_iter, nullptr);
685     --num_outstanding_iterations;
686     ++curr_iter;
687 
688     // When one iteration is completed, we check for deferred iteration,
689     // and start it if there is one.
690     if (!next_iter_roots.empty()) {
691       IncrementIteration(ready);
692     }
693 
694     if (curr_iter <= iteration_count) {
695       iter_state = GetIteration(curr_iter);
696     }
697   }
698   return IsFrameDone();
699 }
700 
InitializeFrameInfo(const ImmutableExecutorState::FrameInfo & finfo)701 void PropagatorState::FrameState::InitializeFrameInfo(
702     const ImmutableExecutorState::FrameInfo& finfo) {
703   pending_counts = finfo.pending_counts.get();
704   total_input_tensors = finfo.total_inputs;
705   num_pending_inputs = finfo.input_count;
706   nodes = finfo.nodes.get();
707 }
708 
SetIteration(int64 iter,IterationState * state)709 void PropagatorState::FrameState::SetIteration(int64 iter,
710                                                IterationState* state)
711     TF_EXCLUSIVE_LOCKS_REQUIRED(mu) {
712   size_t index = iter % (max_parallel_iterations + 1);
713   DCHECK(state == nullptr || iterations[index] == nullptr);
714   iterations_raw[index] = state;
715   if (index == 0) {
716     iterations_first = state;
717   }
718 }
719 
720 // Decrement the outstanding op count and clean up the iterations in the
721 // frame. Return true iff the execution of the frame is done.
DecrementOutstandingOps(IterationState * iter_state,TaggedNodeSeq * ready)722 bool PropagatorState::FrameState::DecrementOutstandingOps(
723     IterationState* iter_state, TaggedNodeSeq* ready) {
724   return AdjustOutstandingOps(iter_state, -1, ready);
725 }
726 
AdjustOutstandingOps(IterationState * iter_state,int delta,TaggedNodeSeq * ready)727 bool PropagatorState::FrameState::AdjustOutstandingOps(
728     IterationState* iter_state, int delta, TaggedNodeSeq* ready) {
729   // Given the following profile of values of 'delta' for wide_deep model from
730   // the TF model garden:
731   //
732   // Count  Value
733   // ---------------
734   // 757938 delta=0x0
735   // 541713 delta=0xffffffff
736   // 138115 delta=0x1
737   //  58770 delta=0x2
738   //   5394 delta=0x3
739   //   4669 delta=0x4
740   //   2037 delta=0xa
741   //   1646 delta=0x7
742   //   1632 delta=0x6
743   //   1613 delta=0x6c
744   //   1224 delta=0x5
745   //    409 delta=0x53
746   //     17 delta=0x86
747   //
748   // ... it's worth no-opping out when delta == 0 to avoid the atomic
749   // instruction.
750   if (delta == 0) {
751     return false;
752   }
753   {
754     tf_shared_lock sl(mu);
755     if (TF_PREDICT_TRUE(!AdjustOutstandingOpsFastPath(iter_state, delta))) {
756       return false;
757     }
758   }
759   mutex_lock l(mu);
760   DCHECK(IsIterationDone(iter_state));
761   return CleanupIterations(iter_state, ready);
762 }
763 
AdjustOutstandingOpsFastPath(IterationState * iter_state,int delta)764 bool PropagatorState::FrameState::AdjustOutstandingOpsFastPath(
765     IterationState* iter_state, int delta) {
766   auto old_val = iter_state->outstanding_ops.fetch_add(delta);
767   return (old_val + delta == 0) && IsIterationDone(iter_state);
768 }
769 
770 // Decrement the outstanding op count and clean up the iterations in the
771 // frame. Return true iff the execution of the frame is done.
DecrementOutstandingOpsLocked(IterationState * iter_state,TaggedNodeSeq * ready)772 bool PropagatorState::FrameState::DecrementOutstandingOpsLocked(
773     IterationState* iter_state, TaggedNodeSeq* ready)
774     TF_EXCLUSIVE_LOCKS_REQUIRED(mu) {
775   return AdjustOutstandingOpsLocked(iter_state, -1, ready);
776 }
777 
AdjustOutstandingOpsLocked(IterationState * iter_state,int delta,TaggedNodeSeq * ready)778 bool PropagatorState::FrameState::AdjustOutstandingOpsLocked(
779     IterationState* iter_state, int delta, TaggedNodeSeq* ready) {
780   // We hold the lock, so we don't need to use an atomic modification.
781   auto cur_val = iter_state->outstanding_ops.load(std::memory_order_relaxed);
782   DCHECK(delta >= 0 || cur_val >= -delta)
783       << "cannot adjust outstanding_ops by " << delta
784       << " when current value is " << cur_val;
785   auto new_val = cur_val + delta;
786   iter_state->outstanding_ops.store(new_val, std::memory_order_relaxed);
787   if (new_val != 0) {
788     return false;
789   }
790   return CleanupIterations(iter_state, ready);
791 }
792 
793 // Returns true if the computation in the frame is completed.
IsFrameDone()794 bool PropagatorState::FrameState::IsFrameDone()
795     TF_EXCLUSIVE_LOCKS_REQUIRED(mu) {
796   return (num_pending_inputs == 0 && num_outstanding_iterations == 0);
797 }
798 
799 }  // namespace tensorflow
800