1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "rocm/include/rocblas.h"
17 
18 #include "tensorflow/stream_executor/rocm/rocm_blas.h"
19 
20 #define EIGEN_USE_GPU
21 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
22 
23 #include <assert.h>
24 #include <complex>
25 
26 #include "absl/strings/str_cat.h"
27 #include "tensorflow/stream_executor/device_memory.h"
28 #include "tensorflow/stream_executor/gpu/gpu_activation.h"
29 #include "tensorflow/stream_executor/gpu/gpu_executor.h"
30 #include "tensorflow/stream_executor/gpu/gpu_helpers.h"
31 #include "tensorflow/stream_executor/gpu/gpu_stream.h"
32 #include "tensorflow/stream_executor/gpu/gpu_timer.h"
33 #include "tensorflow/stream_executor/lib/env.h"
34 #include "tensorflow/stream_executor/lib/initialize.h"
35 #include "tensorflow/stream_executor/lib/status.h"
36 #include "tensorflow/stream_executor/lib/status_macros.h"
37 #include "tensorflow/stream_executor/lib/stringprintf.h"
38 #include "tensorflow/stream_executor/platform/dso_loader.h"
39 #include "tensorflow/stream_executor/platform/logging.h"
40 #include "tensorflow/stream_executor/platform/port.h"
41 #include "tensorflow/stream_executor/plugin_registry.h"
42 #include "tensorflow/stream_executor/rocm/rocm_platform_id.h"
43 #include "tensorflow/stream_executor/scratch_allocator.h"
44 #include "tensorflow/stream_executor/stream_executor.h"
45 
46 namespace stream_executor {
47 namespace gpu {
48 
49 PLUGIN_REGISTRY_DEFINE_PLUGIN_ID(kRocBlasPlugin);
50 
51 namespace wrap {
52 
53 #ifdef PLATFORM_GOOGLE
54 #define STREAM_EXECUTOR_ROCBLAS_WRAP(__name)                       \
55   struct WrapperShim__##__name {                                   \
56     static const char *kName;                                      \
57     template <typename... Args>                                    \
58     rocblas_status operator()(GpuExecutor *parent, Args... args) { \
59       gpu::ScopedActivateExecutorContext sac{parent};              \
60       return ::__name(args...);                                    \
61     }                                                              \
62   } __name;                                                        \
63   const char *WrapperShim__##__name::kName = #__name;
64 
65 #define STREAM_EXECUTOR_ROCBLAS_V2_WRAP(__name) \
66   STREAM_EXECUTOR_ROCBLAS_WRAP(__name)
67 
68 #else
69 
70 #define STREAM_EXECUTOR_ROCBLAS_WRAP(__name)                              \
71   struct DynLoadShim__##__name {                                          \
72     static const char *kName;                                             \
73     using FuncPtrT = std::add_pointer<decltype(::__name)>::type;          \
74     static void *GetDsoHandle() {                                         \
75       auto s = internal::CachedDsoLoader::GetRocblasDsoHandle();          \
76       return s.ValueOrDie();                                              \
77     }                                                                     \
78     static FuncPtrT LoadOrDie() {                                         \
79       void *f;                                                            \
80       auto s = port::Env::Default()->GetSymbolFromLibrary(GetDsoHandle(), \
81                                                           kName, &f);     \
82       CHECK(s.ok()) << "could not find " << kName                         \
83                     << " in rocblas DSO; dlerror: " << s.error_message(); \
84       return reinterpret_cast<FuncPtrT>(f);                               \
85     }                                                                     \
86     static FuncPtrT DynLoad() {                                           \
87       static FuncPtrT f = LoadOrDie();                                    \
88       return f;                                                           \
89     }                                                                     \
90     template <typename... Args>                                           \
91     rocblas_status operator()(GpuExecutor *parent, Args... args) {        \
92       gpu::ScopedActivateExecutorContext sac{parent};                     \
93       return DynLoad()(args...);                                          \
94     }                                                                     \
95   } __name;                                                               \
96   const char *DynLoadShim__##__name::kName = #__name;
97 
98 #define STREAM_EXECUTOR_ROCBLAS_V2_WRAP(__name) \
99   STREAM_EXECUTOR_ROCBLAS_WRAP(__name)
100 
101 #endif
102 
103 #define ROCBLAS_BLAS_ROUTINE_EACH(__macro)                                     \
104   __macro(rocblas_snrm2) __macro(rocblas_dnrm2) /*  __macro(rocblas_scnrm2)    \
105                                                   __macro(rocblas_dznrm2) */   \
106       __macro(rocblas_sdot)                                                    \
107           __macro(rocblas_ddot) /*  __macro(rocblas_cdotu)                     \
108                                   __macro(rocblas_cdotc)                       \
109                                   __macro(rocblas_zdotu)                       \
110                                   __macro(rocblas_zdotc)                    */ \
111       __macro(rocblas_sscal)                                                   \
112           __macro(rocblas_dscal) /*  __macro(rocblas_cscal)                    \
113                                    __macro(rocblas_csscal)                     \
114                                    __macro(rocblas_zscal)                      \
115                                    __macro(rocblas_zdscal) */                  \
116       __macro(rocblas_saxpy)                                                   \
117           __macro(rocblas_daxpy) /*  __macro(rocblas_caxpy)                    \
118                                    __macro(rocblas_zaxpy) */                   \
119       __macro(rocblas_scopy)                                                   \
120           __macro(rocblas_dcopy) /*  __macro(rocblas_ccopy)                    \
121                                    __macro(rocblas_zcopy) */                   \
122       __macro(rocblas_sswap)                                                   \
123           __macro(rocblas_dswap) /*  __macro(rocblas_cswap)                    \
124                                    __macro(rocblas_zswap) */                   \
125       __macro(rocblas_isamax)                                                  \
126           __macro(rocblas_idamax) /*  __macro(rocblas_icamax)                  \
127                                     __macro(rocblas_izamax) */                 \
128       __macro(rocblas_isamin)                                                  \
129           __macro(rocblas_idamin) /*  __macro(rocblas_icamin)                  \
130                                     __macro(rocblas_izamin) */                 \
131       __macro(rocblas_sasum)                                                   \
132           __macro(rocblas_dasum) /*  __macro(rocblas_scasum)                   \
133                                    __macro(rocblas_dzasum)                     \
134                                    __macro(rocblas_srot)                       \
135                                    __macro(rocblas_drot)                       \
136                                    __macro(rocblas_crot)                       \
137                                    __macro(rocblas_csrot)                      \
138                                    __macro(rocblas_zrot)                       \
139                                    __macro(rocblas_zdrot)                      \
140                                    __macro(rocblas_srotg)                      \
141                                    __macro(rocblas_drotg)                      \
142                                    __macro(rocblas_Crotg)                      \
143                                    __macro(rocblas_crotg)                      \
144                                    __macro(rocblas_zrotm)                      \
145                                    __macro(rocblas_drotm)                      \
146                                    __macro(rocblas_srotmg)                     \
147                                    __macro(rocblas_drotmg) */                  \
148       __macro(rocblas_sgemv)                                                   \
149           __macro(rocblas_dgemv) /*  __macro(rocblas_cgemv)                    \
150                                    __macro(rocblas_zgemv)                      \
151                                    __macro(rocblas_sgbmv)                      \
152                                    __macro(rocblas_dgbmv)                      \
153                                    __macro(rocblas_cgbmv)                      \
154                                    __macro(rocblas_zgbmv)                      \
155                                    __macro(rocblas_strmv)                      \
156                                    __macro(rocblas_dtrmv)                      \
157                                    __macro(rocblas_ctrmv)                      \
158                                    __macro(rocblas_ztrmv)                      \
159                                    __macro(rocblas_stbmv)                      \
160                                    __macro(rocblas_dtbmv)                      \
161                                    __macro(rocblas_ctbmv)                      \
162                                    __macro(rocblas_ztbmv)                      \
163                                    __macro(rocblas_stpmv)                      \
164                                    __macro(rocblas_dtpmv)                      \
165                                    __macro(rocblas_ctpmv)                      \
166                                    __macro(rocblas_ztpmv)                      \
167                                    __macro(rocblas_strsv)                      \
168                                    __macro(rocblas_dtrsv)                      \
169                                    __macro(rocblas_ctrsv)                      \
170                                    __macro(rocblas_ztrsv)                      \
171                                    __macro(rocblas_stpsv)                      \
172                                    __macro(rocblas_dtpsv)                      \
173                                    __macro(rocblas_ctpsv)                      \
174                                    __macro(rocblas_ztpsv)                      \
175                                    __macro(rocblas_stbsv)                      \
176                                    __macro(rocblas_dtbsv)                      \
177                                    __macro(rocblas_ctbsv)                      \
178                                    __macro(rocblas_ztbsv)                      \
179                                    __macro(rocblas_ssymv)                      \
180                                    __macro(rocblas_dsymv)                      \
181                                    __macro(rocblas_csymv)                      \
182                                    __macro(rocblas_zsymv)                      \
183                                    __macro(rocblas_chemv)                      \
184                                    __macro(rocblas_zhemv)                      \
185                                    __macro(rocblas_ssbmv)                      \
186                                    __macro(rocblas_dsbmv)                      \
187                                    __macro(rocblas_chbmv)                      \
188                                    __macro(rocblas_zhbmv)                      \
189                                    __macro(rocblas_sspmv)                      \
190                                    __macro(rocblas_dspmv)                      \
191                                    __macro(rocblas_chpmv)                      \
192                                    __macro(rocblas_zhpmv) */                   \
193       __macro(rocblas_sger)                                                    \
194           __macro(rocblas_dger) /*  __macro(rocblas_cgeru)                     \
195                                   __macro(rocblas_cgerc)                       \
196                                   __macro(rocblas_zgeru)                       \
197                                   __macro(rocblas_zgerc)                    */ \
198       __macro(rocblas_ssyr)                                                    \
199           __macro(rocblas_dsyr) /*  __macro(rocblas_csyr)                      \
200                                   __macro(rocblas_zsyr)                        \
201                                   __macro(rocblas_cher)                        \
202                                   __macro(rocblas_zher)                        \
203                                   __macro(rocblas_sspr)                        \
204                                   __macro(rocblas_dspr)                        \
205                                   __macro(rocblas_chpr)                        \
206                                   __macro(rocblas_zhpr)                        \
207                                   __macro(rocblas_ssyr2)                       \
208                                   __macro(rocblas_dsyr2)                       \
209                                   __macro(rocblas_csyr2)                       \
210                                   __macro(rocblas_zsyr2)                       \
211                                   __macro(rocblas_cher2)                       \
212                                   __macro(rocblas_zher2)                       \
213                                   __macro(rocblas_sspr2)                       \
214                                   __macro(rocblas_dspr2)                       \
215                                   __macro(rocblas_chpr2)                       \
216                                   __macro(rocblas_zhpr2)                    */ \
217       __macro(rocblas_sgemm) __macro(rocblas_dgemm)                            \
218           __macro(rocblas_hgemm) /*  __macro(rocblas_cgemm)                    \
219                                    __macro(rocblas_zgemm)                      \
220                                    __macro(rocblas_ssyrk)                      \
221                                    __macro(rocblas_dsyrk)                      \
222                                    __macro(rocblas_csyrk)                      \
223                                    __macro(rocblas_zsyrk)                      \
224                                    __macro(rocblas_cherk)                      \
225                                    __macro(rocblas_zherk)                      \
226                                    __macro(rocblas_ssyr2k)                     \
227                                    __macro(rocblas_dsyr2k)                     \
228                                    __macro(rocblas_csyr2k)                     \
229                                    __macro(rocblas_zsyr2k)                     \
230                                    __macro(rocblas_cher2k)                     \
231                                    __macro(rocblas_zher2k)                     \
232                                    __macro(rocblas_ssyrkx)                     \
233                                    __macro(rocblas_dsyrkx)                     \
234                                    __macro(rocblas_csyrkx)                     \
235                                    __macro(rocblas_zsyrkx)                     \
236                                    __macro(rocblas_cherkx)                     \
237                                    __macro(rocblas_zherkx)                     \
238                                    __macro(rocblas_ssymm)                      \
239                                    __macro(rocblas_dsymm)                      \
240                                    __macro(rocblas_csymm)                      \
241                                    __macro(rocblas_zsymm)                      \
242                                    __macro(rocblas_chemm)                      \
243                                    __macro(rocblas_zhemm) */                   \
244       __macro(rocblas_strsm)                                                   \
245           __macro(rocblas_dtrsm) /*  __macro(rocblas_ctrsm)                    \
246                                    __macro(rocblas_ztrsm)                      \
247                                    __macro(rocblas_strmm)                      \
248                                    __macro(rocblas_dtrmm)                      \
249                                    __macro(rocblas_ctrmm)                      \
250                                    __macro(rocblas_ztrmm) */                   \
251       __macro(rocblas_sgeam)                                                   \
252           __macro(rocblas_dgeam) /*  __macro(rocblas_cgeam)                    \
253                                    __macro(rocblas_zgeam)                      \
254                                    __macro(rocblas_sdgmm)                      \
255                                    __macro(rocblas_ddgmm)                      \
256                                    __macro(rocblas_cdgmm)                      \
257                                    __macro(rocblas_zdgmm) */
258 
259 STREAM_EXECUTOR_ROCBLAS_V2_WRAP(rocblas_create_handle)
260 STREAM_EXECUTOR_ROCBLAS_V2_WRAP(rocblas_destroy_handle)
261 STREAM_EXECUTOR_ROCBLAS_V2_WRAP(rocblas_set_stream)
262 // STREAM_EXECUTOR_ROCBLAS_V2_WRAP(rocblas_set_pointer_mode)
263 // STREAM_EXECUTOR_ROCBLAS_V2_WRAP(rocblas_get_pointer_mode)
264 // STREAM_EXECUTOR_ROCBLAS_WRAP(rocblas_sgemm_batched)
265 STREAM_EXECUTOR_ROCBLAS_WRAP(rocblas_hgemm_strided_batched)
266 STREAM_EXECUTOR_ROCBLAS_WRAP(rocblas_sgemm_strided_batched)
267 // STREAM_EXECUTOR_ROCBLAS_WRAP(rocblas_dgemm_batched)
268 STREAM_EXECUTOR_ROCBLAS_WRAP(rocblas_dgemm_strided_batched)
269 // STREAM_EXECUTOR_ROCBLAS_WRAP(rocblas_cgemm_batched)
270 // STREAM_EXECUTOR_ROCBLAS_WRAP(rocblas_zgemm_batched)
271 ROCBLAS_BLAS_ROUTINE_EACH(STREAM_EXECUTOR_ROCBLAS_V2_WRAP)
272 
273 }  // namespace wrap
274 
ToString(rocblas_status status)275 static string ToString(rocblas_status status) {
276   switch (status) {
277     case rocblas_status_success:
278       return "rocblas_status_success";
279     case rocblas_status_invalid_handle:
280       return "rocblas_status_invalid_handle";
281     case rocblas_status_not_implemented:
282       return "rocblas_status_not_implemented";
283     case rocblas_status_invalid_pointer:
284       return "rocblas_status_invalid_pointer";
285     case rocblas_status_invalid_size:
286       return "rocblas_status_invalid_size";
287     case rocblas_status_memory_error:
288       return "rocblas_status_memory_error";
289     case rocblas_status_internal_error:
290       return "rocblas_status_internal_error";
291     default:
292       return absl::StrCat("<invalid rocBLAS status: ", status, ">");
293   }
294 }
295 
Init()296 bool ROCMBlas::Init() {
297   rocblas_status ret = wrap::rocblas_create_handle(parent_, &blas_);
298   if (ret != rocblas_status_success) {
299     LOG(ERROR) << "failed to create rocBLAS handle: " << ToString(ret);
300     return false;
301   }
302 
303   return true;
304 }
305 
ROCMBlas(gpu::GpuExecutor * parent)306 ROCMBlas::ROCMBlas(gpu::GpuExecutor *parent)
307     : parent_(CHECK_NOTNULL(parent)), blas_(nullptr) {}
308 
~ROCMBlas()309 ROCMBlas::~ROCMBlas() {
310   if (blas_ != nullptr) {
311     wrap::rocblas_destroy_handle(parent_, blas_);
312   }
313 }
314 
SetStream(Stream * stream)315 bool ROCMBlas::SetStream(Stream *stream) {
316   CHECK(stream != nullptr);
317   CHECK(AsGpuStreamValue(stream) != nullptr);
318   CHECK(blas_ != nullptr);
319   rocblas_status ret =
320       wrap::rocblas_set_stream(parent_, blas_, AsGpuStreamValue(stream));
321   if (ret != rocblas_status_success) {
322     LOG(ERROR) << "failed to set stream for rocBLAS calls: " << ToString(ret);
323     return false;
324   }
325 
326   return true;
327 }
328 
329 namespace {
330 
331 // Helper functions transforming blas arguments into rocBLAS arguments.
332 
ROCMBlasTranspose(blas::Transpose trans)333 rocblas_operation ROCMBlasTranspose(blas::Transpose trans) {
334   switch (trans) {
335     case blas::Transpose::kNoTranspose:
336       return rocblas_operation_none;
337     case blas::Transpose::kTranspose:
338       return rocblas_operation_transpose;
339     case blas::Transpose::kConjugateTranspose:
340       return rocblas_operation_conjugate_transpose;
341     default:
342       LOG(FATAL) << "Invalid value of blas::Transpose.";
343   }
344 }
345 
ROCMBlasUpperLower(blas::UpperLower uplo)346 rocblas_fill ROCMBlasUpperLower(blas::UpperLower uplo) {
347   switch (uplo) {
348     case blas::UpperLower::kUpper:
349       return rocblas_fill_upper;
350     case blas::UpperLower::kLower:
351       return rocblas_fill_lower;
352     default:
353       LOG(FATAL) << "Invalid value of blas::UpperLower.";
354   }
355 }
356 
ROCMBlasDiagonal(blas::Diagonal diag)357 rocblas_diagonal ROCMBlasDiagonal(blas::Diagonal diag) {
358   switch (diag) {
359     case blas::Diagonal::kUnit:
360       return rocblas_diagonal_unit;
361     case blas::Diagonal::kNonUnit:
362       return rocblas_diagonal_non_unit;
363     default:
364       LOG(FATAL) << "Invalid value of blas::Diagonal.";
365   }
366 }
367 
ROCMBlasSide(blas::Side side)368 rocblas_side ROCMBlasSide(blas::Side side) {
369   switch (side) {
370     case blas::Side::kLeft:
371       return rocblas_side_left;
372     case blas::Side::kRight:
373       return rocblas_side_right;
374     default:
375       LOG(FATAL) << "Invalid value of blas::Side.";
376   }
377 }
378 
379 }  // namespace
380 
381 template <typename FuncT, typename... Args>
DoBlasInternalImpl(FuncT rocblas_func,Stream * stream,bool pointer_mode_host,bool err_on_failure,Args...args)382 bool ROCMBlas::DoBlasInternalImpl(FuncT rocblas_func, Stream *stream,
383                                   bool pointer_mode_host, bool err_on_failure,
384                                   Args... args) {
385   mutex_lock lock{mu_};
386 
387   CHECK(blas_ != nullptr);
388   if (!SetStream(stream)) {
389     return false;
390   }
391 
392   rocblas_status ret = rocblas_func(parent_, blas_, args...);
393   if (err_on_failure && ret != rocblas_status_success) {
394     LOG(ERROR) << "failed to run ROCBLAS routine " << rocblas_func.kName << ": "
395                << ToString(ret);
396   }
397   return ret == rocblas_status_success;
398 }
399 
DoBlasAsum(Stream * stream,uint64 elem_count,const DeviceMemory<float> & x,int incx,DeviceMemory<float> * result)400 bool ROCMBlas::DoBlasAsum(Stream *stream, uint64 elem_count,
401                           const DeviceMemory<float> &x, int incx,
402                           DeviceMemory<float> *result) {
403   return DoBlasInternal(wrap::rocblas_sasum, stream,
404                         false /* = pointer_mode_host */, elem_count,
405                         GpuMemory(x), incx, GpuMemoryMutable(result));
406 }
407 
DoBlasAsum(Stream * stream,uint64 elem_count,const DeviceMemory<double> & x,int incx,DeviceMemory<double> * result)408 bool ROCMBlas::DoBlasAsum(Stream *stream, uint64 elem_count,
409                           const DeviceMemory<double> &x, int incx,
410                           DeviceMemory<double> *result) {
411   return DoBlasInternal(wrap::rocblas_dasum, stream,
412                         false /* = pointer_mode_host */, elem_count,
413                         GpuMemory(x), incx, GpuMemoryMutable(result));
414 }
415 
DoBlasAsum(Stream * stream,uint64 elem_count,const DeviceMemory<std::complex<float>> & x,int incx,DeviceMemory<float> * result)416 bool ROCMBlas::DoBlasAsum(Stream *stream, uint64 elem_count,
417                           const DeviceMemory<std::complex<float>> &x, int incx,
418                           DeviceMemory<float> *result) {
419   LOG(ERROR) << "rocBLAS does not currently support the ASUM operation "
420              << "for the \"complex<float>\" dataype";
421   return false;
422 }
423 
DoBlasAsum(Stream * stream,uint64 elem_count,const DeviceMemory<std::complex<double>> & x,int incx,DeviceMemory<double> * result)424 bool ROCMBlas::DoBlasAsum(Stream *stream, uint64 elem_count,
425                           const DeviceMemory<std::complex<double>> &x, int incx,
426                           DeviceMemory<double> *result) {
427   LOG(ERROR) << "rocBLAS does not currently support the ASUM operation "
428              << "for the \"complex<double>\" dataype";
429   return false;
430 }
431 
DoBlasAxpy(Stream * stream,uint64 elem_count,float alpha,const DeviceMemory<float> & x,int incx,DeviceMemory<float> * y,int incy)432 bool ROCMBlas::DoBlasAxpy(Stream *stream, uint64 elem_count, float alpha,
433                           const DeviceMemory<float> &x, int incx,
434                           DeviceMemory<float> *y, int incy) {
435   return DoBlasInternal(wrap::rocblas_saxpy, stream,
436                         true /* = pointer_mode_host */, elem_count, &alpha,
437                         GpuMemory(x), incx, GpuMemoryMutable(y), incy);
438 }
439 
DoBlasAxpy(Stream * stream,uint64 elem_count,double alpha,const DeviceMemory<double> & x,int incx,DeviceMemory<double> * y,int incy)440 bool ROCMBlas::DoBlasAxpy(Stream *stream, uint64 elem_count, double alpha,
441                           const DeviceMemory<double> &x, int incx,
442                           DeviceMemory<double> *y, int incy) {
443   return DoBlasInternal(wrap::rocblas_daxpy, stream,
444                         true /* = pointer_mode_host */, elem_count, &alpha,
445                         GpuMemory(x), incx, GpuMemoryMutable(y), incy);
446 }
447 
DoBlasAxpy(Stream * stream,uint64 elem_count,std::complex<float> alpha,const DeviceMemory<std::complex<float>> & x,int incx,DeviceMemory<std::complex<float>> * y,int incy)448 bool ROCMBlas::DoBlasAxpy(Stream *stream, uint64 elem_count,
449                           std::complex<float> alpha,
450                           const DeviceMemory<std::complex<float>> &x, int incx,
451                           DeviceMemory<std::complex<float>> *y, int incy) {
452   LOG(ERROR) << "rocBLAS does not currently support the AXPY operation "
453              << "for the \"complex<float>\" dataype";
454   return false;
455 }
456 
DoBlasAxpy(Stream * stream,uint64 elem_count,std::complex<double> alpha,const DeviceMemory<std::complex<double>> & x,int incx,DeviceMemory<std::complex<double>> * y,int incy)457 bool ROCMBlas::DoBlasAxpy(Stream *stream, uint64 elem_count,
458                           std::complex<double> alpha,
459                           const DeviceMemory<std::complex<double>> &x, int incx,
460                           DeviceMemory<std::complex<double>> *y, int incy) {
461   LOG(ERROR) << "rocBLAS does not currently support the AXPY operation "
462              << "for the \"complex<double>\" dataype";
463   return false;
464 }
465 
DoBlasCopy(Stream * stream,uint64 elem_count,const DeviceMemory<float> & x,int incx,DeviceMemory<float> * y,int incy)466 bool ROCMBlas::DoBlasCopy(Stream *stream, uint64 elem_count,
467                           const DeviceMemory<float> &x, int incx,
468                           DeviceMemory<float> *y, int incy) {
469   return DoBlasInternal(wrap::rocblas_scopy, stream,
470                         true /* = pointer_mode_host */, elem_count,
471                         GpuMemory(x), incx, GpuMemoryMutable(y), incy);
472 }
473 
DoBlasCopy(Stream * stream,uint64 elem_count,const DeviceMemory<double> & x,int incx,DeviceMemory<double> * y,int incy)474 bool ROCMBlas::DoBlasCopy(Stream *stream, uint64 elem_count,
475                           const DeviceMemory<double> &x, int incx,
476                           DeviceMemory<double> *y, int incy) {
477   return DoBlasInternal(wrap::rocblas_dcopy, stream,
478                         true /* = pointer_mode_host */, elem_count,
479                         GpuMemory(x), incx, GpuMemoryMutable(y), incy);
480 }
481 
DoBlasCopy(Stream * stream,uint64 elem_count,const DeviceMemory<std::complex<float>> & x,int incx,DeviceMemory<std::complex<float>> * y,int incy)482 bool ROCMBlas::DoBlasCopy(Stream *stream, uint64 elem_count,
483                           const DeviceMemory<std::complex<float>> &x, int incx,
484                           DeviceMemory<std::complex<float>> *y, int incy) {
485   LOG(ERROR) << "rocBLAS does not currently support the COPY operation "
486              << "for the \"complex<float>\" dataype";
487   return false;
488 }
489 
DoBlasCopy(Stream * stream,uint64 elem_count,const DeviceMemory<std::complex<double>> & x,int incx,DeviceMemory<std::complex<double>> * y,int incy)490 bool ROCMBlas::DoBlasCopy(Stream *stream, uint64 elem_count,
491                           const DeviceMemory<std::complex<double>> &x, int incx,
492                           DeviceMemory<std::complex<double>> *y, int incy) {
493   LOG(ERROR) << "rocBLAS does not currently support the COPY operation "
494              << "for the \"complex<double>\" dataype";
495   return false;
496 }
497 
DoBlasDot(Stream * stream,uint64 elem_count,const DeviceMemory<float> & x,int incx,const DeviceMemory<float> & y,int incy,DeviceMemory<float> * result)498 bool ROCMBlas::DoBlasDot(Stream *stream, uint64 elem_count,
499                          const DeviceMemory<float> &x, int incx,
500                          const DeviceMemory<float> &y, int incy,
501                          DeviceMemory<float> *result) {
502   return DoBlasInternal(
503       wrap::rocblas_sdot, stream, false /* = pointer_mode_host */, elem_count,
504       GpuMemory(x), incx, GpuMemory(y), incy, GpuMemoryMutable(result));
505 }
506 
DoBlasDot(Stream * stream,uint64 elem_count,const DeviceMemory<double> & x,int incx,const DeviceMemory<double> & y,int incy,DeviceMemory<double> * result)507 bool ROCMBlas::DoBlasDot(Stream *stream, uint64 elem_count,
508                          const DeviceMemory<double> &x, int incx,
509                          const DeviceMemory<double> &y, int incy,
510                          DeviceMemory<double> *result) {
511   return DoBlasInternal(
512       wrap::rocblas_ddot, stream, false /* = pointer_mode_host */, elem_count,
513       GpuMemory(x), incx, GpuMemory(y), incy, GpuMemoryMutable(result));
514 }
515 
DoBlasDotc(Stream * stream,uint64 elem_count,const DeviceMemory<std::complex<float>> & x,int incx,const DeviceMemory<std::complex<float>> & y,int incy,DeviceMemory<std::complex<float>> * result)516 bool ROCMBlas::DoBlasDotc(Stream *stream, uint64 elem_count,
517                           const DeviceMemory<std::complex<float>> &x, int incx,
518                           const DeviceMemory<std::complex<float>> &y, int incy,
519                           DeviceMemory<std::complex<float>> *result) {
520   LOG(ERROR) << "rocBLAS does not currently support the DOT operation "
521              << "for the \"complex<float>\" dataype";
522   return false;
523 }
524 
DoBlasDotc(Stream * stream,uint64 elem_count,const DeviceMemory<std::complex<double>> & x,int incx,const DeviceMemory<std::complex<double>> & y,int incy,DeviceMemory<std::complex<double>> * result)525 bool ROCMBlas::DoBlasDotc(Stream *stream, uint64 elem_count,
526                           const DeviceMemory<std::complex<double>> &x, int incx,
527                           const DeviceMemory<std::complex<double>> &y, int incy,
528                           DeviceMemory<std::complex<double>> *result) {
529   LOG(ERROR) << "rocBLAS does not currently support the DOT operation "
530              << "for the \"complex<double>\" dataype";
531   return false;
532 }
533 
DoBlasDotu(Stream * stream,uint64 elem_count,const DeviceMemory<std::complex<float>> & x,int incx,const DeviceMemory<std::complex<float>> & y,int incy,DeviceMemory<std::complex<float>> * result)534 bool ROCMBlas::DoBlasDotu(Stream *stream, uint64 elem_count,
535                           const DeviceMemory<std::complex<float>> &x, int incx,
536                           const DeviceMemory<std::complex<float>> &y, int incy,
537                           DeviceMemory<std::complex<float>> *result) {
538   LOG(ERROR) << "rocBLAS does not currently support the DOT operation "
539              << "for the \"complex<float>\" dataype";
540   return false;
541 }
542 
DoBlasDotu(Stream * stream,uint64 elem_count,const DeviceMemory<std::complex<double>> & x,int incx,const DeviceMemory<std::complex<double>> & y,int incy,DeviceMemory<std::complex<double>> * result)543 bool ROCMBlas::DoBlasDotu(Stream *stream, uint64 elem_count,
544                           const DeviceMemory<std::complex<double>> &x, int incx,
545                           const DeviceMemory<std::complex<double>> &y, int incy,
546                           DeviceMemory<std::complex<double>> *result) {
547   LOG(ERROR) << "rocBLAS does not currently support the DOT operation "
548              << "for the \"complex<double>\" dataype";
549   return false;
550 }
551 
DoBlasNrm2(Stream * stream,uint64 elem_count,const DeviceMemory<float> & x,int incx,DeviceMemory<float> * result)552 bool ROCMBlas::DoBlasNrm2(Stream *stream, uint64 elem_count,
553                           const DeviceMemory<float> &x, int incx,
554                           DeviceMemory<float> *result) {
555   return DoBlasInternal(wrap::rocblas_snrm2, stream,
556                         false /* = pointer_mode_host */, elem_count,
557                         GpuMemory(x), incx, GpuMemoryMutable(result));
558 }
559 
DoBlasNrm2(Stream * stream,uint64 elem_count,const DeviceMemory<double> & x,int incx,DeviceMemory<double> * result)560 bool ROCMBlas::DoBlasNrm2(Stream *stream, uint64 elem_count,
561                           const DeviceMemory<double> &x, int incx,
562                           DeviceMemory<double> *result) {
563   return DoBlasInternal(wrap::rocblas_dnrm2, stream,
564                         false /* = pointer_mode_host */, elem_count,
565                         GpuMemory(x), incx, GpuMemoryMutable(result));
566 }
567 
DoBlasNrm2(Stream * stream,uint64 elem_count,const DeviceMemory<std::complex<float>> & x,int incx,DeviceMemory<float> * result)568 bool ROCMBlas::DoBlasNrm2(Stream *stream, uint64 elem_count,
569                           const DeviceMemory<std::complex<float>> &x, int incx,
570                           DeviceMemory<float> *result) {
571   LOG(ERROR) << "rocBLAS does not currently support the NRM2 operation "
572              << "for the \"complex<float>\" dataype";
573   return false;
574 }
575 
DoBlasNrm2(Stream * stream,uint64 elem_count,const DeviceMemory<std::complex<double>> & x,int incx,DeviceMemory<double> * result)576 bool ROCMBlas::DoBlasNrm2(Stream *stream, uint64 elem_count,
577                           const DeviceMemory<std::complex<double>> &x, int incx,
578                           DeviceMemory<double> *result) {
579   LOG(ERROR) << "rocBLAS does not currently support the NRM2 operation "
580              << "for the \"complex<double>\" dataype";
581   return false;
582 }
583 
DoBlasRot(Stream * stream,uint64 elem_count,DeviceMemory<float> * x,int incx,DeviceMemory<float> * y,int incy,float c,float s)584 bool ROCMBlas::DoBlasRot(Stream *stream, uint64 elem_count,
585                          DeviceMemory<float> *x, int incx,
586                          DeviceMemory<float> *y, int incy, float c, float s) {
587   LOG(ERROR) << "rocBLAS does not currently support the ROT operation "
588              << "for the \"float\" dataype";
589   return false;
590 }
591 
DoBlasRot(Stream * stream,uint64 elem_count,DeviceMemory<double> * x,int incx,DeviceMemory<double> * y,int incy,double c,double s)592 bool ROCMBlas::DoBlasRot(Stream *stream, uint64 elem_count,
593                          DeviceMemory<double> *x, int incx,
594                          DeviceMemory<double> *y, int incy, double c,
595                          double s) {
596   LOG(ERROR) << "rocBLAS does not currently support the ROT operation "
597              << "for the \"double\" dataype";
598   return false;
599 }
600 
DoBlasRot(Stream * stream,uint64 elem_count,DeviceMemory<std::complex<float>> * x,int incx,DeviceMemory<std::complex<float>> * y,int incy,float c,float s)601 bool ROCMBlas::DoBlasRot(Stream *stream, uint64 elem_count,
602                          DeviceMemory<std::complex<float>> *x, int incx,
603                          DeviceMemory<std::complex<float>> *y, int incy,
604                          float c, float s) {
605   LOG(ERROR) << "rocBLAS does not currently support the ROT operation "
606              << "for the \"complex<float>\" dataype";
607   return false;
608 }
609 
DoBlasRot(Stream * stream,uint64 elem_count,DeviceMemory<std::complex<double>> * x,int incx,DeviceMemory<std::complex<double>> * y,int incy,double c,double s)610 bool ROCMBlas::DoBlasRot(Stream *stream, uint64 elem_count,
611                          DeviceMemory<std::complex<double>> *x, int incx,
612                          DeviceMemory<std::complex<double>> *y, int incy,
613                          double c, double s) {
614   LOG(ERROR) << "rocBLAS does not currently support the ROT operation "
615              << "for the \"complex<double>\" dataype";
616   return false;
617 }
618 
DoBlasRotg(Stream * stream,DeviceMemory<float> * a,DeviceMemory<float> * b,DeviceMemory<float> * c,DeviceMemory<float> * s)619 bool ROCMBlas::DoBlasRotg(Stream *stream, DeviceMemory<float> *a,
620                           DeviceMemory<float> *b, DeviceMemory<float> *c,
621                           DeviceMemory<float> *s) {
622   LOG(ERROR) << "rocBLAS does not currently support the ROTG operation "
623              << "for the \"float\" dataype";
624   return false;
625 }
626 
DoBlasRotg(Stream * stream,DeviceMemory<double> * a,DeviceMemory<double> * b,DeviceMemory<double> * c,DeviceMemory<double> * s)627 bool ROCMBlas::DoBlasRotg(Stream *stream, DeviceMemory<double> *a,
628                           DeviceMemory<double> *b, DeviceMemory<double> *c,
629                           DeviceMemory<double> *s) {
630   LOG(ERROR) << "rocBLAS does not currently support the ROTG operation "
631              << "for the \"double\" dataype";
632   return false;
633 }
634 
DoBlasRotg(Stream * stream,DeviceMemory<std::complex<float>> * a,DeviceMemory<std::complex<float>> * b,DeviceMemory<float> * c,DeviceMemory<std::complex<float>> * s)635 bool ROCMBlas::DoBlasRotg(Stream *stream, DeviceMemory<std::complex<float>> *a,
636                           DeviceMemory<std::complex<float>> *b,
637                           DeviceMemory<float> *c,
638                           DeviceMemory<std::complex<float>> *s) {
639   LOG(ERROR) << "rocBLAS does not currently support the ROTG operation "
640              << "for the \"complex<float>\" dataype";
641   return false;
642 }
643 
DoBlasRotg(Stream * stream,DeviceMemory<std::complex<double>> * a,DeviceMemory<std::complex<double>> * b,DeviceMemory<double> * c,DeviceMemory<std::complex<double>> * s)644 bool ROCMBlas::DoBlasRotg(Stream *stream, DeviceMemory<std::complex<double>> *a,
645                           DeviceMemory<std::complex<double>> *b,
646                           DeviceMemory<double> *c,
647                           DeviceMemory<std::complex<double>> *s) {
648   LOG(ERROR) << "rocBLAS does not currently support the ROTG operation "
649              << "for the \"complex<double>\" dataype";
650   return false;
651 }
652 
DoBlasRotm(Stream * stream,uint64 elem_count,DeviceMemory<float> * x,int incx,DeviceMemory<float> * y,int incy,const DeviceMemory<float> & param)653 bool ROCMBlas::DoBlasRotm(Stream *stream, uint64 elem_count,
654                           DeviceMemory<float> *x, int incx,
655                           DeviceMemory<float> *y, int incy,
656                           const DeviceMemory<float> &param) {
657   LOG(ERROR) << "rocBLAS does not currently support the ROTM operation "
658              << "for the \"float\" dataype";
659   return false;
660 }
661 
DoBlasRotm(Stream * stream,uint64 elem_count,DeviceMemory<double> * x,int incx,DeviceMemory<double> * y,int incy,const DeviceMemory<double> & param)662 bool ROCMBlas::DoBlasRotm(Stream *stream, uint64 elem_count,
663                           DeviceMemory<double> *x, int incx,
664                           DeviceMemory<double> *y, int incy,
665                           const DeviceMemory<double> &param) {
666   LOG(ERROR) << "rocBLAS does not currently support the ROTM operation "
667              << "for the \"double\" dataype";
668   return false;
669 }
670 
DoBlasRotmg(Stream * stream,DeviceMemory<float> * d1,DeviceMemory<float> * d2,DeviceMemory<float> * x1,const DeviceMemory<float> & y1,DeviceMemory<float> * param)671 bool ROCMBlas::DoBlasRotmg(Stream *stream, DeviceMemory<float> *d1,
672                            DeviceMemory<float> *d2, DeviceMemory<float> *x1,
673                            const DeviceMemory<float> &y1,
674                            DeviceMemory<float> *param) {
675   LOG(ERROR) << "rocBLAS does not currently support the ROTMG operation "
676              << "for the \"float\" dataype";
677   return false;
678 }
679 
DoBlasRotmg(Stream * stream,DeviceMemory<double> * d1,DeviceMemory<double> * d2,DeviceMemory<double> * x1,const DeviceMemory<double> & y1,DeviceMemory<double> * param)680 bool ROCMBlas::DoBlasRotmg(Stream *stream, DeviceMemory<double> *d1,
681                            DeviceMemory<double> *d2, DeviceMemory<double> *x1,
682                            const DeviceMemory<double> &y1,
683                            DeviceMemory<double> *param) {
684   LOG(ERROR) << "rocBLAS does not currently support the ROTMG operation "
685              << "for the \"double\" dataype";
686   return false;
687 }
688 
DoBlasScal(Stream * stream,uint64 elem_count,float alpha,DeviceMemory<float> * x,int incx)689 bool ROCMBlas::DoBlasScal(Stream *stream, uint64 elem_count, float alpha,
690                           DeviceMemory<float> *x, int incx) {
691   return DoBlasInternal(wrap::rocblas_sscal, stream,
692                         true /* = pointer_mode_host */, elem_count, &alpha,
693                         GpuMemoryMutable(x), incx);
694 }
695 
DoBlasScal(Stream * stream,uint64 elem_count,double alpha,DeviceMemory<double> * x,int incx)696 bool ROCMBlas::DoBlasScal(Stream *stream, uint64 elem_count, double alpha,
697                           DeviceMemory<double> *x, int incx) {
698   return DoBlasInternal(wrap::rocblas_dscal, stream,
699                         true /* = pointer_mode_host */, elem_count, &alpha,
700                         GpuMemoryMutable(x), incx);
701 }
702 
DoBlasScal(Stream * stream,uint64 elem_count,float alpha,DeviceMemory<std::complex<float>> * x,int incx)703 bool ROCMBlas::DoBlasScal(Stream *stream, uint64 elem_count, float alpha,
704                           DeviceMemory<std::complex<float>> *x, int incx) {
705   LOG(ERROR) << "rocBLAS does not currently support the SCAL operation "
706              << "for the \"complex<float>\" dataype";
707   return false;
708 }
709 
DoBlasScal(Stream * stream,uint64 elem_count,double alpha,DeviceMemory<std::complex<double>> * x,int incx)710 bool ROCMBlas::DoBlasScal(Stream *stream, uint64 elem_count, double alpha,
711                           DeviceMemory<std::complex<double>> *x, int incx) {
712   LOG(ERROR) << "rocBLAS does not currently support the SCAL operation "
713              << "for the \"complex<double>\" dataype";
714   return false;
715 }
716 
DoBlasScal(Stream * stream,uint64 elem_count,std::complex<float> alpha,DeviceMemory<std::complex<float>> * x,int incx)717 bool ROCMBlas::DoBlasScal(Stream *stream, uint64 elem_count,
718                           std::complex<float> alpha,
719                           DeviceMemory<std::complex<float>> *x, int incx) {
720   LOG(ERROR) << "rocBLAS does not currently support the SCAL operation "
721              << "for the \"complex<float>\" dataype";
722   return false;
723 }
724 
DoBlasScal(Stream * stream,uint64 elem_count,std::complex<double> alpha,DeviceMemory<std::complex<double>> * x,int incx)725 bool ROCMBlas::DoBlasScal(Stream *stream, uint64 elem_count,
726                           std::complex<double> alpha,
727                           DeviceMemory<std::complex<double>> *x, int incx) {
728   LOG(ERROR) << "rocBLAS does not currently support the SCAL operation "
729              << "for the \"complex<double>\" dataype";
730   return false;
731 }
732 
DoBlasSwap(Stream * stream,uint64 elem_count,DeviceMemory<float> * x,int incx,DeviceMemory<float> * y,int incy)733 bool ROCMBlas::DoBlasSwap(Stream *stream, uint64 elem_count,
734                           DeviceMemory<float> *x, int incx,
735                           DeviceMemory<float> *y, int incy) {
736   return DoBlasInternal(wrap::rocblas_sswap, stream,
737                         true /* = pointer_mode_host */, elem_count,
738                         GpuMemoryMutable(x), incx, GpuMemoryMutable(y), incy);
739 }
740 
DoBlasSwap(Stream * stream,uint64 elem_count,DeviceMemory<double> * x,int incx,DeviceMemory<double> * y,int incy)741 bool ROCMBlas::DoBlasSwap(Stream *stream, uint64 elem_count,
742                           DeviceMemory<double> *x, int incx,
743                           DeviceMemory<double> *y, int incy) {
744   return DoBlasInternal(wrap::rocblas_dswap, stream,
745                         true /* = pointer_mode_host */, elem_count,
746                         GpuMemoryMutable(x), incx, GpuMemoryMutable(y), incy);
747 }
748 
DoBlasSwap(Stream * stream,uint64 elem_count,DeviceMemory<std::complex<float>> * x,int incx,DeviceMemory<std::complex<float>> * y,int incy)749 bool ROCMBlas::DoBlasSwap(Stream *stream, uint64 elem_count,
750                           DeviceMemory<std::complex<float>> *x, int incx,
751                           DeviceMemory<std::complex<float>> *y, int incy) {
752   LOG(ERROR) << "rocBLAS does not currently support the SWAP operation "
753              << "for the \"complex<float>\" dataype";
754   return false;
755 }
756 
DoBlasSwap(Stream * stream,uint64 elem_count,DeviceMemory<std::complex<double>> * x,int incx,DeviceMemory<std::complex<double>> * y,int incy)757 bool ROCMBlas::DoBlasSwap(Stream *stream, uint64 elem_count,
758                           DeviceMemory<std::complex<double>> *x, int incx,
759                           DeviceMemory<std::complex<double>> *y, int incy) {
760   LOG(ERROR) << "rocBLAS does not currently support the SWAP operation "
761              << "for the \"complex<double>\" dataype";
762   return false;
763 }
764 
DoBlasIamax(Stream * stream,uint64 elem_count,const DeviceMemory<float> & x,int incx,DeviceMemory<int> * result)765 bool ROCMBlas::DoBlasIamax(Stream *stream, uint64 elem_count,
766                            const DeviceMemory<float> &x, int incx,
767                            DeviceMemory<int> *result) {
768   return DoBlasInternal(wrap::rocblas_isamax, stream,
769                         false /* = pointer_mode_host */, elem_count,
770                         GpuMemory(x), incx, GpuMemoryMutable(result));
771 }
772 
DoBlasIamax(Stream * stream,uint64 elem_count,const DeviceMemory<double> & x,int incx,DeviceMemory<int> * result)773 bool ROCMBlas::DoBlasIamax(Stream *stream, uint64 elem_count,
774                            const DeviceMemory<double> &x, int incx,
775                            DeviceMemory<int> *result) {
776   return DoBlasInternal(wrap::rocblas_idamax, stream,
777                         false /* = pointer_mode_host */, elem_count,
778                         GpuMemory(x), incx, GpuMemoryMutable(result));
779 }
780 
DoBlasIamax(Stream * stream,uint64 elem_count,const DeviceMemory<std::complex<float>> & x,int incx,DeviceMemory<int> * result)781 bool ROCMBlas::DoBlasIamax(Stream *stream, uint64 elem_count,
782                            const DeviceMemory<std::complex<float>> &x, int incx,
783                            DeviceMemory<int> *result) {
784   LOG(ERROR) << "rocBLAS does not currently support the AMAX operation "
785              << "for the \"complex<float>\" dataype";
786   return false;
787 }
788 
DoBlasIamax(Stream * stream,uint64 elem_count,const DeviceMemory<std::complex<double>> & x,int incx,DeviceMemory<int> * result)789 bool ROCMBlas::DoBlasIamax(Stream *stream, uint64 elem_count,
790                            const DeviceMemory<std::complex<double>> &x,
791                            int incx, DeviceMemory<int> *result) {
792   LOG(ERROR) << "rocBLAS does not currently support the AMAX operation "
793              << "for the \"complex<double>\" dataype";
794   return false;
795 }
796 
DoBlasIamin(Stream * stream,uint64 elem_count,const DeviceMemory<float> & x,int incx,DeviceMemory<int> * result)797 bool ROCMBlas::DoBlasIamin(Stream *stream, uint64 elem_count,
798                            const DeviceMemory<float> &x, int incx,
799                            DeviceMemory<int> *result) {
800   return DoBlasInternal(
801       wrap::rocblas_isamin, stream, false /* = pointer_mode_host */, elem_count,
802       GpuComplex(GpuMemory(x)), incx, GpuMemoryMutable(result));
803 }
804 
DoBlasIamin(Stream * stream,uint64 elem_count,const DeviceMemory<double> & x,int incx,DeviceMemory<int> * result)805 bool ROCMBlas::DoBlasIamin(Stream *stream, uint64 elem_count,
806                            const DeviceMemory<double> &x, int incx,
807                            DeviceMemory<int> *result) {
808   return DoBlasInternal(
809       wrap::rocblas_idamin, stream, false /* = pointer_mode_host */, elem_count,
810       GpuComplex(GpuMemory(x)), incx, GpuMemoryMutable(result));
811 }
812 
DoBlasIamin(Stream * stream,uint64 elem_count,const DeviceMemory<std::complex<float>> & x,int incx,DeviceMemory<int> * result)813 bool ROCMBlas::DoBlasIamin(Stream *stream, uint64 elem_count,
814                            const DeviceMemory<std::complex<float>> &x, int incx,
815                            DeviceMemory<int> *result) {
816   LOG(ERROR) << "rocBLAS does not currently support the AMIN operation "
817              << "for the \"complex<float>\" dataype";
818   return false;
819 }
820 
DoBlasIamin(Stream * stream,uint64 elem_count,const DeviceMemory<std::complex<double>> & x,int incx,DeviceMemory<int> * result)821 bool ROCMBlas::DoBlasIamin(Stream *stream, uint64 elem_count,
822                            const DeviceMemory<std::complex<double>> &x,
823                            int incx, DeviceMemory<int> *result) {
824   LOG(ERROR) << "rocBLAS does not currently support the AMIN operation "
825              << "for the \"complex<double>\" dataype";
826   return false;
827 }
828 
DoBlasGbmv(Stream * stream,blas::Transpose trans,uint64 m,uint64 n,uint64 kl,uint64 ku,float alpha,const DeviceMemory<float> & a,int lda,const DeviceMemory<float> & x,int incx,float beta,DeviceMemory<float> * y,int incy)829 bool ROCMBlas::DoBlasGbmv(Stream *stream, blas::Transpose trans, uint64 m,
830                           uint64 n, uint64 kl, uint64 ku, float alpha,
831                           const DeviceMemory<float> &a, int lda,
832                           const DeviceMemory<float> &x, int incx, float beta,
833                           DeviceMemory<float> *y, int incy) {
834   LOG(ERROR) << "rocBLAS does not currently support the GBMV operation "
835              << "for the \"float\" dataype";
836   return false;
837 }
838 
DoBlasGbmv(Stream * stream,blas::Transpose trans,uint64 m,uint64 n,uint64 kl,uint64 ku,double alpha,const DeviceMemory<double> & a,int lda,const DeviceMemory<double> & x,int incx,double beta,DeviceMemory<double> * y,int incy)839 bool ROCMBlas::DoBlasGbmv(Stream *stream, blas::Transpose trans, uint64 m,
840                           uint64 n, uint64 kl, uint64 ku, double alpha,
841                           const DeviceMemory<double> &a, int lda,
842                           const DeviceMemory<double> &x, int incx, double beta,
843                           DeviceMemory<double> *y, int incy) {
844   LOG(ERROR) << "rocBLAS does not currently support the GBMV operation "
845              << "for the \"double\" dataype";
846   return false;
847 }
848 
DoBlasGbmv(Stream * stream,blas::Transpose trans,uint64 m,uint64 n,uint64 kl,uint64 ku,std::complex<float> alpha,const DeviceMemory<std::complex<float>> & a,int lda,const DeviceMemory<std::complex<float>> & x,int incx,std::complex<float> beta,DeviceMemory<std::complex<float>> * y,int incy)849 bool ROCMBlas::DoBlasGbmv(Stream *stream, blas::Transpose trans, uint64 m,
850                           uint64 n, uint64 kl, uint64 ku,
851                           std::complex<float> alpha,
852                           const DeviceMemory<std::complex<float>> &a, int lda,
853                           const DeviceMemory<std::complex<float>> &x, int incx,
854                           std::complex<float> beta,
855                           DeviceMemory<std::complex<float>> *y, int incy) {
856   LOG(ERROR) << "rocBLAS does not currently support the GBMV operation "
857              << "for the \"complex<float>\" dataype";
858   return false;
859 }
860 
DoBlasGbmv(Stream * stream,blas::Transpose trans,uint64 m,uint64 n,uint64 kl,uint64 ku,std::complex<double> alpha,const DeviceMemory<std::complex<double>> & a,int lda,const DeviceMemory<std::complex<double>> & x,int incx,std::complex<double> beta,DeviceMemory<std::complex<double>> * y,int incy)861 bool ROCMBlas::DoBlasGbmv(Stream *stream, blas::Transpose trans, uint64 m,
862                           uint64 n, uint64 kl, uint64 ku,
863                           std::complex<double> alpha,
864                           const DeviceMemory<std::complex<double>> &a, int lda,
865                           const DeviceMemory<std::complex<double>> &x, int incx,
866                           std::complex<double> beta,
867                           DeviceMemory<std::complex<double>> *y, int incy) {
868   LOG(ERROR) << "rocBLAS does not currently support the GBMV operation "
869              << "for the \"complex<double>\" dataype";
870   return false;
871 }
872 
DoBlasGemv(Stream * stream,blas::Transpose trans,uint64 m,uint64 n,float alpha,const DeviceMemory<float> & a,int lda,const DeviceMemory<float> & x,int incx,float beta,DeviceMemory<float> * y,int incy)873 bool ROCMBlas::DoBlasGemv(Stream *stream, blas::Transpose trans, uint64 m,
874                           uint64 n, float alpha, const DeviceMemory<float> &a,
875                           int lda, const DeviceMemory<float> &x, int incx,
876                           float beta, DeviceMemory<float> *y, int incy) {
877   return DoBlasInternal(
878       wrap::rocblas_sgemv, stream, true /* = pointer_mode_host */,
879       ROCMBlasTranspose(trans), m, n, &alpha, GpuMemory(a), lda, GpuMemory(x),
880       incx, &beta, GpuMemoryMutable(y), incy);
881 }
882 
DoBlasGemv(Stream * stream,blas::Transpose trans,uint64 m,uint64 n,double alpha,const DeviceMemory<double> & a,int lda,const DeviceMemory<double> & x,int incx,double beta,DeviceMemory<double> * y,int incy)883 bool ROCMBlas::DoBlasGemv(Stream *stream, blas::Transpose trans, uint64 m,
884                           uint64 n, double alpha, const DeviceMemory<double> &a,
885                           int lda, const DeviceMemory<double> &x, int incx,
886                           double beta, DeviceMemory<double> *y, int incy) {
887   return DoBlasInternal(
888       wrap::rocblas_dgemv, stream, true /* = pointer_mode_host */,
889       ROCMBlasTranspose(trans), m, n, &alpha, GpuMemory(a), lda, GpuMemory(x),
890       incx, &beta, GpuMemoryMutable(y), incy);
891 }
892 
DoBlasGemv(Stream * stream,blas::Transpose trans,uint64 m,uint64 n,std::complex<float> alpha,const DeviceMemory<std::complex<float>> & a,int lda,const DeviceMemory<std::complex<float>> & x,int incx,std::complex<float> beta,DeviceMemory<std::complex<float>> * y,int incy)893 bool ROCMBlas::DoBlasGemv(Stream *stream, blas::Transpose trans, uint64 m,
894                           uint64 n, std::complex<float> alpha,
895                           const DeviceMemory<std::complex<float>> &a, int lda,
896                           const DeviceMemory<std::complex<float>> &x, int incx,
897                           std::complex<float> beta,
898                           DeviceMemory<std::complex<float>> *y, int incy) {
899   LOG(ERROR) << "rocBLAS does not currently support the GEMV operation "
900              << "for the \"complex<float>\" dataype";
901   return false;
902 }
903 
DoBlasGemv(Stream * stream,blas::Transpose trans,uint64 m,uint64 n,std::complex<double> alpha,const DeviceMemory<std::complex<double>> & a,int lda,const DeviceMemory<std::complex<double>> & x,int incx,std::complex<double> beta,DeviceMemory<std::complex<double>> * y,int incy)904 bool ROCMBlas::DoBlasGemv(Stream *stream, blas::Transpose trans, uint64 m,
905                           uint64 n, std::complex<double> alpha,
906                           const DeviceMemory<std::complex<double>> &a, int lda,
907                           const DeviceMemory<std::complex<double>> &x, int incx,
908                           std::complex<double> beta,
909                           DeviceMemory<std::complex<double>> *y, int incy) {
910   LOG(ERROR) << "rocBLAS does not currently support the GEMV operation "
911              << "for the \"complex<double>\" dataype";
912   return false;
913 }
914 
DoBlasGer(Stream * stream,uint64 m,uint64 n,float alpha,const DeviceMemory<float> & x,int incx,const DeviceMemory<float> & y,int incy,DeviceMemory<float> * a,int lda)915 bool ROCMBlas::DoBlasGer(Stream *stream, uint64 m, uint64 n, float alpha,
916                          const DeviceMemory<float> &x, int incx,
917                          const DeviceMemory<float> &y, int incy,
918                          DeviceMemory<float> *a, int lda) {
919   return DoBlasInternal(
920       wrap::rocblas_sger, stream, true /* = pointer_mode_host */, m, n, &alpha,
921       GpuMemory(x), incx, GpuMemory(y), incy, GpuMemoryMutable(a), lda);
922 }
923 
DoBlasGer(Stream * stream,uint64 m,uint64 n,double alpha,const DeviceMemory<double> & x,int incx,const DeviceMemory<double> & y,int incy,DeviceMemory<double> * a,int lda)924 bool ROCMBlas::DoBlasGer(Stream *stream, uint64 m, uint64 n, double alpha,
925                          const DeviceMemory<double> &x, int incx,
926                          const DeviceMemory<double> &y, int incy,
927                          DeviceMemory<double> *a, int lda) {
928   return DoBlasInternal(
929       wrap::rocblas_dger, stream, true /* = pointer_mode_host */, m, n, &alpha,
930       GpuMemory(x), incx, GpuMemory(y), incy, GpuMemoryMutable(a), lda);
931 }
932 
DoBlasGerc(Stream * stream,uint64 m,uint64 n,std::complex<float> alpha,const DeviceMemory<std::complex<float>> & x,int incx,const DeviceMemory<std::complex<float>> & y,int incy,DeviceMemory<std::complex<float>> * a,int lda)933 bool ROCMBlas::DoBlasGerc(Stream *stream, uint64 m, uint64 n,
934                           std::complex<float> alpha,
935                           const DeviceMemory<std::complex<float>> &x, int incx,
936                           const DeviceMemory<std::complex<float>> &y, int incy,
937                           DeviceMemory<std::complex<float>> *a, int lda) {
938   LOG(ERROR) << "rocBLAS does not currently support the GER operation "
939              << "for the \"complex<float>\" dataype";
940   return false;
941 }
942 
DoBlasGerc(Stream * stream,uint64 m,uint64 n,std::complex<double> alpha,const DeviceMemory<std::complex<double>> & x,int incx,const DeviceMemory<std::complex<double>> & y,int incy,DeviceMemory<std::complex<double>> * a,int lda)943 bool ROCMBlas::DoBlasGerc(Stream *stream, uint64 m, uint64 n,
944                           std::complex<double> alpha,
945                           const DeviceMemory<std::complex<double>> &x, int incx,
946                           const DeviceMemory<std::complex<double>> &y, int incy,
947                           DeviceMemory<std::complex<double>> *a, int lda) {
948   LOG(ERROR) << "rocBLAS does not currently support the GER operation "
949              << "for the \"complex<double>\" dataype";
950   return false;
951 }
952 
DoBlasGeru(Stream * stream,uint64 m,uint64 n,std::complex<float> alpha,const DeviceMemory<std::complex<float>> & x,int incx,const DeviceMemory<std::complex<float>> & y,int incy,DeviceMemory<std::complex<float>> * a,int lda)953 bool ROCMBlas::DoBlasGeru(Stream *stream, uint64 m, uint64 n,
954                           std::complex<float> alpha,
955                           const DeviceMemory<std::complex<float>> &x, int incx,
956                           const DeviceMemory<std::complex<float>> &y, int incy,
957                           DeviceMemory<std::complex<float>> *a, int lda) {
958   LOG(ERROR) << "rocBLAS does not currently support the GERU operation "
959              << "for the \"complex<float>\" dataype";
960   return false;
961 }
962 
DoBlasGeru(Stream * stream,uint64 m,uint64 n,std::complex<double> alpha,const DeviceMemory<std::complex<double>> & x,int incx,const DeviceMemory<std::complex<double>> & y,int incy,DeviceMemory<std::complex<double>> * a,int lda)963 bool ROCMBlas::DoBlasGeru(Stream *stream, uint64 m, uint64 n,
964                           std::complex<double> alpha,
965                           const DeviceMemory<std::complex<double>> &x, int incx,
966                           const DeviceMemory<std::complex<double>> &y, int incy,
967                           DeviceMemory<std::complex<double>> *a, int lda) {
968   LOG(ERROR) << "rocBLAS does not currently support the GERU operation "
969              << "for the \"complex<double>\" dataype";
970   return false;
971 }
972 
DoBlasHbmv(Stream * stream,blas::UpperLower uplo,uint64 n,uint64 k,std::complex<float> alpha,const DeviceMemory<std::complex<float>> & a,int lda,const DeviceMemory<std::complex<float>> & x,int incx,std::complex<float> beta,DeviceMemory<std::complex<float>> * y,int incy)973 bool ROCMBlas::DoBlasHbmv(Stream *stream, blas::UpperLower uplo, uint64 n,
974                           uint64 k, std::complex<float> alpha,
975                           const DeviceMemory<std::complex<float>> &a, int lda,
976                           const DeviceMemory<std::complex<float>> &x, int incx,
977                           std::complex<float> beta,
978                           DeviceMemory<std::complex<float>> *y, int incy) {
979   LOG(ERROR) << "rocBLAS does not currently support the HBMV operation "
980              << "for the \"complex<float>\" dataype";
981   return false;
982 }
983 
DoBlasHbmv(Stream * stream,blas::UpperLower uplo,uint64 n,uint64 k,std::complex<double> alpha,const DeviceMemory<std::complex<double>> & a,int lda,const DeviceMemory<std::complex<double>> & x,int incx,std::complex<double> beta,DeviceMemory<std::complex<double>> * y,int incy)984 bool ROCMBlas::DoBlasHbmv(Stream *stream, blas::UpperLower uplo, uint64 n,
985                           uint64 k, std::complex<double> alpha,
986                           const DeviceMemory<std::complex<double>> &a, int lda,
987                           const DeviceMemory<std::complex<double>> &x, int incx,
988                           std::complex<double> beta,
989                           DeviceMemory<std::complex<double>> *y, int incy) {
990   LOG(ERROR) << "rocBLAS does not currently support the HBMV operation "
991              << "for the \"complex<double>\" dataype";
992   return false;
993 }
994 
DoBlasHemv(Stream * stream,blas::UpperLower uplo,uint64 n,std::complex<float> alpha,const DeviceMemory<std::complex<float>> & a,int lda,const DeviceMemory<std::complex<float>> & x,int incx,std::complex<float> beta,DeviceMemory<std::complex<float>> * y,int incy)995 bool ROCMBlas::DoBlasHemv(Stream *stream, blas::UpperLower uplo, uint64 n,
996                           std::complex<float> alpha,
997                           const DeviceMemory<std::complex<float>> &a, int lda,
998                           const DeviceMemory<std::complex<float>> &x, int incx,
999                           std::complex<float> beta,
1000                           DeviceMemory<std::complex<float>> *y, int incy) {
1001   LOG(ERROR) << "rocBLAS does not currently support the HEMV operation "
1002              << "for the \"complex<float>\" dataype";
1003   return false;
1004 }
1005 
DoBlasHemv(Stream * stream,blas::UpperLower uplo,uint64 n,std::complex<double> alpha,const DeviceMemory<std::complex<double>> & a,int lda,const DeviceMemory<std::complex<double>> & x,int incx,std::complex<double> beta,DeviceMemory<std::complex<double>> * y,int incy)1006 bool ROCMBlas::DoBlasHemv(Stream *stream, blas::UpperLower uplo, uint64 n,
1007                           std::complex<double> alpha,
1008                           const DeviceMemory<std::complex<double>> &a, int lda,
1009                           const DeviceMemory<std::complex<double>> &x, int incx,
1010                           std::complex<double> beta,
1011                           DeviceMemory<std::complex<double>> *y, int incy) {
1012   LOG(ERROR) << "rocBLAS does not currently support the HEMV operation "
1013              << "for the \"complex<double>\" dataype";
1014   return false;
1015 }
1016 
DoBlasHer(Stream * stream,blas::UpperLower uplo,uint64 n,float alpha,const DeviceMemory<std::complex<float>> & x,int incx,DeviceMemory<std::complex<float>> * a,int lda)1017 bool ROCMBlas::DoBlasHer(Stream *stream, blas::UpperLower uplo, uint64 n,
1018                          float alpha,
1019                          const DeviceMemory<std::complex<float>> &x, int incx,
1020                          DeviceMemory<std::complex<float>> *a, int lda) {
1021   LOG(ERROR) << "rocBLAS does not currently support the HER operation "
1022              << "for the \"complex<float>\" dataype";
1023   return false;
1024 }
1025 
DoBlasHer(Stream * stream,blas::UpperLower uplo,uint64 n,double alpha,const DeviceMemory<std::complex<double>> & x,int incx,DeviceMemory<std::complex<double>> * a,int lda)1026 bool ROCMBlas::DoBlasHer(Stream *stream, blas::UpperLower uplo, uint64 n,
1027                          double alpha,
1028                          const DeviceMemory<std::complex<double>> &x, int incx,
1029                          DeviceMemory<std::complex<double>> *a, int lda) {
1030   LOG(ERROR) << "rocBLAS does not currently support the HER operation "
1031              << "for the \"complex<double>\" dataype";
1032   return false;
1033 }
1034 
DoBlasHer2(Stream * stream,blas::UpperLower uplo,uint64 n,std::complex<float> alpha,const DeviceMemory<std::complex<float>> & x,int incx,const DeviceMemory<std::complex<float>> & y,int incy,DeviceMemory<std::complex<float>> * a,int lda)1035 bool ROCMBlas::DoBlasHer2(Stream *stream, blas::UpperLower uplo, uint64 n,
1036                           std::complex<float> alpha,
1037                           const DeviceMemory<std::complex<float>> &x, int incx,
1038                           const DeviceMemory<std::complex<float>> &y, int incy,
1039                           DeviceMemory<std::complex<float>> *a, int lda) {
1040   LOG(ERROR) << "rocBLAS does not currently support the HER2 operation "
1041              << "for the \"complex<float>\" dataype";
1042   return false;
1043 }
1044 
DoBlasHer2(Stream * stream,blas::UpperLower uplo,uint64 n,std::complex<double> alpha,const DeviceMemory<std::complex<double>> & x,int incx,const DeviceMemory<std::complex<double>> & y,int incy,DeviceMemory<std::complex<double>> * a,int lda)1045 bool ROCMBlas::DoBlasHer2(Stream *stream, blas::UpperLower uplo, uint64 n,
1046                           std::complex<double> alpha,
1047                           const DeviceMemory<std::complex<double>> &x, int incx,
1048                           const DeviceMemory<std::complex<double>> &y, int incy,
1049                           DeviceMemory<std::complex<double>> *a, int lda) {
1050   LOG(ERROR) << "rocBLAS does not currently support the HER2 operation "
1051              << "for the \"complex<double>\" dataype";
1052   return false;
1053 }
1054 
DoBlasHpmv(Stream * stream,blas::UpperLower uplo,uint64 n,std::complex<float> alpha,const DeviceMemory<std::complex<float>> & ap,const DeviceMemory<std::complex<float>> & x,int incx,std::complex<float> beta,DeviceMemory<std::complex<float>> * y,int incy)1055 bool ROCMBlas::DoBlasHpmv(Stream *stream, blas::UpperLower uplo, uint64 n,
1056                           std::complex<float> alpha,
1057                           const DeviceMemory<std::complex<float>> &ap,
1058                           const DeviceMemory<std::complex<float>> &x, int incx,
1059                           std::complex<float> beta,
1060                           DeviceMemory<std::complex<float>> *y, int incy) {
1061   LOG(ERROR) << "rocBLAS does not currently support the HPMV operation "
1062              << "for the \"complex<float>\" dataype";
1063   return false;
1064 }
1065 
DoBlasHpmv(Stream * stream,blas::UpperLower uplo,uint64 n,std::complex<double> alpha,const DeviceMemory<std::complex<double>> & ap,const DeviceMemory<std::complex<double>> & x,int incx,std::complex<double> beta,DeviceMemory<std::complex<double>> * y,int incy)1066 bool ROCMBlas::DoBlasHpmv(Stream *stream, blas::UpperLower uplo, uint64 n,
1067                           std::complex<double> alpha,
1068                           const DeviceMemory<std::complex<double>> &ap,
1069                           const DeviceMemory<std::complex<double>> &x, int incx,
1070                           std::complex<double> beta,
1071                           DeviceMemory<std::complex<double>> *y, int incy) {
1072   LOG(ERROR) << "rocBLAS does not currently support the HPMV operation "
1073              << "for the \"complex<double>\" dataype";
1074   return false;
1075 }
1076 
DoBlasHpr(Stream * stream,blas::UpperLower uplo,uint64 n,float alpha,const DeviceMemory<std::complex<float>> & x,int incx,DeviceMemory<std::complex<float>> * ap)1077 bool ROCMBlas::DoBlasHpr(Stream *stream, blas::UpperLower uplo, uint64 n,
1078                          float alpha,
1079                          const DeviceMemory<std::complex<float>> &x, int incx,
1080                          DeviceMemory<std::complex<float>> *ap) {
1081   LOG(ERROR) << "rocBLAS does not currently support the HPR operation "
1082              << "for the \"complex<float>\" dataype";
1083   return false;
1084 }
1085 
DoBlasHpr(Stream * stream,blas::UpperLower uplo,uint64 n,double alpha,const DeviceMemory<std::complex<double>> & x,int incx,DeviceMemory<std::complex<double>> * ap)1086 bool ROCMBlas::DoBlasHpr(Stream *stream, blas::UpperLower uplo, uint64 n,
1087                          double alpha,
1088                          const DeviceMemory<std::complex<double>> &x, int incx,
1089                          DeviceMemory<std::complex<double>> *ap) {
1090   LOG(ERROR) << "rocBLAS does not currently support the HPR operation "
1091              << "for the \"complex<double>\" dataype";
1092   return false;
1093 }
1094 
DoBlasHpr2(Stream * stream,blas::UpperLower uplo,uint64 n,std::complex<float> alpha,const DeviceMemory<std::complex<float>> & x,int incx,const DeviceMemory<std::complex<float>> & y,int incy,DeviceMemory<std::complex<float>> * ap)1095 bool ROCMBlas::DoBlasHpr2(Stream *stream, blas::UpperLower uplo, uint64 n,
1096                           std::complex<float> alpha,
1097                           const DeviceMemory<std::complex<float>> &x, int incx,
1098                           const DeviceMemory<std::complex<float>> &y, int incy,
1099                           DeviceMemory<std::complex<float>> *ap) {
1100   LOG(ERROR) << "rocBLAS does not currently support the HPR2 operation "
1101              << "for the \"complex<float>\" dataype";
1102   return false;
1103 }
1104 
DoBlasHpr2(Stream * stream,blas::UpperLower uplo,uint64 n,std::complex<double> alpha,const DeviceMemory<std::complex<double>> & x,int incx,const DeviceMemory<std::complex<double>> & y,int incy,DeviceMemory<std::complex<double>> * ap)1105 bool ROCMBlas::DoBlasHpr2(Stream *stream, blas::UpperLower uplo, uint64 n,
1106                           std::complex<double> alpha,
1107                           const DeviceMemory<std::complex<double>> &x, int incx,
1108                           const DeviceMemory<std::complex<double>> &y, int incy,
1109                           DeviceMemory<std::complex<double>> *ap) {
1110   LOG(ERROR) << "rocBLAS does not currently support the HPR2 operation "
1111              << "for the \"complex<double>\" dataype";
1112   return false;
1113 }
1114 
DoBlasSbmv(Stream * stream,blas::UpperLower uplo,uint64 n,uint64 k,float alpha,const DeviceMemory<float> & a,int lda,const DeviceMemory<float> & x,int incx,float beta,DeviceMemory<float> * y,int incy)1115 bool ROCMBlas::DoBlasSbmv(Stream *stream, blas::UpperLower uplo, uint64 n,
1116                           uint64 k, float alpha, const DeviceMemory<float> &a,
1117                           int lda, const DeviceMemory<float> &x, int incx,
1118                           float beta, DeviceMemory<float> *y, int incy) {
1119   LOG(ERROR) << "rocBLAS does not currently support the SBMV operation "
1120              << "for the \"complex<float>\" dataype";
1121 
1122   return false;
1123 }
1124 
DoBlasSbmv(Stream * stream,blas::UpperLower uplo,uint64 n,uint64 k,double alpha,const DeviceMemory<double> & a,int lda,const DeviceMemory<double> & x,int incx,double beta,DeviceMemory<double> * y,int incy)1125 bool ROCMBlas::DoBlasSbmv(Stream *stream, blas::UpperLower uplo, uint64 n,
1126                           uint64 k, double alpha, const DeviceMemory<double> &a,
1127                           int lda, const DeviceMemory<double> &x, int incx,
1128                           double beta, DeviceMemory<double> *y, int incy) {
1129   LOG(ERROR) << "rocBLAS does not currently support the SBMV operation "
1130              << "for the \"complex<double>\" dataype";
1131   return false;
1132 }
1133 
DoBlasSpmv(Stream * stream,blas::UpperLower uplo,uint64 n,float alpha,const DeviceMemory<float> & ap,const DeviceMemory<float> & x,int incx,float beta,DeviceMemory<float> * y,int incy)1134 bool ROCMBlas::DoBlasSpmv(Stream *stream, blas::UpperLower uplo, uint64 n,
1135                           float alpha, const DeviceMemory<float> &ap,
1136                           const DeviceMemory<float> &x, int incx, float beta,
1137                           DeviceMemory<float> *y, int incy) {
1138   LOG(ERROR) << "rocBLAS does not currently support the SPMV operation "
1139              << "for the \"float\" dataype";
1140   return false;
1141 }
1142 
DoBlasSpmv(Stream * stream,blas::UpperLower uplo,uint64 n,double alpha,const DeviceMemory<double> & ap,const DeviceMemory<double> & x,int incx,double beta,DeviceMemory<double> * y,int incy)1143 bool ROCMBlas::DoBlasSpmv(Stream *stream, blas::UpperLower uplo, uint64 n,
1144                           double alpha, const DeviceMemory<double> &ap,
1145                           const DeviceMemory<double> &x, int incx, double beta,
1146                           DeviceMemory<double> *y, int incy) {
1147   LOG(ERROR) << "rocBLAS does not currently support the SPMV operation "
1148              << "for the \"double\" dataype";
1149   return false;
1150 }
1151 
DoBlasSpr(Stream * stream,blas::UpperLower uplo,uint64 n,float alpha,const DeviceMemory<float> & x,int incx,DeviceMemory<float> * ap)1152 bool ROCMBlas::DoBlasSpr(Stream *stream, blas::UpperLower uplo, uint64 n,
1153                          float alpha, const DeviceMemory<float> &x, int incx,
1154                          DeviceMemory<float> *ap) {
1155   LOG(ERROR) << "rocBLAS does not currently support the SPR operation "
1156              << "for the \"float\" dataype";
1157   return false;
1158 }
1159 
DoBlasSpr(Stream * stream,blas::UpperLower uplo,uint64 n,double alpha,const DeviceMemory<double> & x,int incx,DeviceMemory<double> * ap)1160 bool ROCMBlas::DoBlasSpr(Stream *stream, blas::UpperLower uplo, uint64 n,
1161                          double alpha, const DeviceMemory<double> &x, int incx,
1162                          DeviceMemory<double> *ap) {
1163   LOG(ERROR) << "rocBLAS does not currently support the SPR operation "
1164              << "for the \"double\" dataype";
1165   return false;
1166 }
1167 
DoBlasSpr2(Stream * stream,blas::UpperLower uplo,uint64 n,float alpha,const DeviceMemory<float> & x,int incx,const DeviceMemory<float> & y,int incy,DeviceMemory<float> * ap)1168 bool ROCMBlas::DoBlasSpr2(Stream *stream, blas::UpperLower uplo, uint64 n,
1169                           float alpha, const DeviceMemory<float> &x, int incx,
1170                           const DeviceMemory<float> &y, int incy,
1171                           DeviceMemory<float> *ap) {
1172   LOG(ERROR) << "rocBLAS does not currently support the SPR2 operation "
1173              << "for the \"float\" dataype";
1174   return false;
1175 }
1176 
DoBlasSpr2(Stream * stream,blas::UpperLower uplo,uint64 n,double alpha,const DeviceMemory<double> & x,int incx,const DeviceMemory<double> & y,int incy,DeviceMemory<double> * ap)1177 bool ROCMBlas::DoBlasSpr2(Stream *stream, blas::UpperLower uplo, uint64 n,
1178                           double alpha, const DeviceMemory<double> &x, int incx,
1179                           const DeviceMemory<double> &y, int incy,
1180                           DeviceMemory<double> *ap) {
1181   LOG(ERROR) << "rocBLAS does not currently support the SPR2 operation "
1182              << "for the \"double\" dataype";
1183   return false;
1184 }
1185 
DoBlasSymv(Stream * stream,blas::UpperLower uplo,uint64 n,float alpha,const DeviceMemory<float> & a,int lda,const DeviceMemory<float> & x,int incx,float beta,DeviceMemory<float> * y,int incy)1186 bool ROCMBlas::DoBlasSymv(Stream *stream, blas::UpperLower uplo, uint64 n,
1187                           float alpha, const DeviceMemory<float> &a, int lda,
1188                           const DeviceMemory<float> &x, int incx, float beta,
1189                           DeviceMemory<float> *y, int incy) {
1190   LOG(ERROR) << "rocBLAS does not currently support the SYMV operation "
1191              << "for the \"float\" dataype";
1192   return false;
1193 }
1194 
DoBlasSymv(Stream * stream,blas::UpperLower uplo,uint64 n,double alpha,const DeviceMemory<double> & a,int lda,const DeviceMemory<double> & x,int incx,double beta,DeviceMemory<double> * y,int incy)1195 bool ROCMBlas::DoBlasSymv(Stream *stream, blas::UpperLower uplo, uint64 n,
1196                           double alpha, const DeviceMemory<double> &a, int lda,
1197                           const DeviceMemory<double> &x, int incx, double beta,
1198                           DeviceMemory<double> *y, int incy) {
1199   LOG(ERROR) << "rocBLAS does not currently support the SYMV operation "
1200              << "for the \"double\" dataype";
1201   return false;
1202 }
1203 
DoBlasSyr(Stream * stream,blas::UpperLower uplo,uint64 n,float alpha,const DeviceMemory<float> & x,int incx,DeviceMemory<float> * a,int lda)1204 bool ROCMBlas::DoBlasSyr(Stream *stream, blas::UpperLower uplo, uint64 n,
1205                          float alpha, const DeviceMemory<float> &x, int incx,
1206                          DeviceMemory<float> *a, int lda) {
1207   return DoBlasInternal(wrap::rocblas_ssyr, stream,
1208                         true /* = pointer_mode_host */,
1209                         ROCMBlasUpperLower(uplo), n, &alpha, GpuMemory(x), incx,
1210                         GpuMemoryMutable(a), lda);
1211 }
1212 
DoBlasSyr(Stream * stream,blas::UpperLower uplo,uint64 n,double alpha,const DeviceMemory<double> & x,int incx,DeviceMemory<double> * a,int lda)1213 bool ROCMBlas::DoBlasSyr(Stream *stream, blas::UpperLower uplo, uint64 n,
1214                          double alpha, const DeviceMemory<double> &x, int incx,
1215                          DeviceMemory<double> *a, int lda) {
1216   return DoBlasInternal(wrap::rocblas_dsyr, stream,
1217                         true /* = pointer_mode_host */,
1218                         ROCMBlasUpperLower(uplo), n, &alpha, GpuMemory(x), incx,
1219                         GpuMemoryMutable(a), lda);
1220 }
1221 
DoBlasSyr2(Stream * stream,blas::UpperLower uplo,uint64 n,float alpha,const DeviceMemory<float> & x,int incx,const DeviceMemory<float> & y,int incy,DeviceMemory<float> * a,int lda)1222 bool ROCMBlas::DoBlasSyr2(Stream *stream, blas::UpperLower uplo, uint64 n,
1223                           float alpha, const DeviceMemory<float> &x, int incx,
1224                           const DeviceMemory<float> &y, int incy,
1225                           DeviceMemory<float> *a, int lda) {
1226   LOG(ERROR) << "rocBLAS does not currently support the SYR2 operation "
1227              << "for the \"float\" dataype";
1228   return false;
1229 }
1230 
DoBlasSyr2(Stream * stream,blas::UpperLower uplo,uint64 n,double alpha,const DeviceMemory<double> & x,int incx,const DeviceMemory<double> & y,int incy,DeviceMemory<double> * a,int lda)1231 bool ROCMBlas::DoBlasSyr2(Stream *stream, blas::UpperLower uplo, uint64 n,
1232                           double alpha, const DeviceMemory<double> &x, int incx,
1233                           const DeviceMemory<double> &y, int incy,
1234                           DeviceMemory<double> *a, int lda) {
1235   LOG(ERROR) << "rocBLAS does not currently support the SYR2 operation "
1236              << "for the \"double\" dataype";
1237   return false;
1238 }
1239 
DoBlasTbmv(Stream * stream,blas::UpperLower uplo,blas::Transpose trans,blas::Diagonal diag,uint64 n,uint64 k,const DeviceMemory<float> & a,int lda,DeviceMemory<float> * x,int incx)1240 bool ROCMBlas::DoBlasTbmv(Stream *stream, blas::UpperLower uplo,
1241                           blas::Transpose trans, blas::Diagonal diag, uint64 n,
1242                           uint64 k, const DeviceMemory<float> &a, int lda,
1243                           DeviceMemory<float> *x, int incx) {
1244   LOG(ERROR) << "rocBLAS does not currently support the TBMV operation "
1245              << "for the \"float\" dataype";
1246   return false;
1247 }
1248 
DoBlasTbmv(Stream * stream,blas::UpperLower uplo,blas::Transpose trans,blas::Diagonal diag,uint64 n,uint64 k,const DeviceMemory<double> & a,int lda,DeviceMemory<double> * x,int incx)1249 bool ROCMBlas::DoBlasTbmv(Stream *stream, blas::UpperLower uplo,
1250                           blas::Transpose trans, blas::Diagonal diag, uint64 n,
1251                           uint64 k, const DeviceMemory<double> &a, int lda,
1252                           DeviceMemory<double> *x, int incx) {
1253   LOG(ERROR) << "rocBLAS does not currently support the TBMV operation "
1254              << "for the \"double\" dataype";
1255   return false;
1256 }
1257 
DoBlasTbmv(Stream * stream,blas::UpperLower uplo,blas::Transpose trans,blas::Diagonal diag,uint64 n,uint64 k,const DeviceMemory<std::complex<float>> & a,int lda,DeviceMemory<std::complex<float>> * x,int incx)1258 bool ROCMBlas::DoBlasTbmv(Stream *stream, blas::UpperLower uplo,
1259                           blas::Transpose trans, blas::Diagonal diag, uint64 n,
1260                           uint64 k, const DeviceMemory<std::complex<float>> &a,
1261                           int lda, DeviceMemory<std::complex<float>> *x,
1262                           int incx) {
1263   LOG(ERROR) << "rocBLAS does not currently support the TBMV operation "
1264              << "for the \"complex<float>\" dataype";
1265   return false;
1266 }
1267 
DoBlasTbmv(Stream * stream,blas::UpperLower uplo,blas::Transpose trans,blas::Diagonal diag,uint64 n,uint64 k,const DeviceMemory<std::complex<double>> & a,int lda,DeviceMemory<std::complex<double>> * x,int incx)1268 bool ROCMBlas::DoBlasTbmv(Stream *stream, blas::UpperLower uplo,
1269                           blas::Transpose trans, blas::Diagonal diag, uint64 n,
1270                           uint64 k, const DeviceMemory<std::complex<double>> &a,
1271                           int lda, DeviceMemory<std::complex<double>> *x,
1272                           int incx) {
1273   LOG(ERROR) << "rocBLAS does not currently support the TBMV operation "
1274              << "for the \"complex<double>\" dataype";
1275   return false;
1276 }
1277 
DoBlasTbsv(Stream * stream,blas::UpperLower uplo,blas::Transpose trans,blas::Diagonal diag,uint64 n,uint64 k,const DeviceMemory<float> & a,int lda,DeviceMemory<float> * x,int incx)1278 bool ROCMBlas::DoBlasTbsv(Stream *stream, blas::UpperLower uplo,
1279                           blas::Transpose trans, blas::Diagonal diag, uint64 n,
1280                           uint64 k, const DeviceMemory<float> &a, int lda,
1281                           DeviceMemory<float> *x, int incx) {
1282   LOG(ERROR) << "rocBLAS does not currently support the TBSV operation "
1283              << "for the \"float\" dataype";
1284   return false;
1285 }
1286 
DoBlasTbsv(Stream * stream,blas::UpperLower uplo,blas::Transpose trans,blas::Diagonal diag,uint64 n,uint64 k,const DeviceMemory<double> & a,int lda,DeviceMemory<double> * x,int incx)1287 bool ROCMBlas::DoBlasTbsv(Stream *stream, blas::UpperLower uplo,
1288                           blas::Transpose trans, blas::Diagonal diag, uint64 n,
1289                           uint64 k, const DeviceMemory<double> &a, int lda,
1290                           DeviceMemory<double> *x, int incx) {
1291   LOG(ERROR) << "rocBLAS does not currently support the TBSV operation "
1292              << "for the \"double\" dataype";
1293   return false;
1294 }
1295 
DoBlasTbsv(Stream * stream,blas::UpperLower uplo,blas::Transpose trans,blas::Diagonal diag,uint64 n,uint64 k,const DeviceMemory<std::complex<float>> & a,int lda,DeviceMemory<std::complex<float>> * x,int incx)1296 bool ROCMBlas::DoBlasTbsv(Stream *stream, blas::UpperLower uplo,
1297                           blas::Transpose trans, blas::Diagonal diag, uint64 n,
1298                           uint64 k, const DeviceMemory<std::complex<float>> &a,
1299                           int lda, DeviceMemory<std::complex<float>> *x,
1300                           int incx) {
1301   LOG(ERROR) << "rocBLAS does not currently support the TBSV operation "
1302              << "for the \"complex<float>\" dataype";
1303   return false;
1304 }
1305 
DoBlasTbsv(Stream * stream,blas::UpperLower uplo,blas::Transpose trans,blas::Diagonal diag,uint64 n,uint64 k,const DeviceMemory<std::complex<double>> & a,int lda,DeviceMemory<std::complex<double>> * x,int incx)1306 bool ROCMBlas::DoBlasTbsv(Stream *stream, blas::UpperLower uplo,
1307                           blas::Transpose trans, blas::Diagonal diag, uint64 n,
1308                           uint64 k, const DeviceMemory<std::complex<double>> &a,
1309                           int lda, DeviceMemory<std::complex<double>> *x,
1310                           int incx) {
1311   LOG(ERROR) << "rocBLAS does not currently support the TBSV operation "
1312              << "for the \"complex<double>\" dataype";
1313   return false;
1314 }
1315 
DoBlasTpmv(Stream * stream,blas::UpperLower uplo,blas::Transpose trans,blas::Diagonal diag,uint64 n,const DeviceMemory<float> & ap,DeviceMemory<float> * x,int incx)1316 bool ROCMBlas::DoBlasTpmv(Stream *stream, blas::UpperLower uplo,
1317                           blas::Transpose trans, blas::Diagonal diag, uint64 n,
1318                           const DeviceMemory<float> &ap, DeviceMemory<float> *x,
1319                           int incx) {
1320   LOG(ERROR) << "rocBLAS does not currently support the TPMV operation "
1321              << "for the \"float\" dataype";
1322   return false;
1323 }
1324 
DoBlasTpmv(Stream * stream,blas::UpperLower uplo,blas::Transpose trans,blas::Diagonal diag,uint64 n,const DeviceMemory<double> & ap,DeviceMemory<double> * x,int incx)1325 bool ROCMBlas::DoBlasTpmv(Stream *stream, blas::UpperLower uplo,
1326                           blas::Transpose trans, blas::Diagonal diag, uint64 n,
1327                           const DeviceMemory<double> &ap,
1328                           DeviceMemory<double> *x, int incx) {
1329   LOG(ERROR) << "rocBLAS does not currently support the TPMV operation "
1330              << "for the \"double\" dataype";
1331   return false;
1332 }
1333 
DoBlasTpmv(Stream * stream,blas::UpperLower uplo,blas::Transpose trans,blas::Diagonal diag,uint64 n,const DeviceMemory<std::complex<float>> & ap,DeviceMemory<std::complex<float>> * x,int incx)1334 bool ROCMBlas::DoBlasTpmv(Stream *stream, blas::UpperLower uplo,
1335                           blas::Transpose trans, blas::Diagonal diag, uint64 n,
1336                           const DeviceMemory<std::complex<float>> &ap,
1337                           DeviceMemory<std::complex<float>> *x, int incx) {
1338   LOG(ERROR) << "rocBLAS does not currently support the TPMV operation "
1339              << "for the \"complex<float>\" dataype";
1340   return false;
1341 }
1342 
DoBlasTpmv(Stream * stream,blas::UpperLower uplo,blas::Transpose trans,blas::Diagonal diag,uint64 n,const DeviceMemory<std::complex<double>> & ap,DeviceMemory<std::complex<double>> * x,int incx)1343 bool ROCMBlas::DoBlasTpmv(Stream *stream, blas::UpperLower uplo,
1344                           blas::Transpose trans, blas::Diagonal diag, uint64 n,
1345                           const DeviceMemory<std::complex<double>> &ap,
1346                           DeviceMemory<std::complex<double>> *x, int incx) {
1347   LOG(ERROR) << "rocBLAS does not currently support the TPMV operation "
1348              << "for the \"complex<double>\" dataype";
1349   return false;
1350 }
1351 
DoBlasTpsv(Stream * stream,blas::UpperLower uplo,blas::Transpose trans,blas::Diagonal diag,uint64 n,const DeviceMemory<float> & ap,DeviceMemory<float> * x,int incx)1352 bool ROCMBlas::DoBlasTpsv(Stream *stream, blas::UpperLower uplo,
1353                           blas::Transpose trans, blas::Diagonal diag, uint64 n,
1354                           const DeviceMemory<float> &ap, DeviceMemory<float> *x,
1355                           int incx) {
1356   LOG(ERROR) << "rocBLAS does not currently support the TPSV operation "
1357              << "for the \"float\" dataype";
1358   return false;
1359 }
1360 
DoBlasTpsv(Stream * stream,blas::UpperLower uplo,blas::Transpose trans,blas::Diagonal diag,uint64 n,const DeviceMemory<double> & ap,DeviceMemory<double> * x,int incx)1361 bool ROCMBlas::DoBlasTpsv(Stream *stream, blas::UpperLower uplo,
1362                           blas::Transpose trans, blas::Diagonal diag, uint64 n,
1363                           const DeviceMemory<double> &ap,
1364                           DeviceMemory<double> *x, int incx) {
1365   LOG(ERROR) << "rocBLAS does not currently support the TPSV operation "
1366              << "for the \"double\" dataype";
1367   return false;
1368 }
1369 
DoBlasTpsv(Stream * stream,blas::UpperLower uplo,blas::Transpose trans,blas::Diagonal diag,uint64 n,const DeviceMemory<std::complex<float>> & ap,DeviceMemory<std::complex<float>> * x,int incx)1370 bool ROCMBlas::DoBlasTpsv(Stream *stream, blas::UpperLower uplo,
1371                           blas::Transpose trans, blas::Diagonal diag, uint64 n,
1372                           const DeviceMemory<std::complex<float>> &ap,
1373                           DeviceMemory<std::complex<float>> *x, int incx) {
1374   LOG(ERROR) << "rocBLAS does not currently support the TPSV operation "
1375              << "for the \"complex<float>\" dataype";
1376   return false;
1377 }
1378 
DoBlasTpsv(Stream * stream,blas::UpperLower uplo,blas::Transpose trans,blas::Diagonal diag,uint64 n,const DeviceMemory<std::complex<double>> & ap,DeviceMemory<std::complex<double>> * x,int incx)1379 bool ROCMBlas::DoBlasTpsv(Stream *stream, blas::UpperLower uplo,
1380                           blas::Transpose trans, blas::Diagonal diag, uint64 n,
1381                           const DeviceMemory<std::complex<double>> &ap,
1382                           DeviceMemory<std::complex<double>> *x, int incx) {
1383   LOG(ERROR) << "rocBLAS does not currently support the TPSV operation "
1384              << "for the \"complex<double>\" dataype";
1385   return false;
1386 }
1387 
DoBlasTrmv(Stream * stream,blas::UpperLower uplo,blas::Transpose trans,blas::Diagonal diag,uint64 n,const DeviceMemory<float> & a,int lda,DeviceMemory<float> * x,int incx)1388 bool ROCMBlas::DoBlasTrmv(Stream *stream, blas::UpperLower uplo,
1389                           blas::Transpose trans, blas::Diagonal diag, uint64 n,
1390                           const DeviceMemory<float> &a, int lda,
1391                           DeviceMemory<float> *x, int incx) {
1392   LOG(ERROR) << "rocBLAS does not currently support the TRMV operation "
1393              << "for the \"float\" dataype";
1394   return false;
1395 }
1396 
DoBlasTrmv(Stream * stream,blas::UpperLower uplo,blas::Transpose trans,blas::Diagonal diag,uint64 n,const DeviceMemory<double> & a,int lda,DeviceMemory<double> * x,int incx)1397 bool ROCMBlas::DoBlasTrmv(Stream *stream, blas::UpperLower uplo,
1398                           blas::Transpose trans, blas::Diagonal diag, uint64 n,
1399                           const DeviceMemory<double> &a, int lda,
1400                           DeviceMemory<double> *x, int incx) {
1401   LOG(ERROR) << "rocBLAS does not currently support the TRMV operation "
1402              << "for the \"double\" dataype";
1403   return false;
1404 }
1405 
DoBlasTrmv(Stream * stream,blas::UpperLower uplo,blas::Transpose trans,blas::Diagonal diag,uint64 n,const DeviceMemory<std::complex<float>> & a,int lda,DeviceMemory<std::complex<float>> * x,int incx)1406 bool ROCMBlas::DoBlasTrmv(Stream *stream, blas::UpperLower uplo,
1407                           blas::Transpose trans, blas::Diagonal diag, uint64 n,
1408                           const DeviceMemory<std::complex<float>> &a, int lda,
1409                           DeviceMemory<std::complex<float>> *x, int incx) {
1410   LOG(ERROR) << "rocBLAS does not currently support the TRMV operation "
1411              << "for the \"complex<float>\" dataype";
1412   return false;
1413 }
1414 
DoBlasTrmv(Stream * stream,blas::UpperLower uplo,blas::Transpose trans,blas::Diagonal diag,uint64 n,const DeviceMemory<std::complex<double>> & a,int lda,DeviceMemory<std::complex<double>> * x,int incx)1415 bool ROCMBlas::DoBlasTrmv(Stream *stream, blas::UpperLower uplo,
1416                           blas::Transpose trans, blas::Diagonal diag, uint64 n,
1417                           const DeviceMemory<std::complex<double>> &a, int lda,
1418                           DeviceMemory<std::complex<double>> *x, int incx) {
1419   LOG(ERROR) << "rocBLAS does not currently support the TRMV operation "
1420              << "for the \"complex<double>\" dataype";
1421   return false;
1422 }
1423 
DoBlasTrsv(Stream * stream,blas::UpperLower uplo,blas::Transpose trans,blas::Diagonal diag,uint64 n,const DeviceMemory<float> & a,int lda,DeviceMemory<float> * x,int incx)1424 bool ROCMBlas::DoBlasTrsv(Stream *stream, blas::UpperLower uplo,
1425                           blas::Transpose trans, blas::Diagonal diag, uint64 n,
1426                           const DeviceMemory<float> &a, int lda,
1427                           DeviceMemory<float> *x, int incx) {
1428   LOG(ERROR) << "rocBLAS does not currently support the TRSV operation "
1429              << "for the \"float\" dataype";
1430   return false;
1431 }
1432 
DoBlasTrsv(Stream * stream,blas::UpperLower uplo,blas::Transpose trans,blas::Diagonal diag,uint64 n,const DeviceMemory<double> & a,int lda,DeviceMemory<double> * x,int incx)1433 bool ROCMBlas::DoBlasTrsv(Stream *stream, blas::UpperLower uplo,
1434                           blas::Transpose trans, blas::Diagonal diag, uint64 n,
1435                           const DeviceMemory<double> &a, int lda,
1436                           DeviceMemory<double> *x, int incx) {
1437   LOG(ERROR) << "rocBLAS does not currently support the TRSV operation "
1438              << "for the \"double\" dataype";
1439   return false;
1440 }
1441 
DoBlasTrsv(Stream * stream,blas::UpperLower uplo,blas::Transpose trans,blas::Diagonal diag,uint64 n,const DeviceMemory<std::complex<float>> & a,int lda,DeviceMemory<std::complex<float>> * x,int incx)1442 bool ROCMBlas::DoBlasTrsv(Stream *stream, blas::UpperLower uplo,
1443                           blas::Transpose trans, blas::Diagonal diag, uint64 n,
1444                           const DeviceMemory<std::complex<float>> &a, int lda,
1445                           DeviceMemory<std::complex<float>> *x, int incx) {
1446   LOG(ERROR) << "rocBLAS does not currently support the TRSV operation "
1447              << "for the \"complex<float>\" dataype";
1448   return false;
1449 }
1450 
DoBlasTrsv(Stream * stream,blas::UpperLower uplo,blas::Transpose trans,blas::Diagonal diag,uint64 n,const DeviceMemory<std::complex<double>> & a,int lda,DeviceMemory<std::complex<double>> * x,int incx)1451 bool ROCMBlas::DoBlasTrsv(Stream *stream, blas::UpperLower uplo,
1452                           blas::Transpose trans, blas::Diagonal diag, uint64 n,
1453                           const DeviceMemory<std::complex<double>> &a, int lda,
1454                           DeviceMemory<std::complex<double>> *x, int incx) {
1455   LOG(ERROR) << "rocBLAS does not currently support the TRSV operation "
1456              << "for the \"complex<double>\" dataype";
1457   return false;
1458 }
1459 
DoBlasGemm(Stream * stream,blas::Transpose transa,blas::Transpose transb,uint64 m,uint64 n,uint64 k,float alpha,const DeviceMemory<Eigen::half> & a,int lda,const DeviceMemory<Eigen::half> & b,int ldb,float beta,DeviceMemory<Eigen::half> * c,int ldc)1460 bool ROCMBlas::DoBlasGemm(Stream *stream, blas::Transpose transa,
1461                           blas::Transpose transb, uint64 m, uint64 n, uint64 k,
1462                           float alpha, const DeviceMemory<Eigen::half> &a,
1463                           int lda, const DeviceMemory<Eigen::half> &b, int ldb,
1464                           float beta, DeviceMemory<Eigen::half> *c, int ldc) {
1465   VLOG(1) << port::Printf(
1466       "doing rocBLAS SGEMM: at=%d bt=%d m=%llu n=%llu "
1467       "k=%llu alpha=%f a=%p lda=%d b=%p ldb=%d beta=%f "
1468       "c=%p ldc=%d",
1469       static_cast<int>(transa), static_cast<int>(transb), m, n, k, alpha,
1470       a.opaque(), lda, b.opaque(), ldb, beta, c->opaque(), ldc);
1471   if (transa == blas::Transpose::kNoTranspose) {
1472     if (lda < static_cast<int64>(m)) {
1473       LOG(WARNING) << "GEMM lda was smaller than m (no transpose case); "
1474                       "precondition violation";
1475     }
1476   } else {
1477     if (lda < static_cast<int64>(k)) {
1478       LOG(WARNING) << "GEMM lda (" << lda << ") was smaller than k (" << k
1479                    << ") (transpose case); precondition violation";
1480     }
1481   }
1482   if (transb == blas::Transpose::kNoTranspose) {
1483     if (ldb < static_cast<int64>(k)) {
1484       LOG(WARNING) << "GEMM ldb (" << ldb << ") was smaller than k (" << k
1485                    << ") (no transpose case); precondition violation";
1486     }
1487   } else {
1488     if (ldb < static_cast<int64>(n)) {
1489       LOG(WARNING) << "GEMM ldb was smaller than n (transpose case); "
1490                       "precondition violation";
1491     }
1492   }
1493   const Eigen::half alpha_half(alpha);
1494   const Eigen::half beta_half(beta);
1495   return DoBlasInternal(
1496       wrap::rocblas_hgemm, stream, true /* = pointer_mode_host */,
1497       ROCMBlasTranspose(transa), ROCMBlasTranspose(transb), m, n, k,
1498       reinterpret_cast<const rocblas_half *>(&alpha_half),
1499       reinterpret_cast<const rocblas_half *>(GpuMemory(a)), lda,
1500       reinterpret_cast<const rocblas_half *>(GpuMemory(b)), ldb,
1501       reinterpret_cast<const rocblas_half *>(&beta_half),
1502       reinterpret_cast<rocblas_half *>(GpuMemoryMutable(c)), ldc);
1503 }
1504 
DoBlasGemm(Stream * stream,blas::Transpose transa,blas::Transpose transb,uint64 m,uint64 n,uint64 k,float alpha,const DeviceMemory<float> & a,int lda,const DeviceMemory<float> & b,int ldb,float beta,DeviceMemory<float> * c,int ldc)1505 bool ROCMBlas::DoBlasGemm(Stream *stream, blas::Transpose transa,
1506                           blas::Transpose transb, uint64 m, uint64 n, uint64 k,
1507                           float alpha, const DeviceMemory<float> &a, int lda,
1508                           const DeviceMemory<float> &b, int ldb, float beta,
1509                           DeviceMemory<float> *c, int ldc) {
1510   VLOG(1) << port::Printf(
1511       "doing rocBLAS SGEMM: at=%d bt=%d m=%llu n=%llu "
1512       "k=%llu alpha=%f a=%p lda=%d b=%p ldb=%d beta=%f "
1513       "c=%p ldc=%d",
1514       static_cast<int>(transa), static_cast<int>(transb), m, n, k, alpha,
1515       a.opaque(), lda, b.opaque(), ldb, beta, c->opaque(), ldc);
1516   if (transa == blas::Transpose::kNoTranspose) {
1517     if (lda < static_cast<int64>(m)) {
1518       LOG(WARNING) << "GEMM lda was smaller than m (no transpose case); "
1519                       "precondition violation";
1520     }
1521   } else {
1522     if (lda < static_cast<int64>(k)) {
1523       LOG(WARNING) << "GEMM lda (" << lda << ") was smaller than k (" << k
1524                    << ") (transpose case); precondition violation";
1525     }
1526   }
1527   if (transb == blas::Transpose::kNoTranspose) {
1528     if (ldb < static_cast<int64>(k)) {
1529       LOG(WARNING) << "GEMM ldb (" << ldb << ") was smaller than k (" << k
1530                    << ") (no transpose case); precondition violation";
1531     }
1532   } else {
1533     if (ldb < static_cast<int64>(n)) {
1534       LOG(WARNING) << "GEMM ldb was smaller than n (transpose case); "
1535                       "precondition violation";
1536     }
1537   }
1538   return DoBlasInternal(
1539       wrap::rocblas_sgemm, stream, true /* = pointer_mode_host */,
1540       ROCMBlasTranspose(transa), ROCMBlasTranspose(transb), m, n, k, &alpha,
1541       GpuMemory(a), lda, GpuMemory(b), ldb, &beta, GpuMemoryMutable(c), ldc);
1542 }
1543 
DoBlasGemm(Stream * stream,blas::Transpose transa,blas::Transpose transb,uint64 m,uint64 n,uint64 k,double alpha,const DeviceMemory<double> & a,int lda,const DeviceMemory<double> & b,int ldb,double beta,DeviceMemory<double> * c,int ldc)1544 bool ROCMBlas::DoBlasGemm(Stream *stream, blas::Transpose transa,
1545                           blas::Transpose transb, uint64 m, uint64 n, uint64 k,
1546                           double alpha, const DeviceMemory<double> &a, int lda,
1547                           const DeviceMemory<double> &b, int ldb, double beta,
1548                           DeviceMemory<double> *c, int ldc) {
1549   return DoBlasInternal(
1550       wrap::rocblas_dgemm, stream, true /* = pointer_mode_host */,
1551       ROCMBlasTranspose(transa), ROCMBlasTranspose(transb), m, n, k, &alpha,
1552       GpuMemory(a), lda, GpuMemory(b), ldb, &beta, GpuMemoryMutable(c), ldc);
1553 }
1554 
DoBlasGemm(Stream * stream,blas::Transpose transa,blas::Transpose transb,uint64 m,uint64 n,uint64 k,std::complex<float> alpha,const DeviceMemory<std::complex<float>> & a,int lda,const DeviceMemory<std::complex<float>> & b,int ldb,std::complex<float> beta,DeviceMemory<std::complex<float>> * c,int ldc)1555 bool ROCMBlas::DoBlasGemm(Stream *stream, blas::Transpose transa,
1556                           blas::Transpose transb, uint64 m, uint64 n, uint64 k,
1557                           std::complex<float> alpha,
1558                           const DeviceMemory<std::complex<float>> &a, int lda,
1559                           const DeviceMemory<std::complex<float>> &b, int ldb,
1560                           std::complex<float> beta,
1561                           DeviceMemory<std::complex<float>> *c, int ldc) {
1562   LOG(ERROR) << "rocBLAS does not currently support the GEMM operation "
1563              << "for the \"complex<float>\" dataype";
1564   return false;
1565 }
1566 
DoBlasGemm(Stream * stream,blas::Transpose transa,blas::Transpose transb,uint64 m,uint64 n,uint64 k,std::complex<double> alpha,const DeviceMemory<std::complex<double>> & a,int lda,const DeviceMemory<std::complex<double>> & b,int ldb,std::complex<double> beta,DeviceMemory<std::complex<double>> * c,int ldc)1567 bool ROCMBlas::DoBlasGemm(Stream *stream, blas::Transpose transa,
1568                           blas::Transpose transb, uint64 m, uint64 n, uint64 k,
1569                           std::complex<double> alpha,
1570                           const DeviceMemory<std::complex<double>> &a, int lda,
1571                           const DeviceMemory<std::complex<double>> &b, int ldb,
1572                           std::complex<double> beta,
1573                           DeviceMemory<std::complex<double>> *c, int ldc) {
1574   LOG(ERROR) << "rocBLAS does not currently support the GEMM operation "
1575              << "for the \"complex<double>\" dataype";
1576   return false;
1577 }
1578 
DoBlasGemvWithProfiling(Stream * stream,blas::Transpose trans,uint64 m,uint64 n,float alpha,const DeviceMemory<float> & a,int lda,const DeviceMemory<float> & x,int incx,float beta,DeviceMemory<float> * y,int incy,blas::ProfileResult * output_profile_result)1579 bool ROCMBlas::DoBlasGemvWithProfiling(
1580     Stream *stream, blas::Transpose trans, uint64 m, uint64 n, float alpha,
1581     const DeviceMemory<float> &a, int lda, const DeviceMemory<float> &x,
1582     int incx, float beta, DeviceMemory<float> *y, int incy,
1583     blas::ProfileResult *output_profile_result) {
1584   return DoBlasGemvWithProfilingImpl(stream, trans, m, n, alpha, a, lda, x,
1585                                      incx, beta, y, incy,
1586                                      output_profile_result);
1587 }
1588 
DoBlasGemvWithProfiling(Stream * stream,blas::Transpose trans,uint64 m,uint64 n,double alpha,const DeviceMemory<double> & a,int lda,const DeviceMemory<double> & x,int incx,double beta,DeviceMemory<double> * y,int incy,blas::ProfileResult * output_profile_result)1589 bool ROCMBlas::DoBlasGemvWithProfiling(
1590     Stream *stream, blas::Transpose trans, uint64 m, uint64 n, double alpha,
1591     const DeviceMemory<double> &a, int lda, const DeviceMemory<double> &x,
1592     int incx, double beta, DeviceMemory<double> *y, int incy,
1593     blas::ProfileResult *output_profile_result) {
1594   return DoBlasGemvWithProfilingImpl(stream, trans, m, n, alpha, a, lda, x,
1595                                      incx, beta, y, incy,
1596                                      output_profile_result);
1597 }
1598 
DoBlasGemvWithProfiling(Stream * stream,blas::Transpose trans,uint64 m,uint64 n,std::complex<float> alpha,const DeviceMemory<std::complex<float>> & a,int lda,const DeviceMemory<std::complex<float>> & x,int incx,std::complex<float> beta,DeviceMemory<std::complex<float>> * y,int incy,blas::ProfileResult * output_profile_result)1599 bool ROCMBlas::DoBlasGemvWithProfiling(
1600     Stream *stream, blas::Transpose trans, uint64 m, uint64 n,
1601     std::complex<float> alpha, const DeviceMemory<std::complex<float>> &a,
1602     int lda, const DeviceMemory<std::complex<float>> &x, int incx,
1603     std::complex<float> beta, DeviceMemory<std::complex<float>> *y, int incy,
1604     blas::ProfileResult *output_profile_result) {
1605   return DoBlasGemvWithProfilingImpl(stream, trans, m, n, alpha, a, lda, x,
1606                                      incx, beta, y, incy,
1607                                      output_profile_result);
1608 }
1609 
DoBlasGemvWithProfiling(Stream * stream,blas::Transpose trans,uint64 m,uint64 n,std::complex<double> alpha,const DeviceMemory<std::complex<double>> & a,int lda,const DeviceMemory<std::complex<double>> & x,int incx,std::complex<double> beta,DeviceMemory<std::complex<double>> * y,int incy,blas::ProfileResult * output_profile_result)1610 bool ROCMBlas::DoBlasGemvWithProfiling(
1611     Stream *stream, blas::Transpose trans, uint64 m, uint64 n,
1612     std::complex<double> alpha, const DeviceMemory<std::complex<double>> &a,
1613     int lda, const DeviceMemory<std::complex<double>> &x, int incx,
1614     std::complex<double> beta, DeviceMemory<std::complex<double>> *y, int incy,
1615     blas::ProfileResult *output_profile_result) {
1616   return DoBlasGemvWithProfilingImpl(stream, trans, m, n, alpha, a, lda, x,
1617                                      incx, beta, y, incy,
1618                                      output_profile_result);
1619 }
1620 
DoBlasGemmWithProfiling(Stream * stream,blas::Transpose transa,blas::Transpose transb,uint64 m,uint64 n,uint64 k,float alpha,const DeviceMemory<Eigen::half> & a,int lda,const DeviceMemory<Eigen::half> & b,int ldb,float beta,DeviceMemory<Eigen::half> * c,int ldc,blas::ProfileResult * output_profile_result)1621 bool ROCMBlas::DoBlasGemmWithProfiling(
1622     Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m,
1623     uint64 n, uint64 k, float alpha, const DeviceMemory<Eigen::half> &a,
1624     int lda, const DeviceMemory<Eigen::half> &b, int ldb, float beta,
1625     DeviceMemory<Eigen::half> *c, int ldc,
1626     blas::ProfileResult *output_profile_result) {
1627   return DoBlasGemmWithProfilingImpl(stream, transa, transb, m, n, k, alpha, a,
1628                                      lda, b, ldb, beta, c, ldc,
1629                                      output_profile_result);
1630 }
1631 
DoBlasGemmWithProfiling(Stream * stream,blas::Transpose transa,blas::Transpose transb,uint64 m,uint64 n,uint64 k,float alpha,const DeviceMemory<float> & a,int lda,const DeviceMemory<float> & b,int ldb,float beta,DeviceMemory<float> * c,int ldc,blas::ProfileResult * output_profile_result)1632 bool ROCMBlas::DoBlasGemmWithProfiling(
1633     Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m,
1634     uint64 n, uint64 k, float alpha, const DeviceMemory<float> &a, int lda,
1635     const DeviceMemory<float> &b, int ldb, float beta, DeviceMemory<float> *c,
1636     int ldc, blas::ProfileResult *output_profile_result) {
1637   return DoBlasGemmWithProfilingImpl(stream, transa, transb, m, n, k, alpha, a,
1638                                      lda, b, ldb, beta, c, ldc,
1639                                      output_profile_result);
1640 }
1641 
DoBlasGemmWithProfiling(Stream * stream,blas::Transpose transa,blas::Transpose transb,uint64 m,uint64 n,uint64 k,double alpha,const DeviceMemory<double> & a,int lda,const DeviceMemory<double> & b,int ldb,double beta,DeviceMemory<double> * c,int ldc,blas::ProfileResult * output_profile_result)1642 bool ROCMBlas::DoBlasGemmWithProfiling(
1643     Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m,
1644     uint64 n, uint64 k, double alpha, const DeviceMemory<double> &a, int lda,
1645     const DeviceMemory<double> &b, int ldb, double beta,
1646     DeviceMemory<double> *c, int ldc,
1647     blas::ProfileResult *output_profile_result) {
1648   return DoBlasGemmWithProfilingImpl(stream, transa, transb, m, n, k, alpha, a,
1649                                      lda, b, ldb, beta, c, ldc,
1650                                      output_profile_result);
1651 }
1652 
DoBlasGemmWithProfiling(Stream * stream,blas::Transpose transa,blas::Transpose transb,uint64 m,uint64 n,uint64 k,std::complex<float> alpha,const DeviceMemory<std::complex<float>> & a,int lda,const DeviceMemory<std::complex<float>> & b,int ldb,std::complex<float> beta,DeviceMemory<std::complex<float>> * c,int ldc,blas::ProfileResult * output_profile_result)1653 bool ROCMBlas::DoBlasGemmWithProfiling(
1654     Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m,
1655     uint64 n, uint64 k, std::complex<float> alpha,
1656     const DeviceMemory<std::complex<float>> &a, int lda,
1657     const DeviceMemory<std::complex<float>> &b, int ldb,
1658     std::complex<float> beta, DeviceMemory<std::complex<float>> *c, int ldc,
1659     blas::ProfileResult *output_profile_result) {
1660   return DoBlasGemmWithProfilingImpl(stream, transa, transb, m, n, k, alpha, a,
1661                                      lda, b, ldb, beta, c, ldc,
1662                                      output_profile_result);
1663 }
1664 
DoBlasGemmWithProfiling(Stream * stream,blas::Transpose transa,blas::Transpose transb,uint64 m,uint64 n,uint64 k,std::complex<double> alpha,const DeviceMemory<std::complex<double>> & a,int lda,const DeviceMemory<std::complex<double>> & b,int ldb,std::complex<double> beta,DeviceMemory<std::complex<double>> * c,int ldc,blas::ProfileResult * output_profile_result)1665 bool ROCMBlas::DoBlasGemmWithProfiling(
1666     Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m,
1667     uint64 n, uint64 k, std::complex<double> alpha,
1668     const DeviceMemory<std::complex<double>> &a, int lda,
1669     const DeviceMemory<std::complex<double>> &b, int ldb,
1670     std::complex<double> beta, DeviceMemory<std::complex<double>> *c, int ldc,
1671     blas::ProfileResult *output_profile_result) {
1672   return DoBlasGemmWithProfilingImpl(stream, transa, transb, m, n, k, alpha, a,
1673                                      lda, b, ldb, beta, c, ldc,
1674                                      output_profile_result);
1675 }
1676 
1677 template <typename T>
DoBlasGemvWithProfilingImpl(Stream * stream,blas::Transpose trans,uint64 m,uint64 n,const T & alpha,const DeviceMemory<T> & a,int lda,const DeviceMemory<T> & x,int incx,const T & beta,DeviceMemory<T> * y,int incy,blas::ProfileResult * output_profile_result)1678 bool ROCMBlas::DoBlasGemvWithProfilingImpl(
1679     Stream *stream, blas::Transpose trans, uint64 m, uint64 n, const T &alpha,
1680     const DeviceMemory<T> &a, int lda, const DeviceMemory<T> &x, int incx,
1681     const T &beta, DeviceMemory<T> *y, int incy,
1682     blas::ProfileResult *output_profile_result) {
1683   // ROCM TODO: properly implement the interface
1684   return false;
1685 }
1686 
1687 template <typename T, typename ParamType>
DoBlasGemmWithProfilingImpl(Stream * stream,blas::Transpose transa,blas::Transpose transb,uint64 m,uint64 n,uint64 k,const ParamType & alpha,const DeviceMemory<T> & a,int lda,const DeviceMemory<T> & b,int ldb,const ParamType & beta,DeviceMemory<T> * c,int ldc,blas::ProfileResult * output_profile_result)1688 bool ROCMBlas::DoBlasGemmWithProfilingImpl(
1689     Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m,
1690     uint64 n, uint64 k, const ParamType &alpha, const DeviceMemory<T> &a,
1691     int lda, const DeviceMemory<T> &b, int ldb, const ParamType &beta,
1692     DeviceMemory<T> *c, int ldc, blas::ProfileResult *output_profile_result) {
1693   // ROCM TODO: properly implement the interface
1694   return false;
1695 }
1696 
1697 template <typename InT, typename OutT, typename CompT>
DoBlasGemmWithAlgorithmImpl(Stream * stream,blas::Transpose transa,blas::Transpose transb,uint64 m,uint64 n,uint64 k,const CompT & alpha,const DeviceMemory<InT> & a,int lda,const DeviceMemory<InT> & b,int ldb,const CompT & beta,DeviceMemory<OutT> * c,int ldc,blas::ComputationType computation_type,blas::AlgorithmType algorithm,blas::ProfileResult * output_profile_result)1698 bool ROCMBlas::DoBlasGemmWithAlgorithmImpl(
1699     Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m,
1700     uint64 n, uint64 k, const CompT &alpha, const DeviceMemory<InT> &a, int lda,
1701     const DeviceMemory<InT> &b, int ldb, const CompT &beta,
1702     DeviceMemory<OutT> *c, int ldc, blas::ComputationType computation_type,
1703     blas::AlgorithmType algorithm, blas::ProfileResult *output_profile_result) {
1704   // ROCM TODO: properly implement the interface
1705   return false;
1706 }
1707 
GetBlasGemmAlgorithms(std::vector<blas::AlgorithmType> * out_algorithms)1708 bool ROCMBlas::GetBlasGemmAlgorithms(
1709     std::vector<blas::AlgorithmType> *out_algorithms) {
1710   // ROCM TODO: properly implement the interface
1711   return true;
1712 }
1713 
DoBlasGemmWithAlgorithm(Stream * stream,blas::Transpose transa,blas::Transpose transb,uint64 m,uint64 n,uint64 k,const HostOrDeviceScalar<int> & alpha,const DeviceMemory<int8> & a,int lda,const DeviceMemory<int8> & b,int ldb,const HostOrDeviceScalar<int> & beta,DeviceMemory<int32> * c,int ldc,blas::ComputationType computation_type,blas::AlgorithmType algorithm,blas::ProfileResult * output_profile_result)1714 bool ROCMBlas::DoBlasGemmWithAlgorithm(
1715     Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m,
1716     uint64 n, uint64 k, const HostOrDeviceScalar<int> &alpha,
1717     const DeviceMemory<int8> &a, int lda, const DeviceMemory<int8> &b, int ldb,
1718     const HostOrDeviceScalar<int> &beta, DeviceMemory<int32> *c, int ldc,
1719     blas::ComputationType computation_type, blas::AlgorithmType algorithm,
1720     blas::ProfileResult *output_profile_result) {
1721   LOG(ERROR)
1722       << "rocBLAS does not currently support the GEMMwithAlgorithm operation "
1723       << "for the \"int8\" dataype";
1724   return false;
1725 }
1726 
DoBlasGemmWithAlgorithm(Stream * stream,blas::Transpose transa,blas::Transpose transb,uint64 m,uint64 n,uint64 k,const HostOrDeviceScalar<Eigen::half> & alpha,const DeviceMemory<Eigen::half> & a,int lda,const DeviceMemory<Eigen::half> & b,int ldb,const HostOrDeviceScalar<Eigen::half> & beta,DeviceMemory<Eigen::half> * c,int ldc,blas::ComputationType computation_type,blas::AlgorithmType algorithm,blas::ProfileResult * output_profile_result)1727 bool ROCMBlas::DoBlasGemmWithAlgorithm(
1728     Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m,
1729     uint64 n, uint64 k, const HostOrDeviceScalar<Eigen::half> &alpha,
1730     const DeviceMemory<Eigen::half> &a, int lda,
1731     const DeviceMemory<Eigen::half> &b, int ldb,
1732     const HostOrDeviceScalar<Eigen::half> &beta, DeviceMemory<Eigen::half> *c,
1733     int ldc, blas::ComputationType computation_type,
1734     blas::AlgorithmType algorithm, blas::ProfileResult *output_profile_result) {
1735   LOG(ERROR)
1736       << "rocBLAS does not currently support the GEMMwithAlgorithm operation "
1737       << "for the \"half\" dataype";
1738   return false;
1739 }
1740 
DoBlasGemmWithAlgorithm(Stream * stream,blas::Transpose transa,blas::Transpose transb,uint64 m,uint64 n,uint64 k,const HostOrDeviceScalar<float> & alpha,const DeviceMemory<float> & a,int lda,const DeviceMemory<float> & b,int ldb,const HostOrDeviceScalar<float> & beta,DeviceMemory<float> * c,int ldc,blas::ComputationType computation_type,blas::AlgorithmType algorithm,blas::ProfileResult * output_profile_result)1741 bool ROCMBlas::DoBlasGemmWithAlgorithm(
1742     Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m,
1743     uint64 n, uint64 k, const HostOrDeviceScalar<float> &alpha,
1744     const DeviceMemory<float> &a, int lda, const DeviceMemory<float> &b,
1745     int ldb, const HostOrDeviceScalar<float> &beta, DeviceMemory<float> *c,
1746     int ldc, blas::ComputationType computation_type,
1747     blas::AlgorithmType algorithm, blas::ProfileResult *output_profile_result) {
1748   LOG(ERROR)
1749       << "rocBLAS does not currently support the GEMMwithAlgorithm operation "
1750       << "for the \"float\" dataype";
1751   return false;
1752 }
1753 
DoBlasGemmWithAlgorithm(Stream * stream,blas::Transpose transa,blas::Transpose transb,uint64 m,uint64 n,uint64 k,const HostOrDeviceScalar<double> & alpha,const DeviceMemory<double> & a,int lda,const DeviceMemory<double> & b,int ldb,const HostOrDeviceScalar<double> & beta,DeviceMemory<double> * c,int ldc,blas::ComputationType computation_type,blas::AlgorithmType algorithm,blas::ProfileResult * output_profile_result)1754 bool ROCMBlas::DoBlasGemmWithAlgorithm(
1755     Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m,
1756     uint64 n, uint64 k, const HostOrDeviceScalar<double> &alpha,
1757     const DeviceMemory<double> &a, int lda, const DeviceMemory<double> &b,
1758     int ldb, const HostOrDeviceScalar<double> &beta, DeviceMemory<double> *c,
1759     int ldc, blas::ComputationType computation_type,
1760     blas::AlgorithmType algorithm, blas::ProfileResult *output_profile_result) {
1761   LOG(ERROR)
1762       << "rocBLAS does not currently support the GEMMwithAlgorithm operation "
1763       << "for the \"double\" dataype";
1764   return false;
1765 }
1766 
DoBlasGemmWithAlgorithm(Stream * stream,blas::Transpose transa,blas::Transpose transb,uint64 m,uint64 n,uint64 k,const HostOrDeviceScalar<std::complex<float>> & alpha,const DeviceMemory<std::complex<float>> & a,int lda,const DeviceMemory<std::complex<float>> & b,int ldb,const HostOrDeviceScalar<std::complex<float>> & beta,DeviceMemory<std::complex<float>> * c,int ldc,blas::ComputationType computation_type,blas::AlgorithmType algorithm,blas::ProfileResult * output_profile_result)1767 bool ROCMBlas::DoBlasGemmWithAlgorithm(
1768     Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m,
1769     uint64 n, uint64 k, const HostOrDeviceScalar<std::complex<float>> &alpha,
1770     const DeviceMemory<std::complex<float>> &a, int lda,
1771     const DeviceMemory<std::complex<float>> &b, int ldb,
1772     const HostOrDeviceScalar<std::complex<float>> &beta,
1773     DeviceMemory<std::complex<float>> *c, int ldc,
1774     blas::ComputationType computation_type, blas::AlgorithmType algorithm,
1775     blas::ProfileResult *output_profile_result) {
1776   LOG(ERROR)
1777       << "rocBLAS does not currently support the GEMMwithAlgorithm operation "
1778       << "for the \"complex<float>\" dataype";
1779   return false;
1780 }
1781 
DoBlasGemmWithAlgorithm(Stream * stream,blas::Transpose transa,blas::Transpose transb,uint64 m,uint64 n,uint64 k,const HostOrDeviceScalar<std::complex<double>> & alpha,const DeviceMemory<std::complex<double>> & a,int lda,const DeviceMemory<std::complex<double>> & b,int ldb,const HostOrDeviceScalar<std::complex<double>> & beta,DeviceMemory<std::complex<double>> * c,int ldc,blas::ComputationType computation_type,blas::AlgorithmType algorithm,blas::ProfileResult * output_profile_result)1782 bool ROCMBlas::DoBlasGemmWithAlgorithm(
1783     Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m,
1784     uint64 n, uint64 k, const HostOrDeviceScalar<std::complex<double>> &alpha,
1785     const DeviceMemory<std::complex<double>> &a, int lda,
1786     const DeviceMemory<std::complex<double>> &b, int ldb,
1787     const HostOrDeviceScalar<std::complex<double>> &beta,
1788     DeviceMemory<std::complex<double>> *c, int ldc,
1789     blas::ComputationType computation_type, blas::AlgorithmType algorithm,
1790     blas::ProfileResult *output_profile_result) {
1791   LOG(ERROR)
1792       << "rocBLAS does not currently support the GEMMwithAlgorithm operation "
1793       << "for the \"complex<double>\" dataype";
1794   return false;
1795 }
1796 
1797 template <typename T>
1798 struct EigenHalfToRocBlasHalf {
1799   using type = T;
1800 };
1801 
1802 template <>
1803 struct EigenHalfToRocBlasHalf<Eigen::half> {
1804   using type = rocblas_half;
1805 };
1806 
1807 template <typename T, typename FuncT>
DoBlasGemmBatchedInternal(FuncT rocblas_func,Stream * stream,blas::Transpose transa,blas::Transpose transb,uint64 m,uint64 n,uint64 k,T alpha,const port::ArraySlice<DeviceMemory<T> * > & a_ptrs_to_wrappers,int lda,const port::ArraySlice<DeviceMemory<T> * > & b_ptrs_to_wrappers,int ldb,T beta,const port::ArraySlice<DeviceMemory<T> * > & c_ptrs_to_wrappers,int ldc,int batch_count,ScratchAllocator * scratch_allocator)1808 port::Status ROCMBlas::DoBlasGemmBatchedInternal(
1809     FuncT rocblas_func, Stream *stream, blas::Transpose transa,
1810     blas::Transpose transb, uint64 m, uint64 n, uint64 k, T alpha,
1811     const port::ArraySlice<DeviceMemory<T> *> &a_ptrs_to_wrappers, int lda,
1812     const port::ArraySlice<DeviceMemory<T> *> &b_ptrs_to_wrappers, int ldb,
1813     T beta, const port::ArraySlice<DeviceMemory<T> *> &c_ptrs_to_wrappers,
1814     int ldc, int batch_count, ScratchAllocator *scratch_allocator) {
1815   // MAPPED_T will be same as T for all types except Eigen::Half
1816   // for T = Eigen::half, MAPPED_T = rocblas_half
1817   using MAPPED_T = typename EigenHalfToRocBlasHalf<T>::type;
1818 
1819   // Alocate local vectors to hold device pointers to matrices
1820   std::vector<MAPPED_T *> a_raw_ptrs, b_raw_ptrs, c_raw_ptrs;
1821   for (int i = 0; i < batch_count; ++i) {
1822     // static_cast does work when converting Eigen::half* to rocblas_half*,
1823     // hence the use od reinterpret_cast
1824     a_raw_ptrs.push_back(
1825         reinterpret_cast<MAPPED_T *>(a_ptrs_to_wrappers[i]->opaque()));
1826     b_raw_ptrs.push_back(
1827         reinterpret_cast<MAPPED_T *>(b_ptrs_to_wrappers[i]->opaque()));
1828     c_raw_ptrs.push_back(
1829         reinterpret_cast<MAPPED_T *>(c_ptrs_to_wrappers[i]->opaque()));
1830   }
1831 
1832   //  batch_count <= 1 is base case, no definable matrix stride, set it same as
1833   //  ld*
1834   long long bsa = lda;
1835   long long bsb = ldb;
1836   long long bsc = ldc;
1837   bool bsa_is_constant = true;
1838   bool bsb_is_constant = true;
1839   bool bsc_is_constant = true;
1840 
1841   if (batch_count > 1) {
1842     // Remember first stride; if any other stride is different that this one,
1843     // KABLAM
1844     bsa = a_raw_ptrs[1] - a_raw_ptrs[0];
1845     bsb = b_raw_ptrs[1] - b_raw_ptrs[0];
1846     bsc = c_raw_ptrs[1] - c_raw_ptrs[0];
1847 
1848     //  Loop to verify that batched strides are constant
1849     //  All the test cases from batch_matmul_op_test.py seem to satisfy this
1850     //  requirement of a constant stride.  If this can be proven globally, then
1851     //  this loop check can be safely removed
1852     for (int i = 1; i < batch_count - 1; ++i) {
1853       long long iterative_bsa = a_raw_ptrs[i + 1] - a_raw_ptrs[i];
1854       if (iterative_bsa != bsa) {
1855         bsa_is_constant = false;
1856         break;
1857       }
1858 
1859       long long iterative_bsb = b_raw_ptrs[i + 1] - b_raw_ptrs[i];
1860       if (iterative_bsb != bsb) {
1861         bsb_is_constant = false;
1862         break;
1863       }
1864 
1865       long long iterative_bsc = c_raw_ptrs[i + 1] - c_raw_ptrs[i];
1866       if (iterative_bsc != bsc) {
1867         bsc_is_constant = false;
1868         break;
1869       }
1870     }
1871   }
1872 
1873   assert(!(ldc < m || bsc < ldc * n));
1874 
1875   if (ROCMBlasTranspose(transa) == rocblas_operation_none)
1876     assert(!(lda < m || bsa < lda * k));
1877   else
1878     assert(!(lda < k || bsa < lda * m));
1879 
1880   if (ROCMBlasTranspose(transb) == rocblas_operation_none)
1881     assert(!(ldb < k || bsb < ldb * n));
1882   else
1883     assert(!(ldb < n || bsc < ldc * k));
1884 
1885   MAPPED_T *alpha_ptr = reinterpret_cast<MAPPED_T *>(&alpha);
1886   MAPPED_T *beta_ptr = reinterpret_cast<MAPPED_T *>(&beta);
1887 
1888   if (bsa_is_constant && bsb_is_constant && bsc_is_constant) {
1889     bool ok = DoBlasInternal(
1890         rocblas_func, stream, true /* = pointer_mode_host */,
1891         ROCMBlasTranspose(transa), ROCMBlasTranspose(transb), m, n, k,
1892         GpuComplex(alpha_ptr), a_raw_ptrs[0], lda, bsa, b_raw_ptrs[0], ldb, bsb,
1893         GpuComplex(beta_ptr), c_raw_ptrs[0], ldc, bsc, batch_count);
1894 
1895     if (ok) {
1896       return port::Status::OK();
1897     }
1898   }
1899 
1900   return port::Status(port::error::INTERNAL,
1901                       "failed BLAS call, see log for details");
1902 }
1903 
DoBlasGemmBatched(Stream * stream,blas::Transpose transa,blas::Transpose transb,uint64 m,uint64 n,uint64 k,float alpha,const port::ArraySlice<DeviceMemory<Eigen::half> * > & a,int lda,const port::ArraySlice<DeviceMemory<Eigen::half> * > & b,int ldb,float beta,const port::ArraySlice<DeviceMemory<Eigen::half> * > & c,int ldc,int batch_count,ScratchAllocator * scratch_allocator)1904 bool ROCMBlas::DoBlasGemmBatched(
1905     Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m,
1906     uint64 n, uint64 k, float alpha,
1907     const port::ArraySlice<DeviceMemory<Eigen::half> *> &a, int lda,
1908     const port::ArraySlice<DeviceMemory<Eigen::half> *> &b, int ldb, float beta,
1909     const port::ArraySlice<DeviceMemory<Eigen::half> *> &c, int ldc,
1910     int batch_count, ScratchAllocator *scratch_allocator) {
1911   const Eigen::half alpha_half(alpha);
1912   const Eigen::half beta_half(beta);
1913 
1914   port::Status status = DoBlasGemmBatchedInternal(
1915       wrap::rocblas_hgemm_strided_batched, stream, transa, transb, m, n, k,
1916       alpha_half, a, lda, b, ldb, beta_half, c, ldc, batch_count,
1917       scratch_allocator);
1918   if (!status.ok()) {
1919     LOG(ERROR) << status;
1920   }
1921 
1922   return status.ok();
1923 }
1924 
DoBlasGemmBatched(Stream * stream,blas::Transpose transa,blas::Transpose transb,uint64 m,uint64 n,uint64 k,float alpha,const port::ArraySlice<DeviceMemory<float> * > & a_array,int lda,const port::ArraySlice<DeviceMemory<float> * > & b_array,int ldb,float beta,const port::ArraySlice<DeviceMemory<float> * > & c_array,int ldc,int batch_count,ScratchAllocator * scratch_allocator)1925 bool ROCMBlas::DoBlasGemmBatched(
1926     Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m,
1927     uint64 n, uint64 k, float alpha,
1928     const port::ArraySlice<DeviceMemory<float> *> &a_array, int lda,
1929     const port::ArraySlice<DeviceMemory<float> *> &b_array, int ldb, float beta,
1930     const port::ArraySlice<DeviceMemory<float> *> &c_array, int ldc,
1931     int batch_count, ScratchAllocator *scratch_allocator) {
1932   port::Status status = DoBlasGemmBatchedInternal(
1933       wrap::rocblas_sgemm_strided_batched, stream, transa, transb, m, n, k,
1934       alpha, a_array, lda, b_array, ldb, beta, c_array, ldc, batch_count,
1935       scratch_allocator);
1936   if (!status.ok()) {
1937     LOG(ERROR) << status;
1938   }
1939   return status.ok();
1940 }
1941 
DoBlasGemmBatched(Stream * stream,blas::Transpose transa,blas::Transpose transb,uint64 m,uint64 n,uint64 k,double alpha,const port::ArraySlice<DeviceMemory<double> * > & a_array,int lda,const port::ArraySlice<DeviceMemory<double> * > & b_array,int ldb,double beta,const port::ArraySlice<DeviceMemory<double> * > & c_array,int ldc,int batch_count,ScratchAllocator * scratch_allocator)1942 bool ROCMBlas::DoBlasGemmBatched(
1943     Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m,
1944     uint64 n, uint64 k, double alpha,
1945     const port::ArraySlice<DeviceMemory<double> *> &a_array, int lda,
1946     const port::ArraySlice<DeviceMemory<double> *> &b_array, int ldb,
1947     double beta, const port::ArraySlice<DeviceMemory<double> *> &c_array,
1948     int ldc, int batch_count, ScratchAllocator *scratch_allocator) {
1949   port::Status status = DoBlasGemmBatchedInternal(
1950       wrap::rocblas_dgemm_strided_batched, stream, transa, transb, m, n, k,
1951       alpha, a_array, lda, b_array, ldb, beta, c_array, ldc, batch_count,
1952       scratch_allocator);
1953   if (!status.ok()) {
1954     LOG(ERROR) << status;
1955   }
1956   return status.ok();
1957 }
1958 
DoBlasGemmBatched(Stream * stream,blas::Transpose transa,blas::Transpose transb,uint64 m,uint64 n,uint64 k,std::complex<float> alpha,const port::ArraySlice<DeviceMemory<std::complex<float>> * > & a_array,int lda,const port::ArraySlice<DeviceMemory<std::complex<float>> * > & b_array,int ldb,std::complex<float> beta,const port::ArraySlice<DeviceMemory<std::complex<float>> * > & c_array,int ldc,int batch_count,ScratchAllocator * scratch_allocator)1959 bool ROCMBlas::DoBlasGemmBatched(
1960     Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m,
1961     uint64 n, uint64 k, std::complex<float> alpha,
1962     const port::ArraySlice<DeviceMemory<std::complex<float>> *> &a_array,
1963     int lda,
1964     const port::ArraySlice<DeviceMemory<std::complex<float>> *> &b_array,
1965     int ldb, std::complex<float> beta,
1966     const port::ArraySlice<DeviceMemory<std::complex<float>> *> &c_array,
1967     int ldc, int batch_count, ScratchAllocator *scratch_allocator) {
1968   LOG(ERROR) << "rocBLAS does not currently support the GEMMBatched operation "
1969              << "for the \"complex<float>\" dataype";
1970   return false;
1971 }
1972 
DoBlasGemmBatched(Stream * stream,blas::Transpose transa,blas::Transpose transb,uint64 m,uint64 n,uint64 k,std::complex<double> alpha,const port::ArraySlice<DeviceMemory<std::complex<double>> * > & a_array,int lda,const port::ArraySlice<DeviceMemory<std::complex<double>> * > & b_array,int ldb,std::complex<double> beta,const port::ArraySlice<DeviceMemory<std::complex<double>> * > & c_array,int ldc,int batch_count,ScratchAllocator * scratch_allocator)1973 bool ROCMBlas::DoBlasGemmBatched(
1974     Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m,
1975     uint64 n, uint64 k, std::complex<double> alpha,
1976     const port::ArraySlice<DeviceMemory<std::complex<double>> *> &a_array,
1977     int lda,
1978     const port::ArraySlice<DeviceMemory<std::complex<double>> *> &b_array,
1979     int ldb, std::complex<double> beta,
1980     const port::ArraySlice<DeviceMemory<std::complex<double>> *> &c_array,
1981     int ldc, int batch_count, ScratchAllocator *scratch_allocator) {
1982   LOG(ERROR) << "rocBLAS does not currently support the GEMMBatched operation "
1983              << "for the \"complex<double>\" dataype";
1984   return false;
1985 }
1986 
DoBlasHemm(Stream * stream,blas::Side side,blas::UpperLower uplo,uint64 m,uint64 n,std::complex<float> alpha,const DeviceMemory<std::complex<float>> & a,int lda,const DeviceMemory<std::complex<float>> & b,int ldb,std::complex<float> beta,DeviceMemory<std::complex<float>> * c,int ldc)1987 bool ROCMBlas::DoBlasHemm(Stream *stream, blas::Side side,
1988                           blas::UpperLower uplo, uint64 m, uint64 n,
1989                           std::complex<float> alpha,
1990                           const DeviceMemory<std::complex<float>> &a, int lda,
1991                           const DeviceMemory<std::complex<float>> &b, int ldb,
1992                           std::complex<float> beta,
1993                           DeviceMemory<std::complex<float>> *c, int ldc) {
1994   LOG(ERROR) << "rocBLAS does not currently support the HEMM operation "
1995              << "for the \"complex<float>\" dataype";
1996   return false;
1997 }
1998 
DoBlasHemm(Stream * stream,blas::Side side,blas::UpperLower uplo,uint64 m,uint64 n,std::complex<double> alpha,const DeviceMemory<std::complex<double>> & a,int lda,const DeviceMemory<std::complex<double>> & b,int ldb,std::complex<double> beta,DeviceMemory<std::complex<double>> * c,int ldc)1999 bool ROCMBlas::DoBlasHemm(Stream *stream, blas::Side side,
2000                           blas::UpperLower uplo, uint64 m, uint64 n,
2001                           std::complex<double> alpha,
2002                           const DeviceMemory<std::complex<double>> &a, int lda,
2003                           const DeviceMemory<std::complex<double>> &b, int ldb,
2004                           std::complex<double> beta,
2005                           DeviceMemory<std::complex<double>> *c, int ldc) {
2006   LOG(ERROR) << "rocBLAS does not currently support the HEMM operation "
2007              << "for the \"complex<double>\" dataype";
2008   return false;
2009 }
2010 
DoBlasHerk(Stream * stream,blas::UpperLower uplo,blas::Transpose trans,uint64 n,uint64 k,float alpha,const DeviceMemory<std::complex<float>> & a,int lda,float beta,DeviceMemory<std::complex<float>> * c,int ldc)2011 bool ROCMBlas::DoBlasHerk(Stream *stream, blas::UpperLower uplo,
2012                           blas::Transpose trans, uint64 n, uint64 k,
2013                           float alpha,
2014                           const DeviceMemory<std::complex<float>> &a, int lda,
2015                           float beta, DeviceMemory<std::complex<float>> *c,
2016                           int ldc) {
2017   LOG(ERROR) << "rocBLAS does not currently support the HERK operation "
2018              << "for the \"complex<float>\" dataype";
2019   return false;
2020 }
2021 
DoBlasHerk(Stream * stream,blas::UpperLower uplo,blas::Transpose trans,uint64 n,uint64 k,double alpha,const DeviceMemory<std::complex<double>> & a,int lda,double beta,DeviceMemory<std::complex<double>> * c,int ldc)2022 bool ROCMBlas::DoBlasHerk(Stream *stream, blas::UpperLower uplo,
2023                           blas::Transpose trans, uint64 n, uint64 k,
2024                           double alpha,
2025                           const DeviceMemory<std::complex<double>> &a, int lda,
2026                           double beta, DeviceMemory<std::complex<double>> *c,
2027                           int ldc) {
2028   LOG(ERROR) << "rocBLAS does not currently support the HERK operation "
2029              << "for the \"complex<double>\" dataype";
2030   return false;
2031 }
2032 
DoBlasHer2k(Stream * stream,blas::UpperLower uplo,blas::Transpose trans,uint64 n,uint64 k,std::complex<float> alpha,const DeviceMemory<std::complex<float>> & a,int lda,const DeviceMemory<std::complex<float>> & b,int ldb,float beta,DeviceMemory<std::complex<float>> * c,int ldc)2033 bool ROCMBlas::DoBlasHer2k(Stream *stream, blas::UpperLower uplo,
2034                            blas::Transpose trans, uint64 n, uint64 k,
2035                            std::complex<float> alpha,
2036                            const DeviceMemory<std::complex<float>> &a, int lda,
2037                            const DeviceMemory<std::complex<float>> &b, int ldb,
2038                            float beta, DeviceMemory<std::complex<float>> *c,
2039                            int ldc) {
2040   LOG(ERROR) << "rocBLAS does not currently support the HER2K operation "
2041              << "for the \"complex<float>\" dataype";
2042   return false;
2043 }
2044 
DoBlasHer2k(Stream * stream,blas::UpperLower uplo,blas::Transpose trans,uint64 n,uint64 k,std::complex<double> alpha,const DeviceMemory<std::complex<double>> & a,int lda,const DeviceMemory<std::complex<double>> & b,int ldb,double beta,DeviceMemory<std::complex<double>> * c,int ldc)2045 bool ROCMBlas::DoBlasHer2k(Stream *stream, blas::UpperLower uplo,
2046                            blas::Transpose trans, uint64 n, uint64 k,
2047                            std::complex<double> alpha,
2048                            const DeviceMemory<std::complex<double>> &a, int lda,
2049                            const DeviceMemory<std::complex<double>> &b, int ldb,
2050                            double beta, DeviceMemory<std::complex<double>> *c,
2051                            int ldc) {
2052   LOG(ERROR) << "rocBLAS does not currently support the HER2K operation "
2053              << "for the \"complex<double>\" dataype";
2054   return false;
2055 }
2056 
DoBlasSymm(Stream * stream,blas::Side side,blas::UpperLower uplo,uint64 m,uint64 n,float alpha,const DeviceMemory<float> & a,int lda,const DeviceMemory<float> & b,int ldb,float beta,DeviceMemory<float> * c,int ldc)2057 bool ROCMBlas::DoBlasSymm(Stream *stream, blas::Side side,
2058                           blas::UpperLower uplo, uint64 m, uint64 n,
2059                           float alpha, const DeviceMemory<float> &a, int lda,
2060                           const DeviceMemory<float> &b, int ldb, float beta,
2061                           DeviceMemory<float> *c, int ldc) {
2062   LOG(ERROR) << "rocBLAS does not currently support the SYMM operation "
2063              << "for the \"float\" dataype";
2064   return false;
2065 }
2066 
DoBlasSymm(Stream * stream,blas::Side side,blas::UpperLower uplo,uint64 m,uint64 n,double alpha,const DeviceMemory<double> & a,int lda,const DeviceMemory<double> & b,int ldb,double beta,DeviceMemory<double> * c,int ldc)2067 bool ROCMBlas::DoBlasSymm(Stream *stream, blas::Side side,
2068                           blas::UpperLower uplo, uint64 m, uint64 n,
2069                           double alpha, const DeviceMemory<double> &a, int lda,
2070                           const DeviceMemory<double> &b, int ldb, double beta,
2071                           DeviceMemory<double> *c, int ldc) {
2072   LOG(ERROR) << "rocBLAS does not currently support the SYMM operation "
2073              << "for the \"double\" dataype";
2074   return false;
2075 }
2076 
DoBlasSymm(Stream * stream,blas::Side side,blas::UpperLower uplo,uint64 m,uint64 n,std::complex<float> alpha,const DeviceMemory<std::complex<float>> & a,int lda,const DeviceMemory<std::complex<float>> & b,int ldb,std::complex<float> beta,DeviceMemory<std::complex<float>> * c,int ldc)2077 bool ROCMBlas::DoBlasSymm(Stream *stream, blas::Side side,
2078                           blas::UpperLower uplo, uint64 m, uint64 n,
2079                           std::complex<float> alpha,
2080                           const DeviceMemory<std::complex<float>> &a, int lda,
2081                           const DeviceMemory<std::complex<float>> &b, int ldb,
2082                           std::complex<float> beta,
2083                           DeviceMemory<std::complex<float>> *c, int ldc) {
2084   LOG(ERROR) << "rocBLAS does not currently support the SYMM operation "
2085              << "for the \"complex<float>\" dataype";
2086   return false;
2087 }
2088 
DoBlasSymm(Stream * stream,blas::Side side,blas::UpperLower uplo,uint64 m,uint64 n,std::complex<double> alpha,const DeviceMemory<std::complex<double>> & a,int lda,const DeviceMemory<std::complex<double>> & b,int ldb,std::complex<double> beta,DeviceMemory<std::complex<double>> * c,int ldc)2089 bool ROCMBlas::DoBlasSymm(Stream *stream, blas::Side side,
2090                           blas::UpperLower uplo, uint64 m, uint64 n,
2091                           std::complex<double> alpha,
2092                           const DeviceMemory<std::complex<double>> &a, int lda,
2093                           const DeviceMemory<std::complex<double>> &b, int ldb,
2094                           std::complex<double> beta,
2095                           DeviceMemory<std::complex<double>> *c, int ldc) {
2096   LOG(ERROR) << "rocBLAS does not currently support the SYMM operation "
2097              << "for the \"complex<double>\" dataype";
2098   return false;
2099 }
2100 
DoBlasSyrk(Stream * stream,blas::UpperLower uplo,blas::Transpose trans,uint64 n,uint64 k,float alpha,const DeviceMemory<float> & a,int lda,float beta,DeviceMemory<float> * c,int ldc)2101 bool ROCMBlas::DoBlasSyrk(Stream *stream, blas::UpperLower uplo,
2102                           blas::Transpose trans, uint64 n, uint64 k,
2103                           float alpha, const DeviceMemory<float> &a, int lda,
2104                           float beta, DeviceMemory<float> *c, int ldc) {
2105   LOG(ERROR) << "rocBLAS does not currently support the SYRK operation "
2106              << "for the \"float\" dataype";
2107   return false;
2108 }
2109 
DoBlasSyrk(Stream * stream,blas::UpperLower uplo,blas::Transpose trans,uint64 n,uint64 k,double alpha,const DeviceMemory<double> & a,int lda,double beta,DeviceMemory<double> * c,int ldc)2110 bool ROCMBlas::DoBlasSyrk(Stream *stream, blas::UpperLower uplo,
2111                           blas::Transpose trans, uint64 n, uint64 k,
2112                           double alpha, const DeviceMemory<double> &a, int lda,
2113                           double beta, DeviceMemory<double> *c, int ldc) {
2114   LOG(ERROR) << "rocBLAS does not currently support the SYRK operation "
2115              << "for the \"double\" dataype";
2116   return false;
2117 }
2118 
DoBlasSyrk(Stream * stream,blas::UpperLower uplo,blas::Transpose trans,uint64 n,uint64 k,std::complex<float> alpha,const DeviceMemory<std::complex<float>> & a,int lda,std::complex<float> beta,DeviceMemory<std::complex<float>> * c,int ldc)2119 bool ROCMBlas::DoBlasSyrk(Stream *stream, blas::UpperLower uplo,
2120                           blas::Transpose trans, uint64 n, uint64 k,
2121                           std::complex<float> alpha,
2122                           const DeviceMemory<std::complex<float>> &a, int lda,
2123                           std::complex<float> beta,
2124                           DeviceMemory<std::complex<float>> *c, int ldc) {
2125   LOG(ERROR) << "rocBLAS does not currently support the SYRK operation "
2126              << "for the \"complex<float>\" dataype";
2127   return false;
2128 }
2129 
DoBlasSyrk(Stream * stream,blas::UpperLower uplo,blas::Transpose trans,uint64 n,uint64 k,std::complex<double> alpha,const DeviceMemory<std::complex<double>> & a,int lda,std::complex<double> beta,DeviceMemory<std::complex<double>> * c,int ldc)2130 bool ROCMBlas::DoBlasSyrk(Stream *stream, blas::UpperLower uplo,
2131                           blas::Transpose trans, uint64 n, uint64 k,
2132                           std::complex<double> alpha,
2133                           const DeviceMemory<std::complex<double>> &a, int lda,
2134                           std::complex<double> beta,
2135                           DeviceMemory<std::complex<double>> *c, int ldc) {
2136   LOG(ERROR) << "rocBLAS does not currently support the SYRK operation "
2137              << "for the \"complex<double>\" dataype";
2138   return false;
2139 }
2140 
DoBlasSyr2k(Stream * stream,blas::UpperLower uplo,blas::Transpose trans,uint64 n,uint64 k,float alpha,const DeviceMemory<float> & a,int lda,const DeviceMemory<float> & b,int ldb,float beta,DeviceMemory<float> * c,int ldc)2141 bool ROCMBlas::DoBlasSyr2k(Stream *stream, blas::UpperLower uplo,
2142                            blas::Transpose trans, uint64 n, uint64 k,
2143                            float alpha, const DeviceMemory<float> &a, int lda,
2144                            const DeviceMemory<float> &b, int ldb, float beta,
2145                            DeviceMemory<float> *c, int ldc) {
2146   LOG(ERROR) << "rocBLAS does not currently support the SYR2K operation "
2147              << "for the \"float\" dataype";
2148   return false;
2149 }
2150 
DoBlasSyr2k(Stream * stream,blas::UpperLower uplo,blas::Transpose trans,uint64 n,uint64 k,double alpha,const DeviceMemory<double> & a,int lda,const DeviceMemory<double> & b,int ldb,double beta,DeviceMemory<double> * c,int ldc)2151 bool ROCMBlas::DoBlasSyr2k(Stream *stream, blas::UpperLower uplo,
2152                            blas::Transpose trans, uint64 n, uint64 k,
2153                            double alpha, const DeviceMemory<double> &a, int lda,
2154                            const DeviceMemory<double> &b, int ldb, double beta,
2155                            DeviceMemory<double> *c, int ldc) {
2156   LOG(ERROR) << "rocBLAS does not currently support the SYR2K operation "
2157              << "for the \"double\" dataype";
2158   return false;
2159 }
2160 
DoBlasSyr2k(Stream * stream,blas::UpperLower uplo,blas::Transpose trans,uint64 n,uint64 k,std::complex<float> alpha,const DeviceMemory<std::complex<float>> & a,int lda,const DeviceMemory<std::complex<float>> & b,int ldb,std::complex<float> beta,DeviceMemory<std::complex<float>> * c,int ldc)2161 bool ROCMBlas::DoBlasSyr2k(Stream *stream, blas::UpperLower uplo,
2162                            blas::Transpose trans, uint64 n, uint64 k,
2163                            std::complex<float> alpha,
2164                            const DeviceMemory<std::complex<float>> &a, int lda,
2165                            const DeviceMemory<std::complex<float>> &b, int ldb,
2166                            std::complex<float> beta,
2167                            DeviceMemory<std::complex<float>> *c, int ldc) {
2168   LOG(ERROR) << "rocBLAS does not currently support the SYR2K operation "
2169              << "for the \"complex<float>\" dataype";
2170   return false;
2171 }
2172 
DoBlasSyr2k(Stream * stream,blas::UpperLower uplo,blas::Transpose trans,uint64 n,uint64 k,std::complex<double> alpha,const DeviceMemory<std::complex<double>> & a,int lda,const DeviceMemory<std::complex<double>> & b,int ldb,std::complex<double> beta,DeviceMemory<std::complex<double>> * c,int ldc)2173 bool ROCMBlas::DoBlasSyr2k(Stream *stream, blas::UpperLower uplo,
2174                            blas::Transpose trans, uint64 n, uint64 k,
2175                            std::complex<double> alpha,
2176                            const DeviceMemory<std::complex<double>> &a, int lda,
2177                            const DeviceMemory<std::complex<double>> &b, int ldb,
2178                            std::complex<double> beta,
2179                            DeviceMemory<std::complex<double>> *c, int ldc) {
2180   LOG(ERROR) << "rocBLAS does not currently support the SYR2K operation "
2181              << "for the \"complex<double>\" dataype";
2182   return false;
2183 }
2184 
DoBlasTrmm(Stream * stream,blas::Side side,blas::UpperLower uplo,blas::Transpose transa,blas::Diagonal diag,uint64 m,uint64 n,float alpha,const DeviceMemory<float> & a,int lda,DeviceMemory<float> * b,int ldb)2185 bool ROCMBlas::DoBlasTrmm(Stream *stream, blas::Side side,
2186                           blas::UpperLower uplo, blas::Transpose transa,
2187                           blas::Diagonal diag, uint64 m, uint64 n, float alpha,
2188                           const DeviceMemory<float> &a, int lda,
2189                           DeviceMemory<float> *b, int ldb) {
2190   LOG(ERROR) << "rocBLAS does not currently support the TRMM operation "
2191              << "for the \"float\" dataype";
2192   return false;
2193 }
2194 
DoBlasTrmm(Stream * stream,blas::Side side,blas::UpperLower uplo,blas::Transpose transa,blas::Diagonal diag,uint64 m,uint64 n,double alpha,const DeviceMemory<double> & a,int lda,DeviceMemory<double> * b,int ldb)2195 bool ROCMBlas::DoBlasTrmm(Stream *stream, blas::Side side,
2196                           blas::UpperLower uplo, blas::Transpose transa,
2197                           blas::Diagonal diag, uint64 m, uint64 n, double alpha,
2198                           const DeviceMemory<double> &a, int lda,
2199                           DeviceMemory<double> *b, int ldb) {
2200   LOG(ERROR) << "rocBLAS does not currently support the TRMM operation "
2201              << "for the \"double\" dataype";
2202   return false;
2203 }
2204 
DoBlasTrmm(Stream * stream,blas::Side side,blas::UpperLower uplo,blas::Transpose transa,blas::Diagonal diag,uint64 m,uint64 n,std::complex<float> alpha,const DeviceMemory<std::complex<float>> & a,int lda,DeviceMemory<std::complex<float>> * b,int ldb)2205 bool ROCMBlas::DoBlasTrmm(Stream *stream, blas::Side side,
2206                           blas::UpperLower uplo, blas::Transpose transa,
2207                           blas::Diagonal diag, uint64 m, uint64 n,
2208                           std::complex<float> alpha,
2209                           const DeviceMemory<std::complex<float>> &a, int lda,
2210                           DeviceMemory<std::complex<float>> *b, int ldb) {
2211   LOG(ERROR) << "rocBLAS does not currently support the TRMM operation "
2212              << "for the \"complex<float>\" dataype";
2213   return false;
2214 }
2215 
DoBlasTrmm(Stream * stream,blas::Side side,blas::UpperLower uplo,blas::Transpose transa,blas::Diagonal diag,uint64 m,uint64 n,std::complex<double> alpha,const DeviceMemory<std::complex<double>> & a,int lda,DeviceMemory<std::complex<double>> * b,int ldb)2216 bool ROCMBlas::DoBlasTrmm(Stream *stream, blas::Side side,
2217                           blas::UpperLower uplo, blas::Transpose transa,
2218                           blas::Diagonal diag, uint64 m, uint64 n,
2219                           std::complex<double> alpha,
2220                           const DeviceMemory<std::complex<double>> &a, int lda,
2221                           DeviceMemory<std::complex<double>> *b, int ldb) {
2222   LOG(ERROR) << "rocBLAS does not currently support the TRMM operation "
2223              << "for the \"complex<double>\" dataype";
2224   return false;
2225 }
2226 
DoBlasTrsm(Stream * stream,blas::Side side,blas::UpperLower uplo,blas::Transpose transa,blas::Diagonal diag,uint64 m,uint64 n,float alpha,const DeviceMemory<float> & a,int lda,DeviceMemory<float> * b,int ldb)2227 bool ROCMBlas::DoBlasTrsm(Stream *stream, blas::Side side,
2228                           blas::UpperLower uplo, blas::Transpose transa,
2229                           blas::Diagonal diag, uint64 m, uint64 n, float alpha,
2230                           const DeviceMemory<float> &a, int lda,
2231                           DeviceMemory<float> *b, int ldb) {
2232   return DoBlasInternal(
2233       wrap::rocblas_strsm, stream, true /* = pointer_mode_host */,
2234       ROCMBlasSide(side), ROCMBlasUpperLower(uplo), ROCMBlasTranspose(transa),
2235       ROCMBlasDiagonal(diag), m, n, &alpha, const_cast<float *>(GpuMemory(a)),
2236       lda, GpuMemoryMutable(b), ldb);
2237 }
2238 
DoBlasTrsm(Stream * stream,blas::Side side,blas::UpperLower uplo,blas::Transpose transa,blas::Diagonal diag,uint64 m,uint64 n,double alpha,const DeviceMemory<double> & a,int lda,DeviceMemory<double> * b,int ldb)2239 bool ROCMBlas::DoBlasTrsm(Stream *stream, blas::Side side,
2240                           blas::UpperLower uplo, blas::Transpose transa,
2241                           blas::Diagonal diag, uint64 m, uint64 n, double alpha,
2242                           const DeviceMemory<double> &a, int lda,
2243                           DeviceMemory<double> *b, int ldb) {
2244   return DoBlasInternal(
2245       wrap::rocblas_dtrsm, stream, true /* = pointer_mode_host */,
2246       ROCMBlasSide(side), ROCMBlasUpperLower(uplo), ROCMBlasTranspose(transa),
2247       ROCMBlasDiagonal(diag), m, n, &alpha, const_cast<double *>(GpuMemory(a)),
2248       lda, GpuMemoryMutable(b), ldb);
2249 }
2250 
DoBlasTrsm(Stream * stream,blas::Side side,blas::UpperLower uplo,blas::Transpose transa,blas::Diagonal diag,uint64 m,uint64 n,std::complex<float> alpha,const DeviceMemory<std::complex<float>> & a,int lda,DeviceMemory<std::complex<float>> * b,int ldb)2251 bool ROCMBlas::DoBlasTrsm(Stream *stream, blas::Side side,
2252                           blas::UpperLower uplo, blas::Transpose transa,
2253                           blas::Diagonal diag, uint64 m, uint64 n,
2254                           std::complex<float> alpha,
2255                           const DeviceMemory<std::complex<float>> &a, int lda,
2256                           DeviceMemory<std::complex<float>> *b, int ldb) {
2257   LOG(ERROR) << "rocBLAS does not currently support the TRSM operation "
2258              << "for the \"complex<float>\" dataype";
2259   return false;
2260 }
2261 
DoBlasTrsm(Stream * stream,blas::Side side,blas::UpperLower uplo,blas::Transpose transa,blas::Diagonal diag,uint64 m,uint64 n,std::complex<double> alpha,const DeviceMemory<std::complex<double>> & a,int lda,DeviceMemory<std::complex<double>> * b,int ldb)2262 bool ROCMBlas::DoBlasTrsm(Stream *stream, blas::Side side,
2263                           blas::UpperLower uplo, blas::Transpose transa,
2264                           blas::Diagonal diag, uint64 m, uint64 n,
2265                           std::complex<double> alpha,
2266                           const DeviceMemory<std::complex<double>> &a, int lda,
2267                           DeviceMemory<std::complex<double>> *b, int ldb) {
2268   LOG(ERROR) << "rocBLAS does not currently support the TRSM operation "
2269              << "for the \"complex<double>\" dataype";
2270   return false;
2271 }
DoBlasGemmStridedBatched(Stream * stream,blas::Transpose transa,blas::Transpose transb,uint64 m,uint64 n,uint64 k,float alpha,const DeviceMemory<Eigen::half> & a,int lda,int64 stride_a,const DeviceMemory<Eigen::half> & b,int ldb,int64 stride_b,float beta,DeviceMemory<Eigen::half> * c,int ldc,int64 stride_c,int batch_count)2272 bool ROCMBlas::DoBlasGemmStridedBatched(
2273     Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m,
2274     uint64 n, uint64 k, float alpha, const DeviceMemory<Eigen::half> &a,
2275     int lda, int64 stride_a, const DeviceMemory<Eigen::half> &b, int ldb,
2276     int64 stride_b, float beta, DeviceMemory<Eigen::half> *c, int ldc,
2277     int64 stride_c, int batch_count) {
2278   LOG(ERROR) << "rocBLAS does not currently support the "
2279                 "DoBlasGemmStridedBatched operation "
2280              << "for the \"Eigen::half\" dataype";
2281   return false;
2282 }
DoBlasGemmStridedBatched(Stream * stream,blas::Transpose transa,blas::Transpose transb,uint64 m,uint64 n,uint64 k,float alpha,const DeviceMemory<float> & a,int lda,int64 stride_a,const DeviceMemory<float> & b,int ldb,int64 stride_b,float beta,DeviceMemory<float> * c,int ldc,int64 stride_c,int batch_count)2283 bool ROCMBlas::DoBlasGemmStridedBatched(
2284     Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m,
2285     uint64 n, uint64 k, float alpha, const DeviceMemory<float> &a, int lda,
2286     int64 stride_a, const DeviceMemory<float> &b, int ldb, int64 stride_b,
2287     float beta, DeviceMemory<float> *c, int ldc, int64 stride_c,
2288     int batch_count) {
2289   LOG(ERROR) << "rocBLAS does not currently support the "
2290                 "DoBlasGemmStridedBatched operation "
2291              << "for the \"float\" dataype";
2292   return false;
2293 }
DoBlasGemmStridedBatched(Stream * stream,blas::Transpose transa,blas::Transpose transb,uint64 m,uint64 n,uint64 k,double alpha,const DeviceMemory<double> & a,int lda,int64 stride_a,const DeviceMemory<double> & b,int ldb,int64 stride_b,double beta,DeviceMemory<double> * c,int ldc,int64 stride_c,int batch_count)2294 bool ROCMBlas::DoBlasGemmStridedBatched(
2295     Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m,
2296     uint64 n, uint64 k, double alpha, const DeviceMemory<double> &a, int lda,
2297     int64 stride_a, const DeviceMemory<double> &b, int ldb, int64 stride_b,
2298     double beta, DeviceMemory<double> *c, int ldc, int64 stride_c,
2299     int batch_count) {
2300   LOG(ERROR) << "rocBLAS does not currently support the "
2301                 "DoBlasGemmStridedBatched operation "
2302              << "for the \"double\" dataype";
2303   return false;
2304 }
DoBlasGemmStridedBatched(Stream * stream,blas::Transpose transa,blas::Transpose transb,uint64 m,uint64 n,uint64 k,std::complex<float> alpha,const DeviceMemory<std::complex<float>> & a,int lda,int64 stride_a,const DeviceMemory<std::complex<float>> & b,int ldb,int64 stride_b,std::complex<float> beta,DeviceMemory<std::complex<float>> * c,int ldc,int64 stride_c,int batch_count)2305 bool ROCMBlas::DoBlasGemmStridedBatched(
2306     Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m,
2307     uint64 n, uint64 k, std::complex<float> alpha,
2308     const DeviceMemory<std::complex<float>> &a, int lda, int64 stride_a,
2309     const DeviceMemory<std::complex<float>> &b, int ldb, int64 stride_b,
2310     std::complex<float> beta, DeviceMemory<std::complex<float>> *c, int ldc,
2311     int64 stride_c, int batch_count) {
2312   LOG(ERROR) << "rocBLAS does not currently support the "
2313                 "DoBlasGemmStridedBatched operation "
2314              << "for the \"complex<float>\" dataype";
2315   return false;
2316 }
DoBlasGemmStridedBatched(Stream * stream,blas::Transpose transa,blas::Transpose transb,uint64 m,uint64 n,uint64 k,std::complex<double> alpha,const DeviceMemory<std::complex<double>> & a,int lda,int64 stride_a,const DeviceMemory<std::complex<double>> & b,int ldb,int64 stride_b,std::complex<double> beta,DeviceMemory<std::complex<double>> * c,int ldc,int64 stride_c,int batch_count)2317 bool ROCMBlas::DoBlasGemmStridedBatched(
2318     Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m,
2319     uint64 n, uint64 k, std::complex<double> alpha,
2320     const DeviceMemory<std::complex<double>> &a, int lda, int64 stride_a,
2321     const DeviceMemory<std::complex<double>> &b, int ldb, int64 stride_b,
2322     std::complex<double> beta, DeviceMemory<std::complex<double>> *c, int ldc,
2323     int64 stride_c, int batch_count) {
2324   LOG(ERROR) << "rocBLAS does not currently support the "
2325                 "DoBlasGemmStridedBatched operation "
2326              << "for the \"complex<double>\" dataype";
2327   return false;
2328 }
2329 }  // namespace gpu
2330 
initialize_rocblas()2331 void initialize_rocblas() {
2332   auto rocBlasAlreadyRegistered = PluginRegistry::Instance()->HasFactory(
2333       rocm::kROCmPlatformId, PluginKind::kBlas, gpu::kRocBlasPlugin);
2334 
2335   if (!rocBlasAlreadyRegistered) {
2336     port::Status status =
2337         PluginRegistry::Instance()
2338             ->RegisterFactory<PluginRegistry::BlasFactory>(
2339                 rocm::kROCmPlatformId, gpu::kRocBlasPlugin, "rocBLAS",
2340                 [](internal::StreamExecutorInterface *parent)
2341                     -> blas::BlasSupport * {
2342                   gpu::GpuExecutor *rocm_executor =
2343                       dynamic_cast<gpu::GpuExecutor *>(parent);
2344                   if (rocm_executor == nullptr) {
2345                     LOG(ERROR)
2346                         << "Attempting to initialize an instance of the "
2347                            "rocBLAS "
2348                         << "support library with a non-ROCM StreamExecutor";
2349                     return nullptr;
2350                   }
2351 
2352                   gpu::ROCMBlas *blas = new gpu::ROCMBlas(rocm_executor);
2353                   if (!blas->Init()) {
2354                     // Note: Init() will log a more specific error.
2355                     delete blas;
2356                     return nullptr;
2357                   }
2358                   return blas;
2359                 });
2360 
2361     if (!status.ok()) {
2362       LOG(ERROR) << "Unable to register rocBLAS factory: "
2363                  << status.error_message();
2364     }
2365 
2366     PluginRegistry::Instance()->SetDefaultFactory(
2367         rocm::kROCmPlatformId, PluginKind::kBlas, gpu::kRocBlasPlugin);
2368   }
2369 }
2370 
2371 }  // namespace stream_executor
2372 
2373 REGISTER_MODULE_INITIALIZER(register_rocblas,
2374                             { stream_executor::initialize_rocblas(); });
2375