1 //
2 // Copyright (C) Microsoft Corporation
3 // All rights reserved.
4 // Modified for native C++ WRL support by Gregory Morse
5 //
6 // Code in Details namespace is for internal usage within the library code
7 //
8 
9 #ifndef _PLATFORM_AGILE_H_
10 #define _PLATFORM_AGILE_H_
11 
12 #ifdef _MSC_VER
13 #pragma once
14 #endif  // _MSC_VER
15 
16 #include <algorithm>
17 #include <wrl\client.h>
18 
19 template <typename T, bool TIsNotAgile> class Agile;
20 
21 template <typename T>
22 struct UnwrapAgile
23 {
24     static const bool _IsAgile = false;
25 };
26 template <typename T>
27 struct UnwrapAgile<Agile<T, false>>
28 {
29     static const bool _IsAgile = true;
30 };
31 template <typename T>
32 struct UnwrapAgile<Agile<T, true>>
33 {
34     static const bool _IsAgile = true;
35 };
36 
37 #define IS_AGILE(T) UnwrapAgile<T>::_IsAgile
38 
39 #define __is_winrt_agile(T) (std::is_same<T, HSTRING__>::value || std::is_base_of<Microsoft::WRL::FtmBase, T>::value || std::is_base_of<IAgileObject, T>::value) //derived from Microsoft::WRL::FtmBase or IAgileObject
40 
41 #define __is_win_interface(T) (std::is_base_of<IUnknown, T>::value || std::is_base_of<IInspectable, T>::value) //derived from IUnknown or IInspectable
42 
43 #define __is_win_class(T) (std::is_same<T, HSTRING__>::value || std::is_base_of<Microsoft::WRL::Details::RuntimeClassBase, T>::value) //derived from Microsoft::WRL::RuntimeClass or HSTRING
44 
45     namespace Details
46     {
47         IUnknown* __stdcall GetObjectContext();
48         HRESULT __stdcall GetProxyImpl(IUnknown*, REFIID, IUnknown*, IUnknown**);
49         HRESULT __stdcall ReleaseInContextImpl(IUnknown*, IUnknown*);
50 
51         template <typename T>
52 #if _MSC_VER >= 1800
53         __declspec(no_refcount) inline HRESULT GetProxy(T *ObjectIn, IUnknown *ContextCallBack, T **Proxy)
54 #else
55         inline HRESULT GetProxy(T *ObjectIn, IUnknown *ContextCallBack, T **Proxy)
56 #endif
57         {
58 #if _MSC_VER >= 1800
59             return GetProxyImpl(*reinterpret_cast<IUnknown**>(&ObjectIn), __uuidof(T*), ContextCallBack, reinterpret_cast<IUnknown**>(Proxy));
60 #else
61             return GetProxyImpl(*reinterpret_cast<IUnknown**>(&const_cast<T*>(ObjectIn)), __uuidof(T*), ContextCallBack, reinterpret_cast<IUnknown**>(Proxy));
62 #endif
63         }
64 
65         template <typename T>
66         inline HRESULT ReleaseInContext(T *ObjectIn, IUnknown *ContextCallBack)
67         {
68             return ReleaseInContextImpl(ObjectIn, ContextCallBack);
69         }
70 
71         template <typename T>
72         class AgileHelper
73         {
74             __abi_IUnknown* _p;
75             bool _release;
76         public:
77             AgileHelper(__abi_IUnknown* p, bool release = true) : _p(p), _release(release)
78             {
79             }
80             AgileHelper(AgileHelper&& other) : _p(other._p), _release(other._release)
81             {
82                 _other._p = nullptr;
83                 _other._release = true;
84             }
85             AgileHelper operator=(AgileHelper&& other)
86             {
87                 _p = other._p;
88                 _release = other._release;
89                 _other._p = nullptr;
90                 _other._release = true;
91                 return *this;
92             }
93 
94             ~AgileHelper()
95             {
96                 if (_release && _p)
97                 {
98                     _p->__abi_Release();
99                 }
100             }
101 
102             __declspec(no_refcount) __declspec(no_release_return)
103                 T* operator->()
104             {
105                     return reinterpret_cast<T*>(_p);
106             }
107 
108             __declspec(no_refcount) __declspec(no_release_return)
109                 operator T * ()
110             {
111                     return reinterpret_cast<T*>(_p);
112             }
113         private:
114             AgileHelper(const AgileHelper&);
115             AgileHelper operator=(const AgileHelper&);
116         };
117         template <typename T>
118         struct __remove_hat
119         {
120             typedef T type;
121         };
122         template <typename T>
123         struct __remove_hat<T*>
124         {
125             typedef T type;
126         };
127         template <typename T>
128         struct AgileTypeHelper
129         {
130             typename typedef __remove_hat<T>::type type;
131             typename typedef __remove_hat<T>::type* agileMemberType;
132         };
133     } // namespace Details
134 
135 #pragma warning(push)
136 #pragma warning(disable: 4451) // Usage of ref class inside this context can lead to invalid marshaling of object across contexts
137 
138     template <
139         typename T,
140         bool TIsNotAgile = (__is_win_class(typename Details::AgileTypeHelper<T>::type) && !__is_winrt_agile(typename Details::AgileTypeHelper<T>::type)) ||
141         __is_win_interface(typename Details::AgileTypeHelper<T>::type)
142     >
143     class Agile
144     {
145         static_assert(__is_win_class(typename Details::AgileTypeHelper<T>::type) || __is_win_interface(typename Details::AgileTypeHelper<T>::type), "Agile can only be used with ref class or interface class types");
146         typename typedef Details::AgileTypeHelper<T>::agileMemberType TypeT;
147         TypeT _object;
148         ::Microsoft::WRL::ComPtr<IUnknown> _contextCallback;
149         ULONG_PTR _contextToken;
150 
151 #if _MSC_VER >= 1800
152         enum class AgileState
153         {
154             NonAgilePointer = 0,
155             AgilePointer = 1,
156             Unknown = 2
157         };
158         AgileState _agileState;
159 #endif
160 
161         void CaptureContext()
162         {
163             _contextCallback = Details::GetObjectContext();
164             __abi_ThrowIfFailed(CoGetContextToken(&_contextToken));
165         }
166 
167         void SetObject(TypeT object)
168         {
169             // Capture context before setting the pointer
170             // If context capture fails then nothing to cleanup
171             Release();
172             if (object != nullptr)
173             {
174                 ::Microsoft::WRL::ComPtr<IAgileObject> checkIfAgile;
175                 HRESULT hr = reinterpret_cast<IUnknown*>(object)->QueryInterface(__uuidof(IAgileObject), &checkIfAgile);
176                 // Don't Capture context if object is agile
177                 if (hr != S_OK)
178                 {
179 #if _MSC_VER >= 1800
180                     _agileState = AgileState::NonAgilePointer;
181 #endif
182                     CaptureContext();
183                 }
184 #if _MSC_VER >= 1800
185                 else
186                 {
187                     _agileState = AgileState::AgilePointer;
188                 }
189 #endif
190             }
191             _object = object;
192         }
193 
194     public:
195         Agile() throw() : _object(nullptr), _contextToken(0)
196 #if _MSC_VER >= 1800
197             , _agileState(AgileState::Unknown)
198 #endif
199         {
200         }
201 
202         Agile(nullptr_t) throw() : _object(nullptr), _contextToken(0)
203 #if _MSC_VER >= 1800
204             , _agileState(AgileState::Unknown)
205 #endif
206         {
207         }
208 
209         explicit Agile(TypeT object) throw() : _object(nullptr), _contextToken(0)
210 #if _MSC_VER >= 1800
211             , _agileState(AgileState::Unknown)
212 #endif
213         {
214             // Assumes that the source object is from the current context
215             SetObject(object);
216         }
217 
218         Agile(const Agile& object) throw() : _object(nullptr), _contextToken(0)
219 #if _MSC_VER >= 1800
220             , _agileState(AgileState::Unknown)
221 #endif
222         {
223             // Get returns pointer valid for current context
224             SetObject(object.Get());
225         }
226 
227         Agile(Agile&& object) throw() : _object(nullptr), _contextToken(0)
228 #if _MSC_VER >= 1800
229             , _agileState(AgileState::Unknown)
230 #endif
231         {
232             // Assumes that the source object is from the current context
233             Swap(object);
234         }
235 
236         ~Agile() throw()
237         {
238             Release();
239         }
240 
241         TypeT Get() const
242         {
243             // Agile object, no proxy required
244 #if _MSC_VER >= 1800
245             if (_agileState == AgileState::AgilePointer || _object == nullptr)
246 #else
247             if (_contextToken == 0 || _contextCallback == nullptr || _object == nullptr)
248 #endif
249             {
250                 return _object;
251             }
252 
253             // Do the check for same context
254             ULONG_PTR currentContextToken;
255             __abi_ThrowIfFailed(CoGetContextToken(&currentContextToken));
256             if (currentContextToken == _contextToken)
257             {
258                 return _object;
259             }
260 
261 #if _MSC_VER >= 1800
262             // Different context and holding on to a non agile object
263             // Do the costly work of getting a proxy
264             TypeT localObject;
265             __abi_ThrowIfFailed(Details::GetProxy(_object, _contextCallback.Get(), &localObject));
266 
267             if (_agileState == AgileState::Unknown)
268 #else
269             // Object is agile if it implements IAgileObject
270             // GetAddressOf captures the context with out knowing the type of object that it will hold
271             if (_object != nullptr)
272 #endif
273             {
274 #if _MSC_VER >= 1800
275                 // Object is agile if it implements IAgileObject
276                 // GetAddressOf captures the context with out knowing the type of object that it will hold
277                 ::Microsoft::WRL::ComPtr<IAgileObject> checkIfAgile;
278                 HRESULT hr = reinterpret_cast<IUnknown*>(localObject)->QueryInterface(__uuidof(IAgileObject), &checkIfAgile);
279 #else
280                 ::Microsoft::WRL::ComPtr<IAgileObject> checkIfAgile;
281                 HRESULT hr = reinterpret_cast<IUnknown*>(_object)->QueryInterface(__uuidof(IAgileObject), &checkIfAgile);
282 #endif
283                 if (hr == S_OK)
284                 {
285                     auto pThis = const_cast<Agile*>(this);
286 #if _MSC_VER >= 1800
287                     pThis->_agileState = AgileState::AgilePointer;
288 #endif
289                     pThis->_contextToken = 0;
290                     pThis->_contextCallback = nullptr;
291                     return _object;
292                 }
293 #if _MSC_VER >= 1800
294                 else
295                 {
296                     auto pThis = const_cast<Agile*>(this);
297                     pThis->_agileState = AgileState::NonAgilePointer;
298                 }
299 #endif
300             }
301 
302 #if _MSC_VER < 1800
303             // Different context and holding on to a non agile object
304             // Do the costly work of getting a proxy
305             TypeT localObject;
306             __abi_ThrowIfFailed(Details::GetProxy(_object, _contextCallback.Get(), &localObject));
307 #endif
308             return localObject;
309         }
310 
311         TypeT* GetAddressOf() throw()
312         {
313             Release();
314             CaptureContext();
315             return &_object;
316         }
317 
318         TypeT* GetAddressOfForInOut() throw()
319         {
320             CaptureContext();
321             return &_object;
322         }
323 
324         TypeT operator->() const throw()
325         {
326             return Get();
327         }
328 
329         Agile& operator=(nullptr_t) throw()
330         {
331             Release();
332             return *this;
333         }
334 
335         Agile& operator=(TypeT object) throw()
336         {
337             Agile(object).Swap(*this);
338             return *this;
339         }
340 
341         Agile& operator=(Agile object) throw()
342         {
343             // parameter is by copy which gets pointer valid for current context
344             object.Swap(*this);
345             return *this;
346         }
347 
348 #if _MSC_VER < 1800
349         Agile& operator=(IUnknown* lp) throw()
350         {
351             // bump ref count
352             ::Microsoft::WRL::ComPtr<IUnknown> spObject(lp);
353 
354             // put it into Platform Object
355             Platform::Object object;
356             *(IUnknown**)(&object) = spObject.Detach();
357 
358             SetObject(object);
359             return *this;
360         }
361 #endif
362 
363         void Swap(Agile& object)
364         {
365             std::swap(_object, object._object);
366             std::swap(_contextCallback, object._contextCallback);
367             std::swap(_contextToken, object._contextToken);
368 #if _MSC_VER >= 1800
369             std::swap(_agileState, object._agileState);
370 #endif
371         }
372 
373         // Release the interface and set to NULL
374         void Release() throw()
375         {
376             if (_object)
377             {
378                 // Cast to IInspectable (no QI)
379                 IUnknown* pObject = *(IUnknown**)(&_object);
380                 // Set * to null without release
381                 *(IUnknown**)(&_object) = nullptr;
382 
383                 ULONG_PTR currentContextToken;
384                 __abi_ThrowIfFailed(CoGetContextToken(&currentContextToken));
385                 if (_contextToken == 0 || _contextCallback == nullptr || _contextToken == currentContextToken)
386                 {
387                     pObject->Release();
388                 }
389                 else
390                 {
391                     Details::ReleaseInContext(pObject, _contextCallback.Get());
392                 }
393                 _contextCallback = nullptr;
394                 _contextToken = 0;
395 #if _MSC_VER >= 1800
396                 _agileState = AgileState::Unknown;
397 #endif
398             }
399         }
400 
401         bool operator==(nullptr_t) const throw()
402         {
403             return _object == nullptr;
404         }
405 
406         bool operator==(const Agile& other) const throw()
407         {
408             return _object == other._object && _contextToken == other._contextToken;
409         }
410 
411         bool operator<(const Agile& other) const throw()
412         {
413             if (reinterpret_cast<void*>(_object) < reinterpret_cast<void*>(other._object))
414             {
415                 return true;
416             }
417 
418             return _object == other._object && _contextToken < other._contextToken;
419         }
420     };
421 
422     template <typename T>
423     class Agile<T, false>
424     {
425         static_assert(__is_win_class(typename Details::AgileTypeHelper<T>::type) || __is_win_interface(typename Details::AgileTypeHelper<T>::type), "Agile can only be used with ref class or interface class types");
426         typename typedef Details::AgileTypeHelper<T>::agileMemberType TypeT;
427         TypeT _object;
428 
429     public:
430         Agile() throw() : _object(nullptr)
431         {
432         }
433 
434         Agile(nullptr_t) throw() : _object(nullptr)
435         {
436         }
437 
438         explicit Agile(TypeT object) throw() : _object(object)
439         {
440         }
441 
442         Agile(const Agile& object) throw() : _object(object._object)
443         {
444         }
445 
446         Agile(Agile&& object) throw() : _object(nullptr)
447         {
448             Swap(object);
449         }
450 
451         ~Agile() throw()
452         {
453             Release();
454         }
455 
456         TypeT Get() const
457         {
458             return _object;
459         }
460 
461         TypeT* GetAddressOf() throw()
462         {
463             Release();
464             return &_object;
465         }
466 
467         TypeT* GetAddressOfForInOut() throw()
468         {
469             return &_object;
470         }
471 
472         TypeT operator->() const throw()
473         {
474             return Get();
475         }
476 
477         Agile& operator=(nullptr_t) throw()
478         {
479             Release();
480             return *this;
481         }
482 
483         Agile& operator=(TypeT object) throw()
484         {
485             if (_object != object)
486             {
487                 _object = object;
488             }
489             return *this;
490         }
491 
492         Agile& operator=(Agile object) throw()
493         {
494             object.Swap(*this);
495             return *this;
496         }
497 
498 #if _MSC_VER < 1800
499         Agile& operator=(IUnknown* lp) throw()
500         {
501             Release();
502             // bump ref count
503             ::Microsoft::WRL::ComPtr<IUnknown> spObject(lp);
504 
505             // put it into Platform Object
506             Platform::Object object;
507             *(IUnknown**)(&object) = spObject.Detach();
508 
509             _object = object;
510             return *this;
511         }
512 #endif
513 
514         // Release the interface and set to NULL
515         void Release() throw()
516         {
517             _object = nullptr;
518         }
519 
520         void Swap(Agile& object)
521         {
522             std::swap(_object, object._object);
523         }
524 
525         bool operator==(nullptr_t) const throw()
526         {
527             return _object == nullptr;
528         }
529 
530         bool operator==(const Agile& other) const throw()
531         {
532             return _object == other._object;
533         }
534 
535         bool operator<(const Agile& other) const throw()
536         {
537             return reinterpret_cast<void*>(_object) < reinterpret_cast<void*>(other._object);
538         }
539     };
540 
541 #pragma warning(pop)
542 
543     template<class U>
544     bool operator==(nullptr_t, const Agile<U>& a) throw()
545     {
546         return a == nullptr;
547     }
548 
549     template<class U>
550     bool operator!=(const Agile<U>& a, nullptr_t) throw()
551     {
552         return !(a == nullptr);
553     }
554 
555     template<class U>
556     bool operator!=(nullptr_t, const Agile<U>& a) throw()
557     {
558         return !(a == nullptr);
559     }
560 
561     template<class U>
562     bool operator!=(const Agile<U>& a, const Agile<U>& b) throw()
563     {
564         return !(a == b);
565     }
566 
567 
568 #endif // _PLATFORM_AGILE_H_
569