1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/platform/default/gpu/cupti_wrapper.h"
17 
18 #if GOOGLE_CUDA
19 
20 #include <string>
21 
22 #include "tensorflow/core/platform/env.h"
23 #include "tensorflow/core/platform/stream_executor.h"
24 
25 namespace perftools {
26 namespace gputools {
27 namespace profiler {
28 
29 namespace dynload {
30 
31 #define LIBCUPTI_WRAP(__name)                                                \
32   struct DynLoadShim__##__name {                                             \
33     static const char* kName;                                                \
34     using FuncPointerT = std::add_pointer<decltype(::__name)>::type;         \
35     template <typename... Args>                                              \
36     CUptiResult operator()(Args... args) {                                   \
37       static auto fn = []() -> FuncPointerT {                                \
38         auto handle_or =                                                     \
39             stream_executor::internal::CachedDsoLoader::GetCuptiDsoHandle(); \
40         if (!handle_or.ok()) return nullptr;                                 \
41         void* symbol;                                                        \
42         stream_executor::port::Env::Default()                                \
43             ->GetSymbolFromLibrary(handle_or.ValueOrDie(), kName, &symbol)   \
44             .IgnoreError();                                                  \
45         return reinterpret_cast<FuncPointerT>(symbol);                       \
46       }();                                                                   \
47       if (fn == nullptr) return CUPTI_ERROR_UNKNOWN;                         \
48       return fn(args...);                                                    \
49     }                                                                        \
50   } __name;                                                                  \
51   const char* DynLoadShim__##__name::kName = #__name;
52 
53 LIBCUPTI_WRAP(cuptiActivityDisable);
54 LIBCUPTI_WRAP(cuptiActivityEnable);
55 LIBCUPTI_WRAP(cuptiActivityFlushAll);
56 LIBCUPTI_WRAP(cuptiActivityGetNextRecord);
57 LIBCUPTI_WRAP(cuptiActivityGetNumDroppedRecords);
58 LIBCUPTI_WRAP(cuptiActivityRegisterCallbacks);
59 LIBCUPTI_WRAP(cuptiGetTimestamp);
60 LIBCUPTI_WRAP(cuptiEnableCallback);
61 LIBCUPTI_WRAP(cuptiEnableDomain);
62 LIBCUPTI_WRAP(cuptiSubscribe);
63 LIBCUPTI_WRAP(cuptiUnsubscribe);
64 LIBCUPTI_WRAP(cuptiGetResultString);
65 
66 }  // namespace dynload
67 
ActivityDisable(CUpti_ActivityKind kind)68 CUptiResult CuptiWrapper::ActivityDisable(CUpti_ActivityKind kind) {
69   return dynload::cuptiActivityDisable(kind);
70 }
71 
ActivityEnable(CUpti_ActivityKind kind)72 CUptiResult CuptiWrapper::ActivityEnable(CUpti_ActivityKind kind) {
73   return dynload::cuptiActivityEnable(kind);
74 }
75 
ActivityFlushAll(uint32_t flag)76 CUptiResult CuptiWrapper::ActivityFlushAll(uint32_t flag) {
77   return dynload::cuptiActivityFlushAll(flag);
78 }
79 
ActivityGetNextRecord(uint8_t * buffer,size_t valid_buffer_size_bytes,CUpti_Activity ** record)80 CUptiResult CuptiWrapper::ActivityGetNextRecord(uint8_t* buffer,
81                                                 size_t valid_buffer_size_bytes,
82                                                 CUpti_Activity** record) {
83   return dynload::cuptiActivityGetNextRecord(buffer, valid_buffer_size_bytes,
84                                              record);
85 }
86 
ActivityGetNumDroppedRecords(CUcontext context,uint32_t stream_id,size_t * dropped)87 CUptiResult CuptiWrapper::ActivityGetNumDroppedRecords(CUcontext context,
88                                                        uint32_t stream_id,
89                                                        size_t* dropped) {
90   return dynload::cuptiActivityGetNumDroppedRecords(context, stream_id,
91                                                     dropped);
92 }
93 
ActivityRegisterCallbacks(CUpti_BuffersCallbackRequestFunc func_buffer_requested,CUpti_BuffersCallbackCompleteFunc func_buffer_completed)94 CUptiResult CuptiWrapper::ActivityRegisterCallbacks(
95     CUpti_BuffersCallbackRequestFunc func_buffer_requested,
96     CUpti_BuffersCallbackCompleteFunc func_buffer_completed) {
97   return dynload::cuptiActivityRegisterCallbacks(func_buffer_requested,
98                                                  func_buffer_completed);
99 }
100 
GetTimestamp(uint64_t * timestamp)101 CUptiResult CuptiWrapper::GetTimestamp(uint64_t* timestamp) {
102   return dynload::cuptiGetTimestamp(timestamp);
103 }
104 
EnableCallback(uint32_t enable,CUpti_SubscriberHandle subscriber,CUpti_CallbackDomain domain,CUpti_CallbackId cbid)105 CUptiResult CuptiWrapper::EnableCallback(uint32_t enable,
106                                          CUpti_SubscriberHandle subscriber,
107                                          CUpti_CallbackDomain domain,
108                                          CUpti_CallbackId cbid) {
109   return dynload::cuptiEnableCallback(enable, subscriber, domain, cbid);
110 }
111 
EnableDomain(uint32_t enable,CUpti_SubscriberHandle subscriber,CUpti_CallbackDomain domain)112 CUptiResult CuptiWrapper::EnableDomain(uint32_t enable,
113                                        CUpti_SubscriberHandle subscriber,
114                                        CUpti_CallbackDomain domain) {
115   return dynload::cuptiEnableDomain(enable, subscriber, domain);
116 }
117 
Subscribe(CUpti_SubscriberHandle * subscriber,CUpti_CallbackFunc callback,void * userdata)118 CUptiResult CuptiWrapper::Subscribe(CUpti_SubscriberHandle* subscriber,
119                                     CUpti_CallbackFunc callback,
120                                     void* userdata) {
121   return dynload::cuptiSubscribe(subscriber, callback, userdata);
122 }
123 
Unsubscribe(CUpti_SubscriberHandle subscriber)124 CUptiResult CuptiWrapper::Unsubscribe(CUpti_SubscriberHandle subscriber) {
125   return dynload::cuptiUnsubscribe(subscriber);
126 }
127 
GetResultString(CUptiResult result,const char ** str)128 CUptiResult CuptiWrapper::GetResultString(CUptiResult result,
129                                           const char** str) {
130   return dynload::cuptiGetResultString(result, str);
131 }
132 
133 }  // namespace profiler
134 }  // namespace gputools
135 }  // namespace perftools
136 
137 #endif  // GOOGLE_CUDA
138