1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 // ROCM-specific support for BLAS functionality -- this wraps the rocBLAS 17 // library capabilities, and is only included into ROCM implementation code -- 18 // it will not introduce rocm headers into other code. 19 20 #ifndef TENSORFLOW_STREAM_EXECUTOR_ROCM_ROCM_BLAS_H_ 21 #define TENSORFLOW_STREAM_EXECUTOR_ROCM_ROCM_BLAS_H_ 22 23 #include "tensorflow/stream_executor/blas.h" 24 #include "tensorflow/stream_executor/platform/mutex.h" 25 #include "tensorflow/stream_executor/platform/port.h" 26 #include "tensorflow/stream_executor/platform/thread_annotations.h" 27 #include "tensorflow/stream_executor/plugin_registry.h" 28 29 namespace stream_executor { 30 31 class Stream; 32 33 namespace gpu { 34 35 // Opaque and unique identifier for the rocBLAS plugin. 36 extern const PluginId kRocBlasPlugin; 37 38 class GpuExecutor; 39 40 // BLAS plugin for ROCM platform via rocBLAS library. 41 // 42 // This satisfies the platform-agnostic BlasSupport interface. 43 // 44 // Note that the rocBLAS handle that this encapsulates is implicitly tied to the 45 // context (and, as a result, the device) that the parent GpuExecutor is tied 46 // to. This simply happens as an artifact of creating the rocBLAS handle when a 47 // ROCM context is active. 48 // 49 // Thread-safe post-initialization. 50 class ROCMBlas : public blas::BlasSupport { 51 public: 52 explicit ROCMBlas(GpuExecutor *parent); 53 54 // Allocates a rocBLAS handle. 55 bool Init(); 56 57 // Releases the rocBLAS handle, if present. 58 ~ROCMBlas() override; 59 60 TENSORFLOW_STREAM_EXECUTOR_GPU_BLAS_SUPPORT_OVERRIDES 61 62 private: 63 // Tells rocBLAS to enqueue the BLAS operation onto a particular Stream. 64 // 65 // rocBLAS is stateful, and only be associated with one stream (in order to 66 // enqueue dispatch) at a given time. As a result, this generally must be 67 // invoked before calling into rocBLAS. 68 bool SetStream(Stream *stream) EXCLUSIVE_LOCKS_REQUIRED(mu_); 69 70 // A helper function that calls the real rocBLAS function together with error 71 // handling. 72 // 73 // rocblas_func: rocBLAS function pointer. 74 // rocblas_name: rocBLAS function name. 75 // stream: Stream to enqueue the BLAS operation onto. 76 // pointer_mode_host: Indicate if the pointer to a scalar value is from host 77 // (true) or device (false). 78 // err_on_failure: Whether to print an error if the rocBLAS function 79 // fails. args: Arguments of rocBLAS function. 80 template <typename FuncT, typename... Args> 81 bool DoBlasInternalImpl(FuncT rocblas_func, Stream *stream, 82 bool pointer_mode_host, bool err_on_failure, 83 Args... args); 84 85 // Convenience functions that call DoBlasInternalImpl with different values 86 // for err_on_failure. 87 template <typename FuncT, typename... Args> DoBlasInternal(FuncT rocblas_func,Stream * stream,bool pointer_mode_host,Args...args)88 bool DoBlasInternal(FuncT rocblas_func, Stream *stream, 89 bool pointer_mode_host, Args... args) { 90 return DoBlasInternalImpl(rocblas_func, stream, pointer_mode_host, 91 /*err_on_failure=*/true, args...); 92 } 93 template <typename FuncT, typename... Args> DoBlasInternalFailureOK(FuncT rocblas_func,Stream * stream,bool pointer_mode_host,Args...args)94 bool DoBlasInternalFailureOK(FuncT rocblas_func, Stream *stream, 95 bool pointer_mode_host, Args... args) { 96 return DoBlasInternalImpl(rocblas_func, stream, pointer_mode_host, 97 /*err_on_failure=*/false, args...); 98 } 99 100 // A helper function to implement DoBlasGemmBatched interfaces for generic 101 // types. 102 template <typename T, typename FuncT> 103 port::Status DoBlasGemmBatchedInternal( 104 FuncT rocblas_func, Stream *stream, blas::Transpose transa, 105 blas::Transpose transb, uint64 m, uint64 n, uint64 k, T alpha, 106 const port::ArraySlice<DeviceMemory<T> *> &a_array, int lda, 107 const port::ArraySlice<DeviceMemory<T> *> &b_array, int ldb, T beta, 108 const port::ArraySlice<DeviceMemory<T> *> &c_array, int ldc, 109 int batch_count, ScratchAllocator *scratch_allocator); 110 111 // Helper function for implementing DoBlasGemmWithAlgorithm. 112 // 113 // We take alpha and beta by const reference because T might be Eigen::half, 114 // and we want to avoid pulling in a dependency on Eigen. When we pass the 115 // references to rocBLAS, we essentially reinterpret_cast to __half, which is 116 // safe because Eigen::half inherits from __half. 117 template <typename InT, typename OutT, typename CompT> 118 bool DoBlasGemmWithAlgorithmImpl( 119 Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m, 120 uint64 n, uint64 k, const CompT &alpha, const DeviceMemory<InT> &a, 121 int lda, const DeviceMemory<InT> &b, int ldb, const CompT &beta, 122 DeviceMemory<OutT> *c, int ldc, blas::ComputationType computation_type, 123 blas::AlgorithmType algorithm, 124 blas::ProfileResult *output_profile_result); 125 126 // Helper function for implementing DoBlasGemmWithProfiling. 127 template <typename T, typename ParamType> 128 bool DoBlasGemmWithProfilingImpl( 129 Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m, 130 uint64 n, uint64 k, const ParamType &alpha, const DeviceMemory<T> &a, 131 int lda, const DeviceMemory<T> &b, int ldb, const ParamType &beta, 132 DeviceMemory<T> *c, int ldc, blas::ProfileResult *output_profile_result); 133 134 // Helper function for implementing DoBlasGemvWithProfiling. 135 template <typename T> 136 bool DoBlasGemvWithProfilingImpl(Stream *stream, blas::Transpose trans, 137 uint64 m, uint64 n, const T &alpha, 138 const DeviceMemory<T> &a, int lda, 139 const DeviceMemory<T> &x, int incx, 140 const T &beta, DeviceMemory<T> *y, int incy, 141 blas::ProfileResult *output_profile_result); 142 143 // mutex that guards the rocBLAS handle for this device. 144 mutex mu_; 145 146 // GpuExecutor which instantiated this ROCMBlas. 147 // Immutable post-initialization. 148 GpuExecutor *parent_; 149 150 // rocBLAS library handle on the device. 151 rocblas_handle blas_ GUARDED_BY(mu_); 152 153 SE_DISALLOW_COPY_AND_ASSIGN(ROCMBlas); 154 }; 155 156 } // namespace gpu 157 } // namespace stream_executor 158 159 #endif // TENSORFLOW_STREAM_EXECUTOR_ROCM_ROCM_BLAS_H_ 160