1 //
2 // Copyright 2020 gRPC authors.
3 //
4 // Licensed under the Apache License, Version 2.0 (the "License");
5 // you may not use this file except in compliance with the License.
6 // You may obtain a copy of the License at
7 //
8 //     http://www.apache.org/licenses/LICENSE-2.0
9 //
10 // Unless required by applicable law or agreed to in writing, software
11 // distributed under the License is distributed on an "AS IS" BASIS,
12 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 // See the License for the specific language governing permissions and
14 // limitations under the License.
15 //
16 
17 #ifndef GRPC_CORE_LIB_GPRPP_DUAL_REF_COUNTED_H
18 #define GRPC_CORE_LIB_GPRPP_DUAL_REF_COUNTED_H
19 
20 #include <grpc/support/port_platform.h>
21 
22 #include <grpc/support/atm.h>
23 #include <grpc/support/log.h>
24 #include <grpc/support/sync.h>
25 
26 #include <atomic>
27 #include <cassert>
28 #include <cinttypes>
29 
30 #include "src/core/lib/gprpp/atomic.h"
31 #include "src/core/lib/gprpp/debug_location.h"
32 #include "src/core/lib/gprpp/orphanable.h"
33 #include "src/core/lib/gprpp/ref_counted_ptr.h"
34 
35 namespace grpc_core {
36 
37 // DualRefCounted is an interface for reference-counted objects with two
38 // classes of refs: strong refs (usually just called "refs") and weak refs.
39 // This supports cases where an object needs to start shutting down when
40 // all external callers are done with it (represented by strong refs) but
41 // cannot be destroyed until all internal callbacks are complete
42 // (represented by weak refs).
43 //
44 // Each class of refs can be incremented and decremented independently.
45 // Objects start with 1 strong ref and 0 weak refs at instantiation.
46 // When the strong refcount reaches 0, the object's Orphan() method is called.
47 // When the weak refcount reaches 0, the object is destroyed.
48 //
49 // This will be used by CRTP (curiously-recurring template pattern), e.g.:
50 //   class MyClass : public RefCounted<MyClass> { ... };
51 template <typename Child>
52 class DualRefCounted : public Orphanable {
53  public:
54   ~DualRefCounted() override = default;
55 
Ref()56   RefCountedPtr<Child> Ref() GRPC_MUST_USE_RESULT {
57     IncrementRefCount();
58     return RefCountedPtr<Child>(static_cast<Child*>(this));
59   }
60 
Ref(const DebugLocation & location,const char * reason)61   RefCountedPtr<Child> Ref(const DebugLocation& location,
62                            const char* reason) GRPC_MUST_USE_RESULT {
63     IncrementRefCount(location, reason);
64     return RefCountedPtr<Child>(static_cast<Child*>(this));
65   }
66 
Unref()67   void Unref() {
68     // Convert strong ref to weak ref.
69     const uint64_t prev_ref_pair =
70         refs_.FetchAdd(MakeRefPair(-1, 1), MemoryOrder::ACQ_REL);
71     const uint32_t strong_refs = GetStrongRefs(prev_ref_pair);
72 #ifndef NDEBUG
73     const uint32_t weak_refs = GetWeakRefs(prev_ref_pair);
74     if (trace_ != nullptr) {
75       gpr_log(GPR_INFO, "%s:%p unref %d -> %d, weak_ref %d -> %d", trace_, this,
76               strong_refs, strong_refs - 1, weak_refs, weak_refs + 1);
77     }
78     GPR_ASSERT(strong_refs > 0);
79 #endif
80     if (GPR_UNLIKELY(strong_refs == 1)) {
81       Orphan();
82     }
83     // Now drop the weak ref.
84     WeakUnref();
85   }
Unref(const DebugLocation & location,const char * reason)86   void Unref(const DebugLocation& location, const char* reason) {
87     const uint64_t prev_ref_pair =
88         refs_.FetchAdd(MakeRefPair(-1, 1), MemoryOrder::ACQ_REL);
89     const uint32_t strong_refs = GetStrongRefs(prev_ref_pair);
90 #ifndef NDEBUG
91     const uint32_t weak_refs = GetWeakRefs(prev_ref_pair);
92     if (trace_ != nullptr) {
93       gpr_log(GPR_INFO, "%s:%p %s:%d unref %d -> %d, weak_ref %d -> %d) %s",
94               trace_, this, location.file(), location.line(), strong_refs,
95               strong_refs - 1, weak_refs, weak_refs + 1, reason);
96     }
97     GPR_ASSERT(strong_refs > 0);
98 #else
99     // Avoid unused-parameter warnings for debug-only parameters
100     (void)location;
101     (void)reason;
102 #endif
103     if (GPR_UNLIKELY(strong_refs == 1)) {
104       Orphan();
105     }
106     // Now drop the weak ref.
107     WeakUnref(location, reason);
108   }
109 
RefIfNonZero()110   RefCountedPtr<Child> RefIfNonZero() GRPC_MUST_USE_RESULT {
111     uint64_t prev_ref_pair = refs_.Load(MemoryOrder::ACQUIRE);
112     do {
113       const uint32_t strong_refs = GetStrongRefs(prev_ref_pair);
114 #ifndef NDEBUG
115       const uint32_t weak_refs = GetWeakRefs(prev_ref_pair);
116       if (trace_ != nullptr) {
117         gpr_log(GPR_INFO, "%s:%p ref_if_non_zero %d -> %d (weak_refs=%d)",
118                 trace_, this, strong_refs, strong_refs + 1, weak_refs);
119       }
120 #endif
121       if (strong_refs == 0) return nullptr;
122     } while (!refs_.CompareExchangeWeak(
123         &prev_ref_pair, prev_ref_pair + MakeRefPair(1, 0), MemoryOrder::ACQ_REL,
124         MemoryOrder::ACQUIRE));
125     return RefCountedPtr<Child>(static_cast<Child*>(this));
126   }
127 
RefIfNonZero(const DebugLocation & location,const char * reason)128   RefCountedPtr<Child> RefIfNonZero(const DebugLocation& location,
129                                     const char* reason) GRPC_MUST_USE_RESULT {
130     uint64_t prev_ref_pair = refs_.Load(MemoryOrder::ACQUIRE);
131     do {
132       const uint32_t strong_refs = GetStrongRefs(prev_ref_pair);
133 #ifndef NDEBUG
134       const uint32_t weak_refs = GetWeakRefs(prev_ref_pair);
135       if (trace_ != nullptr) {
136         gpr_log(GPR_INFO,
137                 "%s:%p %s:%d ref_if_non_zero %d -> %d (weak_refs=%d) %s",
138                 trace_, this, location.file(), location.line(), strong_refs,
139                 strong_refs + 1, weak_refs, reason);
140       }
141 #else
142       // Avoid unused-parameter warnings for debug-only parameters
143       (void)location;
144       (void)reason;
145 #endif
146       if (strong_refs == 0) return nullptr;
147     } while (!refs_.CompareExchangeWeak(
148         &prev_ref_pair, prev_ref_pair + MakeRefPair(1, 0), MemoryOrder::ACQ_REL,
149         MemoryOrder::ACQUIRE));
150     return RefCountedPtr<Child>(static_cast<Child*>(this));
151   }
152 
WeakRef()153   WeakRefCountedPtr<Child> WeakRef() GRPC_MUST_USE_RESULT {
154     IncrementWeakRefCount();
155     return WeakRefCountedPtr<Child>(static_cast<Child*>(this));
156   }
157 
WeakRef(const DebugLocation & location,const char * reason)158   WeakRefCountedPtr<Child> WeakRef(const DebugLocation& location,
159                                    const char* reason) GRPC_MUST_USE_RESULT {
160     IncrementWeakRefCount(location, reason);
161     return WeakRefCountedPtr<Child>(static_cast<Child*>(this));
162   }
163 
WeakUnref()164   void WeakUnref() {
165 #ifndef NDEBUG
166     // Grab a copy of the trace flag before the atomic change, since we
167     // will no longer be holding a ref afterwards and therefore can't
168     // safely access it, since another thread might free us in the interim.
169     const char* trace = trace_;
170 #endif
171     const uint64_t prev_ref_pair =
172         refs_.FetchSub(MakeRefPair(0, 1), MemoryOrder::ACQ_REL);
173     const uint32_t weak_refs = GetWeakRefs(prev_ref_pair);
174 #ifndef NDEBUG
175     const uint32_t strong_refs = GetStrongRefs(prev_ref_pair);
176     if (trace != nullptr) {
177       gpr_log(GPR_INFO, "%s:%p weak_unref %d -> %d (refs=%d)", trace, this,
178               weak_refs, weak_refs - 1, strong_refs);
179     }
180     GPR_ASSERT(weak_refs > 0);
181 #endif
182     if (GPR_UNLIKELY(prev_ref_pair == MakeRefPair(0, 1))) {
183       delete static_cast<Child*>(this);
184     }
185   }
WeakUnref(const DebugLocation & location,const char * reason)186   void WeakUnref(const DebugLocation& location, const char* reason) {
187 #ifndef NDEBUG
188     // Grab a copy of the trace flag before the atomic change, since we
189     // will no longer be holding a ref afterwards and therefore can't
190     // safely access it, since another thread might free us in the interim.
191     const char* trace = trace_;
192 #endif
193     const uint64_t prev_ref_pair =
194         refs_.FetchSub(MakeRefPair(0, 1), MemoryOrder::ACQ_REL);
195     const uint32_t weak_refs = GetWeakRefs(prev_ref_pair);
196 #ifndef NDEBUG
197     const uint32_t strong_refs = GetStrongRefs(prev_ref_pair);
198     if (trace_ != nullptr) {
199       gpr_log(GPR_INFO, "%s:%p %s:%d weak_unref %d -> %d (refs=%d) %s", trace,
200               this, location.file(), location.line(), weak_refs, weak_refs - 1,
201               strong_refs, reason);
202     }
203     GPR_ASSERT(weak_refs > 0);
204 #else
205     // Avoid unused-parameter warnings for debug-only parameters
206     (void)location;
207     (void)reason;
208 #endif
209     if (GPR_UNLIKELY(prev_ref_pair == MakeRefPair(0, 1))) {
210       delete static_cast<Child*>(this);
211     }
212   }
213 
214   // Not copyable nor movable.
215   DualRefCounted(const DualRefCounted&) = delete;
216   DualRefCounted& operator=(const DualRefCounted&) = delete;
217 
218  protected:
219   // Note: Tracing is a no-op in non-debug builds.
220   explicit DualRefCounted(
221       const char*
222 #ifndef NDEBUG
223           // Leave unnamed if NDEBUG to avoid unused parameter warning
224           trace
225 #endif
226       = nullptr,
227       int32_t initial_refcount = 1)
228       :
229 #ifndef NDEBUG
trace_(trace)230         trace_(trace),
231 #endif
232         refs_(MakeRefPair(initial_refcount, 0)) {
233   }
234 
235  private:
236   // Allow RefCountedPtr<> to access IncrementRefCount().
237   template <typename T>
238   friend class RefCountedPtr;
239   // Allow WeakRefCountedPtr<> to access IncrementWeakRefCount().
240   template <typename T>
241   friend class WeakRefCountedPtr;
242 
243   // First 32 bits are strong refs, next 32 bits are weak refs.
MakeRefPair(uint32_t strong,uint32_t weak)244   static uint64_t MakeRefPair(uint32_t strong, uint32_t weak) {
245     return (static_cast<uint64_t>(strong) << 32) + static_cast<int64_t>(weak);
246   }
GetStrongRefs(uint64_t ref_pair)247   static uint32_t GetStrongRefs(uint64_t ref_pair) {
248     return static_cast<uint32_t>(ref_pair >> 32);
249   }
GetWeakRefs(uint64_t ref_pair)250   static uint32_t GetWeakRefs(uint64_t ref_pair) {
251     return static_cast<uint32_t>(ref_pair & 0xffffffffu);
252   }
253 
IncrementRefCount()254   void IncrementRefCount() {
255 #ifndef NDEBUG
256     const uint64_t prev_ref_pair =
257         refs_.FetchAdd(MakeRefPair(1, 0), MemoryOrder::RELAXED);
258     const uint32_t strong_refs = GetStrongRefs(prev_ref_pair);
259     const uint32_t weak_refs = GetWeakRefs(prev_ref_pair);
260     GPR_ASSERT(strong_refs != 0);
261     if (trace_ != nullptr) {
262       gpr_log(GPR_INFO, "%s:%p ref %d -> %d; (weak_refs=%d)", trace_, this,
263               strong_refs, strong_refs + 1, weak_refs);
264     }
265 #else
266     refs_.FetchAdd(MakeRefPair(1, 0), MemoryOrder::RELAXED);
267 #endif
268   }
IncrementRefCount(const DebugLocation & location,const char * reason)269   void IncrementRefCount(const DebugLocation& location, const char* reason) {
270 #ifndef NDEBUG
271     const uint64_t prev_ref_pair =
272         refs_.FetchAdd(MakeRefPair(1, 0), MemoryOrder::RELAXED);
273     const uint32_t strong_refs = GetStrongRefs(prev_ref_pair);
274     const uint32_t weak_refs = GetWeakRefs(prev_ref_pair);
275     GPR_ASSERT(strong_refs != 0);
276     if (trace_ != nullptr) {
277       gpr_log(GPR_INFO, "%s:%p %s:%d ref %d -> %d (weak_refs=%d) %s", trace_,
278               this, location.file(), location.line(), strong_refs,
279               strong_refs + 1, weak_refs, reason);
280     }
281 #else
282     // Use conditionally-important parameters
283     (void)location;
284     (void)reason;
285     refs_.FetchAdd(MakeRefPair(1, 0), MemoryOrder::RELAXED);
286 #endif
287   }
288 
IncrementWeakRefCount()289   void IncrementWeakRefCount() {
290 #ifndef NDEBUG
291     const uint64_t prev_ref_pair =
292         refs_.FetchAdd(MakeRefPair(0, 1), MemoryOrder::RELAXED);
293     const uint32_t strong_refs = GetStrongRefs(prev_ref_pair);
294     const uint32_t weak_refs = GetWeakRefs(prev_ref_pair);
295     if (trace_ != nullptr) {
296       gpr_log(GPR_INFO, "%s:%p weak_ref %d -> %d; (refs=%d)", trace_, this,
297               weak_refs, weak_refs + 1, strong_refs);
298     }
299 #else
300     refs_.FetchAdd(MakeRefPair(0, 1), MemoryOrder::RELAXED);
301 #endif
302   }
IncrementWeakRefCount(const DebugLocation & location,const char * reason)303   void IncrementWeakRefCount(const DebugLocation& location,
304                              const char* reason) {
305 #ifndef NDEBUG
306     const uint64_t prev_ref_pair =
307         refs_.FetchAdd(MakeRefPair(0, 1), MemoryOrder::RELAXED);
308     const uint32_t strong_refs = GetStrongRefs(prev_ref_pair);
309     const uint32_t weak_refs = GetWeakRefs(prev_ref_pair);
310     if (trace_ != nullptr) {
311       gpr_log(GPR_INFO, "%s:%p %s:%d weak_ref %d -> %d (refs=%d) %s", trace_,
312               this, location.file(), location.line(), weak_refs, weak_refs + 1,
313               strong_refs, reason);
314     }
315 #else
316     // Use conditionally-important parameters
317     (void)location;
318     (void)reason;
319     refs_.FetchAdd(MakeRefPair(0, 1), MemoryOrder::RELAXED);
320 #endif
321   }
322 
323 #ifndef NDEBUG
324   const char* trace_;
325 #endif
326   Atomic<uint64_t> refs_;
327 };
328 
329 }  // namespace grpc_core
330 
331 #endif /* GRPC_CORE_LIB_GPRPP_DUAL_REF_COUNTED_H */
332