1 #ifndef TENSORFLOW_CORE_COMMON_RUNTIME_PENDING_COUNTS_H_
2 #define TENSORFLOW_CORE_COMMON_RUNTIME_PENDING_COUNTS_H_
3 
4 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
5 
6 Licensed under the Apache License, Version 2.0 (the "License");
7 you may not use this file except in compliance with the License.
8 You may obtain a copy of the License at
9 
10     http://www.apache.org/licenses/LICENSE-2.0
11 
12 Unless required by applicable law or agreed to in writing, software
13 distributed under the License is distributed on an "AS IS" BASIS,
14 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15 See the License for the specific language governing permissions and
16 limitations under the License.
17 ==============================================================================*/
18 
19 #include <atomic>
20 
21 #include "tensorflow/core/lib/gtl/flatmap.h"
22 #include "tensorflow/core/lib/hash/hash.h"
23 #include "tensorflow/core/platform/logging.h"
24 #include "tensorflow/core/platform/macros.h"
25 #include "tensorflow/core/util/port.h"
26 
27 namespace tensorflow {
28 
29 // PendingCounts is an internal helper class to keep track of pending and
30 // dead counts for nodes, for use in the ExecutorState module.  It
31 // holds a map from Handles to various counts for that handle.  This
32 // information is needed per frame iteration. The amount of memory
33 // needed for an iteration is the same across all executions of the
34 // iteration. The memory amount and handles are precomputed at startup
35 // using a Layout object.
36 //
37 //    PendingCounts::Layout layout;
38 //    std::vector<PendingCounts::Handle> h(C);
39 //    for (int id = 0; id < C; id++) {
40 //      h[id] = r.AddHandle(max_pending[id], max_dead[id]);
41 //    }
42 //
43 // When we actually want to start an iteration we first create a
44 // PendingCounts object and then index into it using the precomputed
45 // handles:
46 
47 //    PendingCounts counts(layout);
48 //    ...
49 //    counts.decrement_pending(h[id], 1);
50 class PendingCounts {
51  public:
52   // The state machine for a node's execution.
53   enum NodeState {
54     // The pending count for the node > 0.
55     PENDING_NOTREADY,
56     // The pending count for the node == 0, but the node has not
57     // started executing.
58     PENDING_READY,
59     // The node has started executing.
60     STARTED,
61     // The node has finished executing.
62     COMPLETED
63   };
64 
65   // An opaque handle indicating where in the PendingCounts data structure
66   // the appropriate count information can be found.
67   class Handle;
68   // Given a node that needs to represent counts no larger than the
69   // specified "max_pending_count" and "max_dead_count", create a
70   // handle that can be passed to various PendingCounts routines
71   // to retrieve the count data for this node.
72   class Layout {
73    public:
74     Handle CreateHandle(size_t max_pending_count, size_t max_dead_count);
75 
76    private:
77     friend class PendingCounts;
78     int next_offset_ = 0;  // Next byte offset to allocate
79   };
80 
81   // Create a new PendingCounts object that can hold the state of
82   // all the Handles allocated from "final_allocator".
PendingCounts(Layout layout)83   explicit PendingCounts(Layout layout)
84       : num_bytes_(layout.next_offset_), bytes_(new char[num_bytes_]) {
85     if (num_bytes_ >= sizeof(LargeCounts)) {
86       CHECK_EQ(uintptr_t(bytes_) % alignof(LargeCounts), 0);
87     }
88   }
89 
90   // Create a new PendingCounts object with the same layout and counts
91   // as "other".
PendingCounts(const PendingCounts & other)92   explicit PendingCounts(const PendingCounts& other)
93       : num_bytes_(other.num_bytes_), bytes_(new char[num_bytes_]) {
94     if (num_bytes_ >= sizeof(LargeCounts)) {
95       CHECK_EQ(uintptr_t(bytes_) % alignof(LargeCounts), 0);
96     }
97     memcpy(bytes_, other.bytes_, other.num_bytes_);
98   }
99 
~PendingCounts()100   ~PendingCounts() { delete[] bytes_; }
101 
set_initial_count(Handle h,size_t pending_count)102   void set_initial_count(Handle h, size_t pending_count) {
103     if (h.is_large_) {
104       std::atomic<LargeCounts>* c_ptr = Large(h);
105       auto c = c_ptr->load(std::memory_order_relaxed);
106       c.pending = pending_count;
107       c.dead_count = 0;
108       c.has_started = 0;
109       c_ptr->store(c, std::memory_order_relaxed);
110     } else {
111       DCHECK_LE(pending_count, kMaxCountForPackedCounts);
112       std::atomic<PackedCounts>* c_ptr = Packed(h);
113       auto c = c_ptr->load(std::memory_order_relaxed);
114       c.pending = pending_count;
115       c.dead_count = 0;
116       c.has_started = 0;
117       c_ptr->store(c, std::memory_order_relaxed);
118     }
119   }
120 
node_state(Handle h)121   NodeState node_state(Handle h) {
122     if (h.is_large_) {
123       return NodeStateForStruct(Large(h)->load(std::memory_order_relaxed));
124     } else {
125       return NodeStateForStruct(Packed(h)->load(std::memory_order_relaxed));
126     }
127   }
mark_started(Handle h)128   void mark_started(Handle h) {
129     DCHECK_EQ(pending(h), 0);
130     if (h.is_large_) {
131       std::atomic<LargeCounts>* c_ptr = Large(h);
132       auto c = c_ptr->load(std::memory_order_relaxed);
133       DCHECK_EQ(c.has_started, 0);
134       c.has_started = 1;
135       c_ptr->store(c, std::memory_order_relaxed);
136     } else {
137       std::atomic<PackedCounts>* c_ptr = Packed(h);
138       auto c = c_ptr->load(std::memory_order_relaxed);
139       DCHECK_EQ(c.has_started, 0);
140       c.has_started = 1;
141       c_ptr->store(c, std::memory_order_relaxed);
142     }
143   }
mark_completed(Handle h)144   void mark_completed(Handle h) {
145     if (h.is_large_) {
146       std::atomic<LargeCounts>* c_ptr = Large(h);
147       auto c = c_ptr->load(std::memory_order_relaxed);
148       DCHECK_EQ(c.has_started, 1);
149       c.pending = 1;
150       c_ptr->store(c, std::memory_order_relaxed);
151     } else {
152       std::atomic<PackedCounts>* c_ptr = Packed(h);
153       auto c = c_ptr->load(std::memory_order_relaxed);
154       DCHECK_EQ(c.has_started, 1);
155       c.pending = 1;
156       c_ptr->store(c, std::memory_order_relaxed);
157     }
158   }
pending(Handle h)159   int pending(Handle h) {
160     if (h.is_large_) {
161       LargeCounts c = Large(h)->load(std::memory_order_relaxed);
162       if (PENDING_NOTREADY == NodeStateForStruct(c)) {
163         return c.pending;
164       } else {
165         // The pending count encodes the state once the node has
166         // started, so just return 0.
167         return 0;
168       }
169     } else {
170       PackedCounts c = Packed(h)->load(std::memory_order_relaxed);
171       if (PENDING_NOTREADY == NodeStateForStruct(c)) {
172         return c.pending;
173       } else {
174         // The pending count encodes the state once the node has
175         // started, so just return 0.
176         return 0;
177       }
178     }
179   }
decrement_pending(Handle h,int v)180   int decrement_pending(Handle h, int v) {
181     DCHECK_GE(pending(h), v);
182     if (h.is_large_) {
183       std::atomic<LargeCounts>* c_ptr = Large(h);
184       auto c = c_ptr->load(std::memory_order_relaxed);
185       c.pending -= v;
186       c_ptr->store(c, std::memory_order_relaxed);
187       return c.pending;
188     } else {
189       std::atomic<PackedCounts>* c_ptr = Packed(h);
190       auto c = c_ptr->load(std::memory_order_relaxed);
191       c.pending -= v;
192       c_ptr->store(c, std::memory_order_relaxed);
193       return c.pending;
194     }
195   }
196   // Mark a merge node as live
197   // REQUIRES: Node corresponding to "h" is a merge node
mark_live(Handle h)198   void mark_live(Handle h) {
199     if (h.is_large_) {
200       std::atomic<LargeCounts>* c_ptr = Large(h);
201       auto c = c_ptr->load(std::memory_order_relaxed);
202       // Only do anything if the node hasn't already started executing.
203       if (PENDING_NOTREADY == NodeStateForStruct(c)) {
204         c.pending &= ~static_cast<int>(0x1);
205         c_ptr->store(c, std::memory_order_relaxed);
206       }
207     } else {
208       std::atomic<PackedCounts>* c_ptr = Packed(h);
209       auto c = c_ptr->load(std::memory_order_relaxed);
210       // Only do anything if the node hasn't already started executing.
211       if (PENDING_NOTREADY == NodeStateForStruct(c)) {
212         static_assert(7 == kMaxCountForPackedCounts,
213                       "Live flag incorrect for max packed count");
214         c.pending &= 0x6;
215         c_ptr->store(c, std::memory_order_relaxed);
216       }
217     }
218   }
219 
dead_count(Handle h)220   int dead_count(Handle h) {
221     int r = h.is_large_ ? Large(h)->load(std::memory_order_relaxed).dead_count
222                         : Packed(h)->load(std::memory_order_relaxed).dead_count;
223     return r;
224   }
increment_dead_count(Handle h)225   void increment_dead_count(Handle h) {
226     if (h.is_large_) {
227       std::atomic<LargeCounts>* c_ptr = Large(h);
228       auto c = c_ptr->load(std::memory_order_relaxed);
229       if (PENDING_NOTREADY == NodeStateForStruct(c)) {
230         c.dead_count++;
231         c_ptr->store(c, std::memory_order_relaxed);
232       }
233     } else {
234       std::atomic<PackedCounts>* c_ptr = Packed(h);
235       auto c = c_ptr->load(std::memory_order_relaxed);
236       if (PENDING_NOTREADY == NodeStateForStruct(c)) {
237         DCHECK_LT(c.dead_count, kMaxCountForPackedCounts);
238         c.dead_count++;
239         c_ptr->store(c, std::memory_order_relaxed);
240       }
241     }
242   }
243 
244   struct AdjustResult {
245     bool any_dead;
246     bool any_pending;
247 
AdjustResultAdjustResult248     AdjustResult(bool any_dead, bool any_pending)
249         : any_dead(any_dead), any_pending(any_pending) {}
250   };
251 
252   // A streamlined routine that does several pieces of bookkeeping at
253   // once.  Equivalent to:
254   //    if (increment_dead) increment_dead_count(h);
255   //    decrement_pending(h, 1);
256   //    return {dead_count(h) > 0, pending(h) > 0};
adjust_for_activation(Handle h,bool increment_dead)257   AdjustResult adjust_for_activation(Handle h, bool increment_dead) {
258     DCHECK_GE(pending(h), 1);
259     if (h.is_large_) {
260       return adjust_for_activation_shared(Large(h), increment_dead);
261     } else {
262       return adjust_for_activation_shared(Packed(h), increment_dead);
263     }
264   }
265 
266   // The same as the above, but performs the operation atomically. This
267   // is thread-safe to run concurrently with other threads.
adjust_for_activation_atomic(Handle h,bool increment_dead)268   AdjustResult adjust_for_activation_atomic(Handle h, bool increment_dead) {
269     DCHECK_GE(pending(h), 1);
270     if (h.is_large_) {
271       return adjust_for_activation_shared_atomic(Large(h), increment_dead);
272     } else {
273       return adjust_for_activation_shared_atomic(Packed(h), increment_dead);
274     }
275   }
276 
277   class Handle {
278    public:
Handle()279     Handle() : byte_offset_(0), is_large_(0) {}
280 
281    private:
282     friend class PendingCounts;
283     int byte_offset_ : 31;  // Byte offset of the rep in PendingCounts object
284     bool is_large_ : 1;  // If true, rep is LargeCounts; otherwise PackedCounts
285   };
286 
287  private:
288   template <typename T>
adjust_for_activation_shared(std::atomic<T> * c,bool increment_dead)289   inline AdjustResult adjust_for_activation_shared(std::atomic<T>* c,
290                                                    bool increment_dead) {
291     T val = c->load(std::memory_order_relaxed);
292     if (increment_dead && PENDING_NOTREADY == NodeStateForStruct(val)) {
293       val.dead_count++;
294     }
295     val.pending--;
296     c->store(val, std::memory_order_relaxed);
297     return AdjustResult(val.dead_count, val.pending);
298   }
299 
300   template <typename T>
adjust_for_activation_shared_atomic(std::atomic<T> * c,bool increment_dead)301   inline AdjustResult adjust_for_activation_shared_atomic(std::atomic<T>* c,
302                                                           bool increment_dead) {
303     T old_val = c->load(std::memory_order_relaxed);
304     while (true) {
305       T new_val = old_val;
306       if (increment_dead && PENDING_NOTREADY == NodeStateForStruct(new_val)) {
307         new_val.dead_count++;
308       }
309       new_val.pending--;
310       AdjustResult ret(new_val.dead_count, new_val.pending);
311       if (TF_PREDICT_TRUE(c->compare_exchange_weak(old_val, new_val)))
312         return ret;
313     }
314   }
315 
316   // We keep track of the pending count and dead input count for each
317   // graph node.  The representation used here is designed to be cache
318   // efficient for graphs with large numbers of nodes, where most
319   // nodes have relatively small maximum pending counts (e.g. for one
320   // LSTM model, 99% of 5000+ nodes had in-degrees of 3 or less).  We
321   // use one byte to hold both the pending and dead count for a node
322   // where these together can fit in one byte, and we use a hash table
323   // to handle the rare node ids that need larger counts than this.
324   // Each frame in this subgraph has its own PendingCounts.
325 
326   // We use 3 bits each for dead_count and pending.
327   static constexpr int kMaxCountForPackedCounts = 7;
328 
329   // Most counts are small, so we pack a pending count and a dead
330   // count into 3 bits each, use 1 bit to indicate that the node has
331   // started computing.
332   struct PackedCounts {
333     uint8 pending : 3;
334     uint8 dead_count : 3;
335     uint8 has_started : 1;
336   };
337 
338   // NOTE: alignas(8) is critical to implement efficient atomic<LargeCounts>
339   // on MSVC.
340   struct alignas(8) LargeCounts {
341     uint32 pending;
342     uint32 dead_count : 31;
343     // NOTE(tlipcon): MSVC won't pack this struct into 8 bytes unless
344     // all of the member types are uint32.
345     uint32 has_started : 1;
346   };
347 
348   template <typename T>
NodeStateForStruct(const T & c)349   NodeState NodeStateForStruct(const T& c) const {
350     if (c.has_started) {
351       return (c.pending == 0) ? STARTED : COMPLETED;
352     } else {
353       return (c.pending == 0) ? PENDING_READY : PENDING_NOTREADY;
354     }
355   }
Large(Handle h)356   inline std::atomic<LargeCounts>* Large(Handle h) {
357     DCHECK(h.is_large_);
358     DCHECK_LE(h.byte_offset_ + sizeof(std::atomic<LargeCounts>), num_bytes_);
359     DCHECK_EQ(h.byte_offset_ % alignof(std::atomic<LargeCounts>), 0);
360     return reinterpret_cast<std::atomic<LargeCounts>*>(bytes_ + h.byte_offset_);
361   }
Packed(Handle h)362   inline std::atomic<PackedCounts>* Packed(Handle h) {
363     DCHECK(!h.is_large_);
364     DCHECK_LE(h.byte_offset_ + sizeof(PackedCounts), num_bytes_);
365     return reinterpret_cast<std::atomic<PackedCounts>*>(bytes_ +
366                                                         h.byte_offset_);
367   }
368 
369   const int num_bytes_;  // Just for bounds checking in debug mode
370   char* bytes_;          // Array of num_bytes_ bytes
371 
372   void operator=(const PendingCounts&) = delete;
373 };
374 
CreateHandle(size_t max_pending_count,size_t max_dead_count)375 inline PendingCounts::Handle PendingCounts::Layout::CreateHandle(
376     size_t max_pending_count, size_t max_dead_count) {
377   Handle result;
378   if ((max_pending_count > kMaxCountForPackedCounts) ||
379       (max_dead_count > kMaxCountForPackedCounts)) {
380     constexpr int B = sizeof(std::atomic<LargeCounts>);
381     // Round byte offset to proper alignment
382     static_assert(
383         sizeof(std::atomic<LargeCounts>) >= alignof(std::atomic<LargeCounts>),
384         "std::atomic<LargeCounts> must be packed");
385     int64 offset = ((static_cast<int64>(next_offset_) + B - 1) / B) * B;
386     result.byte_offset_ = offset;
387     result.is_large_ = true;
388     next_offset_ = result.byte_offset_ + B;
389   } else {
390     result.byte_offset_ = next_offset_;
391     result.is_large_ = false;
392     static_assert(sizeof(std::atomic<PackedCounts>) == 1,
393                   "std::atomic<PackedCounts> should be a single byte");
394     next_offset_ += sizeof(std::atomic<PackedCounts>);
395   }
396   return result;
397 }
398 
399 }  // end namespace tensorflow
400 
401 #endif  // TENSORFLOW_CORE_COMMON_RUNTIME_PENDING_COUNTS_H_
402