1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // This file wraps hipsparse API calls with dso loader so that we don't need to
17 // have explicit linking to libhipsparse. All TF hipsarse API usage should route
18 // through this wrapper.
19 
20 #ifndef TENSORFLOW_STREAM_EXECUTOR_ROCM_HIPSPARSE_WRAPPER_H_
21 #define TENSORFLOW_STREAM_EXECUTOR_ROCM_HIPSPARSE_WRAPPER_H_
22 
23 #include "rocm/include/hipsparse/hipsparse.h"
24 #include "tensorflow/stream_executor/lib/env.h"
25 #include "tensorflow/stream_executor/platform/dso_loader.h"
26 #include "tensorflow/stream_executor/platform/port.h"
27 
28 namespace tensorflow {
29 namespace wrap {
30 
31 #ifdef PLATFORM_GOOGLE
32 
33 #define HIPSPARSE_API_WRAPPER(__name)               \
34   struct WrapperShim__##__name {                    \
35     template <typename... Args>                     \
36     hipsparseStatus_t operator()(Args... args) {    \
37       hipSparseStatus_t retval = ::__name(args...); \
38       return retval;                                \
39     }                                               \
40   } __name;
41 
42 #else
43 
44 #define HIPSPARSE_API_WRAPPER(__name)                                          \
45   struct DynLoadShim__##__name {                                               \
46     static const char* kName;                                                  \
47     using FuncPtrT = std::add_pointer<decltype(::__name)>::type;               \
48     static void* GetDsoHandle() {                                              \
49       auto s =                                                                 \
50           stream_executor::internal::CachedDsoLoader::GetHipsparseDsoHandle(); \
51       return s.ValueOrDie();                                                   \
52     }                                                                          \
53     static FuncPtrT LoadOrDie() {                                              \
54       void* f;                                                                 \
55       auto s =                                                                 \
56           Env::Default()->GetSymbolFromLibrary(GetDsoHandle(), kName, &f);     \
57       CHECK(s.ok()) << "could not find " << kName                              \
58                     << " in miopen DSO; dlerror: " << s.error_message();       \
59       return reinterpret_cast<FuncPtrT>(f);                                    \
60     }                                                                          \
61     static FuncPtrT DynLoad() {                                                \
62       static FuncPtrT f = LoadOrDie();                                         \
63       return f;                                                                \
64     }                                                                          \
65     template <typename... Args>                                                \
66     hipsparseStatus_t operator()(Args... args) {                               \
67       return DynLoad()(args...);                                               \
68     }                                                                          \
69   } __name;                                                                    \
70   const char* DynLoadShim__##__name::kName = #__name;
71 
72 #endif
73 
74 // clang-format off
75 #define FOREACH_HIPSPARSE_API(__macro)		\
76   __macro(hipsparseCreate)			\
77   __macro(hipsparseCreateMatDescr)		\
78   __macro(hipsparseDcsr2csc)			\
79   __macro(hipsparseDcsrgemm)			\
80   __macro(hipsparseDcsrmm2)			\
81   __macro(hipsparseDcsrmv)			\
82   __macro(hipsparseDestroy)			\
83   __macro(hipsparseDestroyMatDescr)		\
84   __macro(hipsparseScsr2csc)			\
85   __macro(hipsparseScsrgemm)			\
86   __macro(hipsparseScsrmm2)			\
87   __macro(hipsparseScsrmv)			\
88   __macro(hipsparseSetStream)			\
89   __macro(hipsparseSetMatIndexBase)		\
90   __macro(hipsparseSetMatType)			\
91   __macro(hipsparseXcoo2csr)			\
92   __macro(hipsparseXcsr2coo)			\
93   __macro(hipsparseXcsrgemmNnz)
94 
95 // clang-format on
96 
97 FOREACH_HIPSPARSE_API(HIPSPARSE_API_WRAPPER)
98 
99 #undef FOREACH_HIPSPARSE_API
100 #undef HIPSPARSE_API_WRAPPER
101 
102 }  // namespace wrap
103 }  // namespace tensorflow
104 
105 #endif  // TENSORFLOW_STREAM_EXECUTOR_ROCM_HIPSPARSE_WRAPPER_H_
106