1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // ROCM-specific support for BLAS functionality -- this wraps the rocBLAS
17 // library capabilities, and is only included into ROCM implementation code --
18 // it will not introduce rocm headers into other code.
19 
20 #ifndef TENSORFLOW_STREAM_EXECUTOR_ROCM_ROCM_BLAS_H_
21 #define TENSORFLOW_STREAM_EXECUTOR_ROCM_ROCM_BLAS_H_
22 
23 #include "tensorflow/stream_executor/blas.h"
24 #include "tensorflow/stream_executor/platform/mutex.h"
25 #include "tensorflow/stream_executor/platform/port.h"
26 #include "tensorflow/stream_executor/platform/thread_annotations.h"
27 #include "tensorflow/stream_executor/plugin_registry.h"
28 
29 namespace stream_executor {
30 
31 class Stream;
32 
33 namespace gpu {
34 
35 // Opaque and unique identifier for the rocBLAS plugin.
36 extern const PluginId kRocBlasPlugin;
37 
38 class GpuExecutor;
39 
40 // BLAS plugin for ROCM platform via rocBLAS library.
41 //
42 // This satisfies the platform-agnostic BlasSupport interface.
43 //
44 // Note that the rocBLAS handle that this encapsulates is implicitly tied to the
45 // context (and, as a result, the device) that the parent GpuExecutor is tied
46 // to. This simply happens as an artifact of creating the rocBLAS handle when a
47 // ROCM context is active.
48 //
49 // Thread-safe post-initialization.
50 class ROCMBlas : public blas::BlasSupport {
51  public:
52   explicit ROCMBlas(GpuExecutor *parent);
53 
54   // Allocates a rocBLAS handle.
55   bool Init();
56 
57   // Releases the rocBLAS handle, if present.
58   ~ROCMBlas() override;
59 
60   TENSORFLOW_STREAM_EXECUTOR_GPU_BLAS_SUPPORT_OVERRIDES
61 
62  private:
63   // Tells rocBLAS to enqueue the BLAS operation onto a particular Stream.
64   //
65   // rocBLAS is stateful, and only be associated with one stream (in order to
66   // enqueue dispatch) at a given time. As a result, this generally must be
67   // invoked before calling into rocBLAS.
68   bool SetStream(Stream *stream) EXCLUSIVE_LOCKS_REQUIRED(mu_);
69 
70   // A helper function that calls the real rocBLAS function together with error
71   // handling.
72   //
73   // rocblas_func:       rocBLAS function pointer.
74   // rocblas_name:       rocBLAS function name.
75   // stream:             Stream to enqueue the BLAS operation onto.
76   // pointer_mode_host:  Indicate if the pointer to a scalar value is from host
77   //                     (true) or device (false).
78   // err_on_failure:     Whether to print an error if the rocBLAS function
79   // fails. args:               Arguments of rocBLAS function.
80   template <typename FuncT, typename... Args>
81   bool DoBlasInternalImpl(FuncT rocblas_func, Stream *stream,
82                           bool pointer_mode_host, bool err_on_failure,
83                           Args... args);
84 
85   // Convenience functions that call DoBlasInternalImpl with different values
86   // for err_on_failure.
87   template <typename FuncT, typename... Args>
DoBlasInternal(FuncT rocblas_func,Stream * stream,bool pointer_mode_host,Args...args)88   bool DoBlasInternal(FuncT rocblas_func, Stream *stream,
89                       bool pointer_mode_host, Args... args) {
90     return DoBlasInternalImpl(rocblas_func, stream, pointer_mode_host,
91                               /*err_on_failure=*/true, args...);
92   }
93   template <typename FuncT, typename... Args>
DoBlasInternalFailureOK(FuncT rocblas_func,Stream * stream,bool pointer_mode_host,Args...args)94   bool DoBlasInternalFailureOK(FuncT rocblas_func, Stream *stream,
95                                bool pointer_mode_host, Args... args) {
96     return DoBlasInternalImpl(rocblas_func, stream, pointer_mode_host,
97                               /*err_on_failure=*/false, args...);
98   }
99 
100   // A helper function to implement DoBlasGemmBatched interfaces for generic
101   // types.
102   template <typename T, typename FuncT>
103   port::Status DoBlasGemmBatchedInternal(
104       FuncT rocblas_func, Stream *stream, blas::Transpose transa,
105       blas::Transpose transb, uint64 m, uint64 n, uint64 k, T alpha,
106       const port::ArraySlice<DeviceMemory<T> *> &a_array, int lda,
107       const port::ArraySlice<DeviceMemory<T> *> &b_array, int ldb, T beta,
108       const port::ArraySlice<DeviceMemory<T> *> &c_array, int ldc,
109       int batch_count, ScratchAllocator *scratch_allocator);
110 
111   // Helper function for implementing DoBlasGemmWithAlgorithm.
112   //
113   // We take alpha and beta by const reference because T might be Eigen::half,
114   // and we want to avoid pulling in a dependency on Eigen.  When we pass the
115   // references to rocBLAS, we essentially reinterpret_cast to __half, which is
116   // safe because Eigen::half inherits from __half.
117   template <typename InT, typename OutT, typename CompT>
118   bool DoBlasGemmWithAlgorithmImpl(
119       Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m,
120       uint64 n, uint64 k, const CompT &alpha, const DeviceMemory<InT> &a,
121       int lda, const DeviceMemory<InT> &b, int ldb, const CompT &beta,
122       DeviceMemory<OutT> *c, int ldc, blas::ComputationType computation_type,
123       blas::AlgorithmType algorithm,
124       blas::ProfileResult *output_profile_result);
125 
126   // Helper function for implementing DoBlasGemmWithProfiling.
127   template <typename T, typename ParamType>
128   bool DoBlasGemmWithProfilingImpl(
129       Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m,
130       uint64 n, uint64 k, const ParamType &alpha, const DeviceMemory<T> &a,
131       int lda, const DeviceMemory<T> &b, int ldb, const ParamType &beta,
132       DeviceMemory<T> *c, int ldc, blas::ProfileResult *output_profile_result);
133 
134   // Helper function for implementing DoBlasGemvWithProfiling.
135   template <typename T>
136   bool DoBlasGemvWithProfilingImpl(Stream *stream, blas::Transpose trans,
137                                    uint64 m, uint64 n, const T &alpha,
138                                    const DeviceMemory<T> &a, int lda,
139                                    const DeviceMemory<T> &x, int incx,
140                                    const T &beta, DeviceMemory<T> *y, int incy,
141                                    blas::ProfileResult *output_profile_result);
142 
143   // mutex that guards the rocBLAS handle for this device.
144   mutex mu_;
145 
146   // GpuExecutor which instantiated this ROCMBlas.
147   // Immutable post-initialization.
148   GpuExecutor *parent_;
149 
150   // rocBLAS library handle on the device.
151   rocblas_handle blas_ GUARDED_BY(mu_);
152 
153   SE_DISALLOW_COPY_AND_ASSIGN(ROCMBlas);
154 };
155 
156 }  // namespace gpu
157 }  // namespace stream_executor
158 
159 #endif  // TENSORFLOW_STREAM_EXECUTOR_ROCM_ROCM_BLAS_H_
160