1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // Exposes the family of BLAS routines as pre-canned high performance calls for
17 // use in conjunction with the StreamExecutor abstraction.
18 //
19 // Note that this interface is optionally supported by platforms; see
20 // StreamExecutor::SupportsBlas() for details.
21 //
22 // This abstraction makes it simple to entrain BLAS operations on GPU data into
23 // a Stream -- users typically will not use this API directly, but will use the
24 // Stream builder methods to entrain these operations "under the hood". For
25 // example:
26 //
27 //  DeviceMemory<float> x = stream_exec->AllocateArray<float>(1024);
28 //  DeviceMemory<float> y = stream_exec->AllocateArray<float>(1024);
29 //  // ... populate x and y ...
30 //  Stream stream{stream_exec};
31 //  stream
32 //    .Init()
33 //    .ThenBlasAxpy(1024, 5.5, x, 1, &y, 1);
34 //  SE_CHECK_OK(stream.BlockHostUntilDone());
35 //
36 // By using stream operations in this manner the user can easily intermix custom
37 // kernel launches (via StreamExecutor::ThenLaunch()) with these pre-canned BLAS
38 // routines.
39 
40 #ifndef TENSORFLOW_STREAM_EXECUTOR_BLAS_H_
41 #define TENSORFLOW_STREAM_EXECUTOR_BLAS_H_
42 
43 #include <complex>
44 #include <vector>
45 
46 #include "tensorflow/stream_executor/dnn.h"  // For DataType, ToDataType
47 #include "tensorflow/stream_executor/lib/array_slice.h"
48 #include "tensorflow/stream_executor/lib/statusor.h"
49 #include "tensorflow/stream_executor/platform/port.h"
50 
51 namespace Eigen {
52 struct half;
53 }  // namespace Eigen
54 
55 namespace stream_executor {
56 
57 class Stream;
58 class ScratchAllocator;
59 
60 template <typename ElemT>
61 class DeviceMemory;
62 
63 template <typename ElemT>
64 class HostOrDeviceScalar;
65 
66 namespace blas {
67 
68 // Specifies whether the input matrix will be transposed or
69 // transposed+conjugated before any BLAS operations.
70 enum class Transpose { kNoTranspose, kTranspose, kConjugateTranspose };
71 
72 // Returns a name for t.
73 std::string TransposeString(Transpose t);
74 
75 // Specifies whether the upper or lower triangular part of a
76 // symmetric/Hermitian matrix is used.
77 enum class UpperLower { kUpper, kLower };
78 
79 // Returns a name for ul.
80 std::string UpperLowerString(UpperLower ul);
81 
82 // Specifies whether a matrix is unit triangular.
83 enum class Diagonal { kUnit, kNonUnit };
84 
85 // Returns a name for d.
86 std::string DiagonalString(Diagonal d);
87 
88 // Specifies whether a Hermitian matrix appears on the left or right in
89 // operation.
90 enum class Side { kLeft, kRight };
91 
92 // Returns a name for s.
93 std::string SideString(Side s);
94 
95 // Type with which intermediate computations of a blas routine are performed.
96 //
97 // Some blas calls can perform computations with a type that's different than
98 // the type of their inputs/outputs.  This lets you e.g. multiply two matrices
99 // of int8s using float32s to store the matmul's intermediate values.
100 enum class ComputationType {
101   kF16,         // 16-bit floating-point
102   kF32,         // 32-bit floating-point
103   kF64,         // 64-bit floating-point
104   kI32,         // 32-bit integer
105   kComplexF32,  // Complex number comprised of two f32s.
106   kComplexF64,  // Complex number comprised of two f64s.
107   // The below values are only supported for BlasLt routines (both real and
108   // complex). They use float32 for accumulation but round the input mantissas
109   // to a smaller number of bits.
110   kTF32AsF32,  // 32-bit floating-point with reduced (>=10-bit) mantissa
111   kBF16AsF32,  // 32-bit floating-point with reduced (7-bit) mantissa
112 };
113 
114 enum class Epilogue {
115   kDefault = 1,                   // No special postprocessing
116   kReLU = 2,                      // Apply ReLU func point-wise to the results
117   kBias = 4,                      // Add broadcasted bias vector to the results
118   kBiasThenReLU = kBias | kReLU,  // Apply bias and then ReLU transform
119 };
120 
121 // Converts a ComputationType to a string.
122 std::string ComputationTypeString(ComputationType ty);
123 
124 std::ostream &operator<<(std::ostream &os, ComputationType ty);
125 
126 using dnn::DataType;
127 using dnn::ToDataType;
128 
129 // Describes the type of pointers for the scaling factors alpha and beta in
130 // blaslt routines.
131 enum class PointerMode {
132   kHost,
133   kDevice,
134 };
135 
136 // Converts a ComputationType to a string.
137 std::string DataTypeString(DataType ty);
138 
139 std::ostream &operator<<(std::ostream &os, DataType ty);
140 
141 // Opaque identifier for an "algorithm" used by a blas routine.  This functions
142 // as a hint to the blas library.
143 typedef int64 AlgorithmType;
144 constexpr AlgorithmType kDefaultAlgorithm = -1;
145 constexpr AlgorithmType kDefaultBlasGemm = -2;
146 constexpr AlgorithmType kDefaultBlasGemv = -3;
147 constexpr AlgorithmType kNoAlgorithm = -4;
148 
149 // blas uses -1 to represent the default algorithm. This happens to match up
150 // with the CUBLAS_GEMM_DFALT constant, so cuda_blas.cc is using static_cast
151 // to convert from AlgorithmType to cublasGemmAlgo_t, and uses a static_assert
152 // to ensure that this assumption does not break.
153 // If another blas implementation uses a different value for the default
154 // algorithm, then it needs to convert kDefaultGemmAlgo to that value
155 // (e.g. via a function called ToWhateverGemmAlgo).
156 constexpr AlgorithmType kDefaultGemmAlgo = -1;
157 
158 // Describes the result of a performance experiment, usually timing the speed of
159 // a particular AlgorithmType.
160 //
161 // If the call we were benchmarking failed (a common occurrence; not all
162 // algorithms are valid for all calls), is_valid() will be false.
163 class ProfileResult {
164  public:
is_valid()165   bool is_valid() const { return is_valid_; }
set_is_valid(bool val)166   void set_is_valid(bool val) { is_valid_ = val; }
algorithm()167   AlgorithmType algorithm() const { return algorithm_; }
set_algorithm(AlgorithmType val)168   void set_algorithm(AlgorithmType val) { algorithm_ = val; }
elapsed_time_in_ms()169   float elapsed_time_in_ms() const { return elapsed_time_in_ms_; }
set_elapsed_time_in_ms(float val)170   void set_elapsed_time_in_ms(float val) { elapsed_time_in_ms_ = val; }
171 
172  private:
173   bool is_valid_ = false;
174   AlgorithmType algorithm_ = kDefaultAlgorithm;
175   float elapsed_time_in_ms_ = std::numeric_limits<float>::max();
176 };
177 
178 class AlgorithmConfig {
179  public:
AlgorithmConfig()180   AlgorithmConfig() : algorithm_(kDefaultAlgorithm) {}
AlgorithmConfig(AlgorithmType algorithm)181   explicit AlgorithmConfig(AlgorithmType algorithm) : algorithm_(algorithm) {}
algorithm()182   AlgorithmType algorithm() const { return algorithm_; }
set_algorithm(AlgorithmType val)183   void set_algorithm(AlgorithmType val) { algorithm_ = val; }
184   bool operator==(const AlgorithmConfig &other) const {
185     return this->algorithm_ == other.algorithm_;
186   }
187   bool operator!=(const AlgorithmConfig &other) const {
188     return !(*this == other);
189   }
190   std::string ToString() const;
191 
192  private:
193   AlgorithmType algorithm_;
194 };
195 
196 struct IBlasLtMatmulPlan {
197   // Returns the data type of the A and B (input) matrices.
198   virtual DataType ab_type() const = 0;
199   // Returns the data type of the C (input/output) matrix.
200   virtual DataType c_type() const = 0;
~IBlasLtMatmulPlanIBlasLtMatmulPlan201   virtual ~IBlasLtMatmulPlan() {}
202 };
203 
204 struct IBlasLtMatmulAlgorithm {
~IBlasLtMatmulAlgorithmIBlasLtMatmulAlgorithm205   virtual ~IBlasLtMatmulAlgorithm() {}
206   // Returns the index of the algorithm within the list returned by
207   // GetBlasLtMatmulAlgorithms.
208   virtual AlgorithmType index() const = 0;
209   // Returns the workspace size required by the algorithm in bytes.
210   virtual size_t workspace_size() const = 0;
211 };
212 
213 // Parameters for the CreateBlasLtMatmulPlan method.
214 struct BlasLtMatmulPlanParams {
215   DataType ab_type;
216   DataType c_type;
217   ComputationType computation_type;
218   PointerMode pointer_mode;
219   Epilogue epilogue;
220   Transpose transa;
221   Transpose transb;
222   uint64 m;
223   uint64 n;
224   uint64 k;
225   int64 lda;
226   int64 ldb;
227   int64 ldc;
228   int batch_count = 1;
229   int64 stride_a = 0;
230   int64 stride_b = 0;
231   int64 stride_c = 0;
232 };
233 
234 // BLAS support interface -- this can be derived from a GPU executor when the
235 // underlying platform has an BLAS library implementation available. See
236 // StreamExecutor::AsBlas().
237 //
238 // Thread-hostile: CUDA associates a CUDA-context with a particular thread in
239 // the system. Any operation that a user attempts to perform by enqueueing BLAS
240 // operations on a thread not-associated with the CUDA-context has unknown
241 // behavior at the current time; see b/13176597
242 class BlasSupport {
243  public:
~BlasSupport()244   virtual ~BlasSupport() {}
245 
246   // Computes the sum of magnitudes of the vector elements.
247   // result <- |Re x(1)| + |Im x(1)| + |Re  x(2)| + |Im  x(2)|+ ... + |Re  x(n)|
248   // + |Im x(n)|.
249   // Note that Im x(i) = 0 for real types float/double.
250   virtual bool DoBlasAsum(Stream *stream, uint64 elem_count,
251                           const DeviceMemory<float> &x, int incx,
252                           DeviceMemory<float> *result) = 0;
253   virtual bool DoBlasAsum(Stream *stream, uint64 elem_count,
254                           const DeviceMemory<double> &x, int incx,
255                           DeviceMemory<double> *result) = 0;
256   virtual bool DoBlasAsum(Stream *stream, uint64 elem_count,
257                           const DeviceMemory<std::complex<float>> &x, int incx,
258                           DeviceMemory<float> *result) = 0;
259   virtual bool DoBlasAsum(Stream *stream, uint64 elem_count,
260                           const DeviceMemory<std::complex<double>> &x, int incx,
261                           DeviceMemory<double> *result) = 0;
262 
263   // Performs a BLAS y <- ax+y operation.
264   virtual bool DoBlasAxpy(Stream *stream, uint64 elem_count, float alpha,
265                           const DeviceMemory<float> &x, int incx,
266                           DeviceMemory<float> *y, int incy) = 0;
267   virtual bool DoBlasAxpy(Stream *stream, uint64 elem_count, double alpha,
268                           const DeviceMemory<double> &x, int incx,
269                           DeviceMemory<double> *y, int incy) = 0;
270   virtual bool DoBlasAxpy(Stream *stream, uint64 elem_count,
271                           std::complex<float> alpha,
272                           const DeviceMemory<std::complex<float>> &x, int incx,
273                           DeviceMemory<std::complex<float>> *y, int incy) = 0;
274   virtual bool DoBlasAxpy(Stream *stream, uint64 elem_count,
275                           std::complex<double> alpha,
276                           const DeviceMemory<std::complex<double>> &x, int incx,
277                           DeviceMemory<std::complex<double>> *y, int incy) = 0;
278 
279   // Copies vector to another vector: y <- x.
280   virtual bool DoBlasCopy(Stream *stream, uint64 elem_count,
281                           const DeviceMemory<float> &x, int incx,
282                           DeviceMemory<float> *y, int incy) = 0;
283   virtual bool DoBlasCopy(Stream *stream, uint64 elem_count,
284                           const DeviceMemory<double> &x, int incx,
285                           DeviceMemory<double> *y, int incy) = 0;
286   virtual bool DoBlasCopy(Stream *stream, uint64 elem_count,
287                           const DeviceMemory<std::complex<float>> &x, int incx,
288                           DeviceMemory<std::complex<float>> *y, int incy) = 0;
289   virtual bool DoBlasCopy(Stream *stream, uint64 elem_count,
290                           const DeviceMemory<std::complex<double>> &x, int incx,
291                           DeviceMemory<std::complex<double>> *y, int incy) = 0;
292 
293   // Performs a BLAS dot product result <- x . y.
294   virtual bool DoBlasDot(Stream *stream, uint64 elem_count,
295                          const DeviceMemory<float> &x, int incx,
296                          const DeviceMemory<float> &y, int incy,
297                          DeviceMemory<float> *result) = 0;
298   virtual bool DoBlasDot(Stream *stream, uint64 elem_count,
299                          const DeviceMemory<double> &x, int incx,
300                          const DeviceMemory<double> &y, int incy,
301                          DeviceMemory<double> *result) = 0;
302 
303   // Performs a BLAS dot product result <- conj(x) . y for complex types.
304   virtual bool DoBlasDotc(Stream *stream, uint64 elem_count,
305                           const DeviceMemory<std::complex<float>> &x, int incx,
306                           const DeviceMemory<std::complex<float>> &y, int incy,
307                           DeviceMemory<std::complex<float>> *result) = 0;
308   virtual bool DoBlasDotc(Stream *stream, uint64 elem_count,
309                           const DeviceMemory<std::complex<double>> &x, int incx,
310                           const DeviceMemory<std::complex<double>> &y, int incy,
311                           DeviceMemory<std::complex<double>> *result) = 0;
312 
313   // Performs a BLAS dot product result <- x . y for complex types. Note that
314   // x is unconjugated in this routine.
315   virtual bool DoBlasDotu(Stream *stream, uint64 elem_count,
316                           const DeviceMemory<std::complex<float>> &x, int incx,
317                           const DeviceMemory<std::complex<float>> &y, int incy,
318                           DeviceMemory<std::complex<float>> *result) = 0;
319   virtual bool DoBlasDotu(Stream *stream, uint64 elem_count,
320                           const DeviceMemory<std::complex<double>> &x, int incx,
321                           const DeviceMemory<std::complex<double>> &y, int incy,
322                           DeviceMemory<std::complex<double>> *result) = 0;
323 
324   // Computes the Euclidean norm of a vector: result <- ||x||.
325   // See the following link for more information of Euclidean norm:
326   // http://en.wikipedia.org/wiki/Norm_(mathematics)#Euclidean_norm
327   virtual bool DoBlasNrm2(Stream *stream, uint64 elem_count,
328                           const DeviceMemory<float> &x, int incx,
329                           DeviceMemory<float> *result) = 0;
330   virtual bool DoBlasNrm2(Stream *stream, uint64 elem_count,
331                           const DeviceMemory<double> &x, int incx,
332                           DeviceMemory<double> *result) = 0;
333   virtual bool DoBlasNrm2(Stream *stream, uint64 elem_count,
334                           const DeviceMemory<std::complex<float>> &x, int incx,
335                           DeviceMemory<float> *result) = 0;
336   virtual bool DoBlasNrm2(Stream *stream, uint64 elem_count,
337                           const DeviceMemory<std::complex<double>> &x, int incx,
338                           DeviceMemory<double> *result) = 0;
339 
340   // Performs rotation of points in the plane:
341   // x(i) = c*x(i) + s*y(i)
342   // y(i) = c*y(i) - s*x(i).
343   virtual bool DoBlasRot(Stream *stream, uint64 elem_count,
344                          DeviceMemory<float> *x, int incx,
345                          DeviceMemory<float> *y, int incy, float c,
346                          float s) = 0;
347   virtual bool DoBlasRot(Stream *stream, uint64 elem_count,
348                          DeviceMemory<double> *x, int incx,
349                          DeviceMemory<double> *y, int incy, double c,
350                          double s) = 0;
351   virtual bool DoBlasRot(Stream *stream, uint64 elem_count,
352                          DeviceMemory<std::complex<float>> *x, int incx,
353                          DeviceMemory<std::complex<float>> *y, int incy,
354                          float c, float s) = 0;
355   virtual bool DoBlasRot(Stream *stream, uint64 elem_count,
356                          DeviceMemory<std::complex<double>> *x, int incx,
357                          DeviceMemory<std::complex<double>> *y, int incy,
358                          double c, double s) = 0;
359 
360   // Computes the parameters for a Givens rotation.
361   // Given the Cartesian coordinates (a, b) of a point, these routines return
362   // the parameters c, s, r, and z associated with the Givens rotation. The
363   // parameters c and s define a unitary matrix such that:
364   //
365   //   |  c s |.| a | = | r |
366   //   | -s c | | b |   | 0 |
367   //
368   // The parameter z is defined such that if |a| > |b|, z is s; otherwise if
369   // c is not 0 z is 1/c; otherwise z is 1.
370   virtual bool DoBlasRotg(Stream *stream, DeviceMemory<float> *a,
371                           DeviceMemory<float> *b, DeviceMemory<float> *c,
372                           DeviceMemory<float> *s) = 0;
373   virtual bool DoBlasRotg(Stream *stream, DeviceMemory<double> *a,
374                           DeviceMemory<double> *b, DeviceMemory<double> *c,
375                           DeviceMemory<double> *s) = 0;
376   virtual bool DoBlasRotg(Stream *stream, DeviceMemory<std::complex<float>> *a,
377                           DeviceMemory<std::complex<float>> *b,
378                           DeviceMemory<float> *c,
379                           DeviceMemory<std::complex<float>> *s) = 0;
380   virtual bool DoBlasRotg(Stream *stream, DeviceMemory<std::complex<double>> *a,
381                           DeviceMemory<std::complex<double>> *b,
382                           DeviceMemory<double> *c,
383                           DeviceMemory<std::complex<double>> *s) = 0;
384 
385   // Performs modified Givens rotation of points in the plane.
386   // Given two vectors x and y, each vector element of these vectors is replaced
387   // as follows:
388   //
389   //   | x(i) | =  H | x(i) |
390   //   | y(i) |      | y(i) |
391   //
392   // for i=1 to n, where H is a modified Givens transformation matrix whose
393   // values are stored in the param[1] through param[4] array.
394   // For more information please Google this routine.
395   virtual bool DoBlasRotm(Stream *stream, uint64 elem_count,
396                           DeviceMemory<float> *x, int incx,
397                           DeviceMemory<float> *y, int incy,
398                           const DeviceMemory<float> &param) = 0;
399   virtual bool DoBlasRotm(Stream *stream, uint64 elem_count,
400                           DeviceMemory<double> *x, int incx,
401                           DeviceMemory<double> *y, int incy,
402                           const DeviceMemory<double> &param) = 0;
403 
404   // Computes the parameters for a modified Givens rotation.
405   // Given Cartesian coordinates (x1, y1) of an input vector, these routines
406   // compute the components of a modified Givens transformation matrix H that
407   // zeros the y-component of the resulting vector:
408   //
409   //   | x1 | =  H | x1 * sqrt(d1) |
410   //   |  0 |      | y1 * sqrt(d1) |
411   //
412   // For more information please Google this routine.
413   virtual bool DoBlasRotmg(Stream *stream, DeviceMemory<float> *d1,
414                            DeviceMemory<float> *d2, DeviceMemory<float> *x1,
415                            const DeviceMemory<float> &y1,
416                            DeviceMemory<float> *param) = 0;
417   virtual bool DoBlasRotmg(Stream *stream, DeviceMemory<double> *d1,
418                            DeviceMemory<double> *d2, DeviceMemory<double> *x1,
419                            const DeviceMemory<double> &y1,
420                            DeviceMemory<double> *param) = 0;
421 
422   // Computes the product of a vector by a scalar: x <- a*x.
423   virtual bool DoBlasScal(Stream *stream, uint64 elem_count, float alpha,
424                           DeviceMemory<float> *x, int incx) = 0;
425   virtual bool DoBlasScal(Stream *stream, uint64 elem_count, double alpha,
426                           DeviceMemory<double> *x, int incx) = 0;
427   virtual bool DoBlasScal(Stream *stream, uint64 elem_count, float alpha,
428                           DeviceMemory<std::complex<float>> *x, int incx) = 0;
429   virtual bool DoBlasScal(Stream *stream, uint64 elem_count, double alpha,
430                           DeviceMemory<std::complex<double>> *x, int incx) = 0;
431   virtual bool DoBlasScal(Stream *stream, uint64 elem_count,
432                           std::complex<float> alpha,
433                           DeviceMemory<std::complex<float>> *x, int incx) = 0;
434   virtual bool DoBlasScal(Stream *stream, uint64 elem_count,
435                           std::complex<double> alpha,
436                           DeviceMemory<std::complex<double>> *x, int incx) = 0;
437 
438   // Swaps a vector with another vector.
439   virtual bool DoBlasSwap(Stream *stream, uint64 elem_count,
440                           DeviceMemory<float> *x, int incx,
441                           DeviceMemory<float> *y, int incy) = 0;
442   virtual bool DoBlasSwap(Stream *stream, uint64 elem_count,
443                           DeviceMemory<double> *x, int incx,
444                           DeviceMemory<double> *y, int incy) = 0;
445   virtual bool DoBlasSwap(Stream *stream, uint64 elem_count,
446                           DeviceMemory<std::complex<float>> *x, int incx,
447                           DeviceMemory<std::complex<float>> *y, int incy) = 0;
448   virtual bool DoBlasSwap(Stream *stream, uint64 elem_count,
449                           DeviceMemory<std::complex<double>> *x, int incx,
450                           DeviceMemory<std::complex<double>> *y, int incy) = 0;
451 
452   // Finds the index of the element with maximum absolute value.
453   virtual bool DoBlasIamax(Stream *stream, uint64 elem_count,
454                            const DeviceMemory<float> &x, int incx,
455                            DeviceMemory<int> *result) = 0;
456   virtual bool DoBlasIamax(Stream *stream, uint64 elem_count,
457                            const DeviceMemory<double> &x, int incx,
458                            DeviceMemory<int> *result) = 0;
459   virtual bool DoBlasIamax(Stream *stream, uint64 elem_count,
460                            const DeviceMemory<std::complex<float>> &x, int incx,
461                            DeviceMemory<int> *result) = 0;
462   virtual bool DoBlasIamax(Stream *stream, uint64 elem_count,
463                            const DeviceMemory<std::complex<double>> &x,
464                            int incx, DeviceMemory<int> *result) = 0;
465 
466   // Finds the index of the element with minimum absolute value.
467   virtual bool DoBlasIamin(Stream *stream, uint64 elem_count,
468                            const DeviceMemory<float> &x, int incx,
469                            DeviceMemory<int> *result) = 0;
470   virtual bool DoBlasIamin(Stream *stream, uint64 elem_count,
471                            const DeviceMemory<double> &x, int incx,
472                            DeviceMemory<int> *result) = 0;
473   virtual bool DoBlasIamin(Stream *stream, uint64 elem_count,
474                            const DeviceMemory<std::complex<float>> &x, int incx,
475                            DeviceMemory<int> *result) = 0;
476   virtual bool DoBlasIamin(Stream *stream, uint64 elem_count,
477                            const DeviceMemory<std::complex<double>> &x,
478                            int incx, DeviceMemory<int> *result) = 0;
479 
480   // Computes a matrix-vector product using a general band matrix:
481   //
482   //     y <- alpha * a * x + beta * y,
483   // or
484   //     y <- alpha * a' * x + beta * y,
485   // or
486   //     y <- alpha * conj(a') * x + beta * y,
487   //
488   // alpha and beta are scalars; a is an m-by-n general band matrix, with kl
489   // sub-diagonals and ku super-diagonals; x is a vector with
490   // n(trans==kNoTranspose)/m(otherwise) elements;
491   // y is a vector with m(trans==kNoTranspose)/n(otherwise) elements.
492   virtual bool DoBlasGbmv(Stream *stream, blas::Transpose trans, uint64 m,
493                           uint64 n, uint64 kl, uint64 ku, float alpha,
494                           const DeviceMemory<float> &a, int lda,
495                           const DeviceMemory<float> &x, int incx, float beta,
496                           DeviceMemory<float> *y, int incy) = 0;
497   virtual bool DoBlasGbmv(Stream *stream, blas::Transpose trans, uint64 m,
498                           uint64 n, uint64 kl, uint64 ku, double alpha,
499                           const DeviceMemory<double> &a, int lda,
500                           const DeviceMemory<double> &x, int incx, double beta,
501                           DeviceMemory<double> *y, int incy) = 0;
502   virtual bool DoBlasGbmv(Stream *stream, blas::Transpose trans, uint64 m,
503                           uint64 n, uint64 kl, uint64 ku,
504                           std::complex<float> alpha,
505                           const DeviceMemory<std::complex<float>> &a, int lda,
506                           const DeviceMemory<std::complex<float>> &x, int incx,
507                           std::complex<float> beta,
508                           DeviceMemory<std::complex<float>> *y, int incy) = 0;
509   virtual bool DoBlasGbmv(Stream *stream, blas::Transpose trans, uint64 m,
510                           uint64 n, uint64 kl, uint64 ku,
511                           std::complex<double> alpha,
512                           const DeviceMemory<std::complex<double>> &a, int lda,
513                           const DeviceMemory<std::complex<double>> &x, int incx,
514                           std::complex<double> beta,
515                           DeviceMemory<std::complex<double>> *y, int incy) = 0;
516 
517   // Computes a matrix-vector product using a general matrix.
518   //
519   //     y <- alpha * a * x + beta * y,
520   // or
521   //     y <- alpha * a' * x + beta * y,
522   // or
523   //     y <- alpha * conj(a') * x + beta * y,
524   //
525   // alpha and beta are scalars; a is an m-by-n general matrix; x is a vector
526   // with n(trans==kNoTranspose)/m(otherwise) elements;
527   // y is a vector with m(trans==kNoTranspose)/n(otherwise) elements.
528   virtual bool DoBlasGemv(Stream *stream, blas::Transpose trans, uint64 m,
529                           uint64 n, float alpha, const DeviceMemory<float> &a,
530                           int lda, const DeviceMemory<float> &x, int incx,
531                           float beta, DeviceMemory<float> *y, int incy) = 0;
532   virtual bool DoBlasGemv(Stream *stream, blas::Transpose trans, uint64 m,
533                           uint64 n, double alpha, const DeviceMemory<double> &a,
534                           int lda, const DeviceMemory<double> &x, int incx,
535                           double beta, DeviceMemory<double> *y, int incy) = 0;
536   virtual bool DoBlasGemv(Stream *stream, blas::Transpose trans, uint64 m,
537                           uint64 n, std::complex<float> alpha,
538                           const DeviceMemory<std::complex<float>> &a, int lda,
539                           const DeviceMemory<std::complex<float>> &x, int incx,
540                           std::complex<float> beta,
541                           DeviceMemory<std::complex<float>> *y, int incy) = 0;
542   virtual bool DoBlasGemv(Stream *stream, blas::Transpose trans, uint64 m,
543                           uint64 n, std::complex<double> alpha,
544                           const DeviceMemory<std::complex<double>> &a, int lda,
545                           const DeviceMemory<std::complex<double>> &x, int incx,
546                           std::complex<double> beta,
547                           DeviceMemory<std::complex<double>> *y, int incy) = 0;
548 
549   virtual bool DoBlasGemvWithProfiling(
550       Stream *stream, blas::Transpose trans, uint64 m, uint64 n, float alpha,
551       const DeviceMemory<float> &a, int lda, const DeviceMemory<float> &x,
552       int incx, float beta, DeviceMemory<float> *y, int incy,
553       ProfileResult *output_profile_result) = 0;
554   virtual bool DoBlasGemvWithProfiling(
555       Stream *stream, blas::Transpose trans, uint64 m, uint64 n, double alpha,
556       const DeviceMemory<double> &a, int lda, const DeviceMemory<double> &x,
557       int incx, double beta, DeviceMemory<double> *y, int incy,
558       ProfileResult *output_profile_result) = 0;
559   virtual bool DoBlasGemvWithProfiling(
560       Stream *stream, blas::Transpose trans, uint64 m, uint64 n,
561       std::complex<float> alpha, const DeviceMemory<std::complex<float>> &a,
562       int lda, const DeviceMemory<std::complex<float>> &x, int incx,
563       std::complex<float> beta, DeviceMemory<std::complex<float>> *y, int incy,
564       ProfileResult *output_profile_result) = 0;
565   virtual bool DoBlasGemvWithProfiling(
566       Stream *stream, blas::Transpose trans, uint64 m, uint64 n,
567       std::complex<double> alpha, const DeviceMemory<std::complex<double>> &a,
568       int lda, const DeviceMemory<std::complex<double>> &x, int incx,
569       std::complex<double> beta, DeviceMemory<std::complex<double>> *y,
570       int incy, ProfileResult *output_profile_result) = 0;
571 
572   // Performs a rank-1 update of a general matrix.
573   //
574   //     a <- alpha * x * y' + a,
575   //
576   // alpha is a scalar; x is an m-element vector; y is an n-element vector; a is
577   // an m-by-n general matrix.
578   virtual bool DoBlasGer(Stream *stream, uint64 m, uint64 n, float alpha,
579                          const DeviceMemory<float> &x, int incx,
580                          const DeviceMemory<float> &y, int incy,
581                          DeviceMemory<float> *a, int lda) = 0;
582   virtual bool DoBlasGer(Stream *stream, uint64 m, uint64 n, double alpha,
583                          const DeviceMemory<double> &x, int incx,
584                          const DeviceMemory<double> &y, int incy,
585                          DeviceMemory<double> *a, int lda) = 0;
586 
587   // Performs a rank-1 update (conjugated) of a general matrix.
588   //
589   //     a <- alpha * x * conj(y') + a,
590   //
591   // alpha is a scalar; x is an m-element vector; y is an n-element vector; a is
592   // an m-by-n general matrix.
593   virtual bool DoBlasGerc(Stream *stream, uint64 m, uint64 n,
594                           std::complex<float> alpha,
595                           const DeviceMemory<std::complex<float>> &x, int incx,
596                           const DeviceMemory<std::complex<float>> &y, int incy,
597                           DeviceMemory<std::complex<float>> *a, int lda) = 0;
598   virtual bool DoBlasGerc(Stream *stream, uint64 m, uint64 n,
599                           std::complex<double> alpha,
600                           const DeviceMemory<std::complex<double>> &x, int incx,
601                           const DeviceMemory<std::complex<double>> &y, int incy,
602                           DeviceMemory<std::complex<double>> *a, int lda) = 0;
603 
604   // Performs a rank-1 update (unconjugated) of a general matrix.
605   //
606   //     a <- alpha * x * y' + a,
607   //
608   // alpha is a scalar; x is an m-element vector; y is an n-element vector; a is
609   // an m-by-n general matrix.
610   virtual bool DoBlasGeru(Stream *stream, uint64 m, uint64 n,
611                           std::complex<float> alpha,
612                           const DeviceMemory<std::complex<float>> &x, int incx,
613                           const DeviceMemory<std::complex<float>> &y, int incy,
614                           DeviceMemory<std::complex<float>> *a, int lda) = 0;
615   virtual bool DoBlasGeru(Stream *stream, uint64 m, uint64 n,
616                           std::complex<double> alpha,
617                           const DeviceMemory<std::complex<double>> &x, int incx,
618                           const DeviceMemory<std::complex<double>> &y, int incy,
619                           DeviceMemory<std::complex<double>> *a, int lda) = 0;
620 
621   // Computes a matrix-vector product using a Hermitian band matrix.
622   //
623   //     y <- alpha * a * x + beta * y,
624   //
625   // alpha and beta are scalars; a is an n-by-n Hermitian band matrix, with k
626   // super-diagonals; x and y are n-element vectors.
627   virtual bool DoBlasHbmv(Stream *stream, blas::UpperLower uplo, uint64 n,
628                           uint64 k, std::complex<float> alpha,
629                           const DeviceMemory<std::complex<float>> &a, int lda,
630                           const DeviceMemory<std::complex<float>> &x, int incx,
631                           std::complex<float> beta,
632                           DeviceMemory<std::complex<float>> *y, int incy) = 0;
633   virtual bool DoBlasHbmv(Stream *stream, blas::UpperLower uplo, uint64 n,
634                           uint64 k, std::complex<double> alpha,
635                           const DeviceMemory<std::complex<double>> &a, int lda,
636                           const DeviceMemory<std::complex<double>> &x, int incx,
637                           std::complex<double> beta,
638                           DeviceMemory<std::complex<double>> *y, int incy) = 0;
639 
640   // Computes a matrix-vector product using a Hermitian matrix.
641   //
642   //     y <- alpha * a * x + beta * y,
643   //
644   // alpha and beta are scalars; a is an n-by-n Hermitian matrix; x and y are
645   // n-element vectors.
646   virtual bool DoBlasHemv(Stream *stream, blas::UpperLower uplo, uint64 n,
647                           std::complex<float> alpha,
648                           const DeviceMemory<std::complex<float>> &a, int lda,
649                           const DeviceMemory<std::complex<float>> &x, int incx,
650                           std::complex<float> beta,
651                           DeviceMemory<std::complex<float>> *y, int incy) = 0;
652   virtual bool DoBlasHemv(Stream *stream, blas::UpperLower uplo, uint64 n,
653                           std::complex<double> alpha,
654                           const DeviceMemory<std::complex<double>> &a, int lda,
655                           const DeviceMemory<std::complex<double>> &x, int incx,
656                           std::complex<double> beta,
657                           DeviceMemory<std::complex<double>> *y, int incy) = 0;
658 
659   // Performs a rank-1 update of a Hermitian matrix.
660   //
661   //     a <- alpha * x * conj(x') + a,
662   //
663   // alpha is a scalar; x is an n-element vector; a is an n-by-n Hermitian
664   // matrix.
665   virtual bool DoBlasHer(Stream *stream, blas::UpperLower uplo, uint64 n,
666                          float alpha,
667                          const DeviceMemory<std::complex<float>> &x, int incx,
668                          DeviceMemory<std::complex<float>> *a, int lda) = 0;
669   virtual bool DoBlasHer(Stream *stream, blas::UpperLower uplo, uint64 n,
670                          double alpha,
671                          const DeviceMemory<std::complex<double>> &x, int incx,
672                          DeviceMemory<std::complex<double>> *a, int lda) = 0;
673 
674   // Performs a rank-2 update of a Hermitian matrix.
675   //
676   //     a <- alpha * x * conj(x') + conj(alpha) * y * conj(x') + a,
677   //
678   // alpha is a scalar; x and y are n-element vectors; a is an n-by-n Hermitian
679   // matrix.
680   virtual bool DoBlasHer2(Stream *stream, blas::UpperLower uplo, uint64 n,
681                           std::complex<float> alpha,
682                           const DeviceMemory<std::complex<float>> &x, int incx,
683                           const DeviceMemory<std::complex<float>> &y, int incy,
684                           DeviceMemory<std::complex<float>> *a, int lda) = 0;
685   virtual bool DoBlasHer2(Stream *stream, blas::UpperLower uplo, uint64 n,
686                           std::complex<double> alpha,
687                           const DeviceMemory<std::complex<double>> &x, int incx,
688                           const DeviceMemory<std::complex<double>> &y, int incy,
689                           DeviceMemory<std::complex<double>> *a, int lda) = 0;
690 
691   // Computes a matrix-vector product using a Hermitian packed matrix.
692   //
693   //     y <- alpha * a * x + beta * y,
694   //
695   // alpha and beta are scalars; a is an n-by-n Hermitian matrix, supplied in
696   // packed form; x and y are n-element vectors.
697   virtual bool DoBlasHpmv(Stream *stream, blas::UpperLower uplo, uint64 n,
698                           std::complex<float> alpha,
699                           const DeviceMemory<std::complex<float>> &ap,
700                           const DeviceMemory<std::complex<float>> &x, int incx,
701                           std::complex<float> beta,
702                           DeviceMemory<std::complex<float>> *y, int incy) = 0;
703   virtual bool DoBlasHpmv(Stream *stream, blas::UpperLower uplo, uint64 n,
704                           std::complex<double> alpha,
705                           const DeviceMemory<std::complex<double>> &ap,
706                           const DeviceMemory<std::complex<double>> &x, int incx,
707                           std::complex<double> beta,
708                           DeviceMemory<std::complex<double>> *y, int incy) = 0;
709 
710   // Performs a rank-1 update of a Hermitian packed matrix.
711   //
712   //     a <- alpha * x * conj(x') + a,
713   //
714   // alpha is a scalar; x is an n-element vector; a is an n-by-n Hermitian
715   // matrix, supplied in packed form.
716   virtual bool DoBlasHpr(Stream *stream, blas::UpperLower uplo, uint64 n,
717                          float alpha,
718                          const DeviceMemory<std::complex<float>> &x, int incx,
719                          DeviceMemory<std::complex<float>> *ap) = 0;
720   virtual bool DoBlasHpr(Stream *stream, blas::UpperLower uplo, uint64 n,
721                          double alpha,
722                          const DeviceMemory<std::complex<double>> &x, int incx,
723                          DeviceMemory<std::complex<double>> *ap) = 0;
724 
725   // Performs a rank-2 update of a Hermitian packed matrix.
726   //
727   //     a <- alpha * x * conj(x') + conj(alpha) * y * conj(x') + a,
728   //
729   // alpha is a scalar; x and y are n-element vectors; a is an n-by-n Hermitian
730   // matrix, supplied in packed form.
731   virtual bool DoBlasHpr2(Stream *stream, blas::UpperLower uplo, uint64 n,
732                           std::complex<float> alpha,
733                           const DeviceMemory<std::complex<float>> &x, int incx,
734                           const DeviceMemory<std::complex<float>> &y, int incy,
735                           DeviceMemory<std::complex<float>> *ap) = 0;
736   virtual bool DoBlasHpr2(Stream *stream, blas::UpperLower uplo, uint64 n,
737                           std::complex<double> alpha,
738                           const DeviceMemory<std::complex<double>> &x, int incx,
739                           const DeviceMemory<std::complex<double>> &y, int incy,
740                           DeviceMemory<std::complex<double>> *ap) = 0;
741 
742   // Computes a matrix-vector product using a symmetric band matrix.
743   //
744   //     y <- alpha * a * x + beta * y,
745   //
746   // alpha and beta are scalars; a is an n-by-n symmetric band matrix, with k
747   // super-diagonals; x and y are n-element vectors.
748   virtual bool DoBlasSbmv(Stream *stream, blas::UpperLower uplo, uint64 n,
749                           uint64 k, float alpha, const DeviceMemory<float> &a,
750                           int lda, const DeviceMemory<float> &x, int incx,
751                           float beta, DeviceMemory<float> *y, int incy) = 0;
752   virtual bool DoBlasSbmv(Stream *stream, blas::UpperLower uplo, uint64 n,
753                           uint64 k, double alpha, const DeviceMemory<double> &a,
754                           int lda, const DeviceMemory<double> &x, int incx,
755                           double beta, DeviceMemory<double> *y, int incy) = 0;
756 
757   // Computes a matrix-vector product using a symmetric packed matrix.
758   //
759   //     y <- alpha * a * x + beta * y,
760   //
761   // alpha and beta are scalars; a is an n-by-n symmetric matrix, supplied in
762   // packed form; x and y are n-element vectors.
763   virtual bool DoBlasSpmv(Stream *stream, blas::UpperLower uplo, uint64 n,
764                           float alpha, const DeviceMemory<float> &ap,
765                           const DeviceMemory<float> &x, int incx, float beta,
766                           DeviceMemory<float> *y, int incy) = 0;
767   virtual bool DoBlasSpmv(Stream *stream, blas::UpperLower uplo, uint64 n,
768                           double alpha, const DeviceMemory<double> &ap,
769                           const DeviceMemory<double> &x, int incx, double beta,
770                           DeviceMemory<double> *y, int incy) = 0;
771 
772   // Performs a rank-1 update of a symmetric packed matrix.
773   //
774   //     a <- alpha * x * x' + a,
775   //
776   // alpha is a scalar; x is an n-element vector; a is an n-by-n symmetric
777   // matrix, supplied in packed form.
778   virtual bool DoBlasSpr(Stream *stream, blas::UpperLower uplo, uint64 n,
779                          float alpha, const DeviceMemory<float> &x, int incx,
780                          DeviceMemory<float> *ap) = 0;
781   virtual bool DoBlasSpr(Stream *stream, blas::UpperLower uplo, uint64 n,
782                          double alpha, const DeviceMemory<double> &x, int incx,
783                          DeviceMemory<double> *ap) = 0;
784 
785   // Performs a rank-2 update of a symmetric packed matrix.
786   //
787   //     a <- alpha * x * x' + alpha * y * x' + a,
788   //
789   // alpha is a scalar; x and y are n-element vectors; a is an n-by-n symmetric
790   // matrix, supplied in packed form.
791   virtual bool DoBlasSpr2(Stream *stream, blas::UpperLower uplo, uint64 n,
792                           float alpha, const DeviceMemory<float> &x, int incx,
793                           const DeviceMemory<float> &y, int incy,
794                           DeviceMemory<float> *ap) = 0;
795   virtual bool DoBlasSpr2(Stream *stream, blas::UpperLower uplo, uint64 n,
796                           double alpha, const DeviceMemory<double> &x, int incx,
797                           const DeviceMemory<double> &y, int incy,
798                           DeviceMemory<double> *ap) = 0;
799 
800   // Computes a matrix-vector product for a symmetric matrix.
801   //
802   //     y <- alpha * a * x + beta * y,
803   //
804   // alpha and beta are scalars; a is an n-by-n symmetric matrix; x and y are
805   // n-element vectors.
806   virtual bool DoBlasSymv(Stream *stream, blas::UpperLower uplo, uint64 n,
807                           float alpha, const DeviceMemory<float> &a, int lda,
808                           const DeviceMemory<float> &x, int incx, float beta,
809                           DeviceMemory<float> *y, int incy) = 0;
810   virtual bool DoBlasSymv(Stream *stream, blas::UpperLower uplo, uint64 n,
811                           double alpha, const DeviceMemory<double> &a, int lda,
812                           const DeviceMemory<double> &x, int incx, double beta,
813                           DeviceMemory<double> *y, int incy) = 0;
814 
815   // Performs a rank-1 update of a symmetric matrix.
816   //
817   //     a <- alpha * x * x' + a,
818   //
819   // alpha is a scalar; x is an n-element vector; a is an n-by-n symmetric
820   // matrix.
821   virtual bool DoBlasSyr(Stream *stream, blas::UpperLower uplo, uint64 n,
822                          float alpha, const DeviceMemory<float> &x, int incx,
823                          DeviceMemory<float> *a, int lda) = 0;
824   virtual bool DoBlasSyr(Stream *stream, blas::UpperLower uplo, uint64 n,
825                          double alpha, const DeviceMemory<double> &x, int incx,
826                          DeviceMemory<double> *a, int lda) = 0;
827 
828   // Performs a rank-2 update of symmetric matrix.
829   //
830   //     a <- alpha * x * x' + alpha * y * x' + a,
831   //
832   // alpha is a scalar; x and y are n-element vectors; a is an n-by-n symmetric
833   // matrix.
834   virtual bool DoBlasSyr2(Stream *stream, blas::UpperLower uplo, uint64 n,
835                           float alpha, const DeviceMemory<float> &x, int incx,
836                           const DeviceMemory<float> &y, int incy,
837                           DeviceMemory<float> *a, int lda) = 0;
838   virtual bool DoBlasSyr2(Stream *stream, blas::UpperLower uplo, uint64 n,
839                           double alpha, const DeviceMemory<double> &x, int incx,
840                           const DeviceMemory<double> &y, int incy,
841                           DeviceMemory<double> *a, int lda) = 0;
842 
843   // Computes a matrix-vector product using a triangular band matrix.
844   //
845   //     x <- a * x,
846   // or
847   //     x <- a' * x,
848   // or
849   //     x <- conj(a') * x,
850   //
851   // a is an n-by-n unit, or non-unit, upper or lower triangular band matrix,
852   // with k+1 diagonals; x is a n-element vector.
853   virtual bool DoBlasTbmv(Stream *stream, blas::UpperLower uplo,
854                           blas::Transpose trans, blas::Diagonal diag, uint64 n,
855                           uint64 k, const DeviceMemory<float> &a, int lda,
856                           DeviceMemory<float> *x, int incx) = 0;
857   virtual bool DoBlasTbmv(Stream *stream, blas::UpperLower uplo,
858                           blas::Transpose trans, blas::Diagonal diag, uint64 n,
859                           uint64 k, const DeviceMemory<double> &a, int lda,
860                           DeviceMemory<double> *x, int incx) = 0;
861   virtual bool DoBlasTbmv(Stream *stream, blas::UpperLower uplo,
862                           blas::Transpose trans, blas::Diagonal diag, uint64 n,
863                           uint64 k, const DeviceMemory<std::complex<float>> &a,
864                           int lda, DeviceMemory<std::complex<float>> *x,
865                           int incx) = 0;
866   virtual bool DoBlasTbmv(Stream *stream, blas::UpperLower uplo,
867                           blas::Transpose trans, blas::Diagonal diag, uint64 n,
868                           uint64 k, const DeviceMemory<std::complex<double>> &a,
869                           int lda, DeviceMemory<std::complex<double>> *x,
870                           int incx) = 0;
871 
872   // Solves a system of linear equations whose coefficients are in a triangular
873   // band matrix as below:
874   //
875   //     a * x = b,
876   // or
877   //     a' * x = b,
878   // or
879   //     conj(a') * x = b,
880   //
881   // b and x are n-element vectors; a is an n-by-n unit, or non-unit, upper or
882   // lower triangular band matrix, with k+1 diagonals.
883   virtual bool DoBlasTbsv(Stream *stream, blas::UpperLower uplo,
884                           blas::Transpose trans, blas::Diagonal diag, uint64 n,
885                           uint64 k, const DeviceMemory<float> &a, int lda,
886                           DeviceMemory<float> *x, int incx) = 0;
887   virtual bool DoBlasTbsv(Stream *stream, blas::UpperLower uplo,
888                           blas::Transpose trans, blas::Diagonal diag, uint64 n,
889                           uint64 k, const DeviceMemory<double> &a, int lda,
890                           DeviceMemory<double> *x, int incx) = 0;
891   virtual bool DoBlasTbsv(Stream *stream, blas::UpperLower uplo,
892                           blas::Transpose trans, blas::Diagonal diag, uint64 n,
893                           uint64 k, const DeviceMemory<std::complex<float>> &a,
894                           int lda, DeviceMemory<std::complex<float>> *x,
895                           int incx) = 0;
896   virtual bool DoBlasTbsv(Stream *stream, blas::UpperLower uplo,
897                           blas::Transpose trans, blas::Diagonal diag, uint64 n,
898                           uint64 k, const DeviceMemory<std::complex<double>> &a,
899                           int lda, DeviceMemory<std::complex<double>> *x,
900                           int incx) = 0;
901 
902   // Computes a matrix-vector product using a triangular packed matrix.
903   //
904   //     x <- a * x,
905   // or
906   //     x <- a' * x,
907   // or
908   //     x <- conj(a') * x,
909   //
910   // a is an n-by-n unit, or non-unit, upper or lower triangular matrix,
911   // supplied in packed form; x is a n-element vector.
912   virtual bool DoBlasTpmv(Stream *stream, blas::UpperLower uplo,
913                           blas::Transpose trans, blas::Diagonal diag, uint64 n,
914                           const DeviceMemory<float> &ap, DeviceMemory<float> *x,
915                           int incx) = 0;
916   virtual bool DoBlasTpmv(Stream *stream, blas::UpperLower uplo,
917                           blas::Transpose trans, blas::Diagonal diag, uint64 n,
918                           const DeviceMemory<double> &ap,
919                           DeviceMemory<double> *x, int incx) = 0;
920   virtual bool DoBlasTpmv(Stream *stream, blas::UpperLower uplo,
921                           blas::Transpose trans, blas::Diagonal diag, uint64 n,
922                           const DeviceMemory<std::complex<float>> &ap,
923                           DeviceMemory<std::complex<float>> *x, int incx) = 0;
924   virtual bool DoBlasTpmv(Stream *stream, blas::UpperLower uplo,
925                           blas::Transpose trans, blas::Diagonal diag, uint64 n,
926                           const DeviceMemory<std::complex<double>> &ap,
927                           DeviceMemory<std::complex<double>> *x, int incx) = 0;
928 
929   // Solves a system of linear equations whose coefficients are in a triangular
930   // packed matrix as below:
931   //
932   //     a * x = b,
933   // or
934   //     a' * x = b,
935   // or
936   //     conj(a') * x = b,
937   //
938   // b and x are n-element vectors; a is an n-by-n unit, or non-unit, upper or
939   // lower triangular matrix, supplied in packed form.
940   virtual bool DoBlasTpsv(Stream *stream, blas::UpperLower uplo,
941                           blas::Transpose trans, blas::Diagonal diag, uint64 n,
942                           const DeviceMemory<float> &ap, DeviceMemory<float> *x,
943                           int incx) = 0;
944   virtual bool DoBlasTpsv(Stream *stream, blas::UpperLower uplo,
945                           blas::Transpose trans, blas::Diagonal diag, uint64 n,
946                           const DeviceMemory<double> &ap,
947                           DeviceMemory<double> *x, int incx) = 0;
948   virtual bool DoBlasTpsv(Stream *stream, blas::UpperLower uplo,
949                           blas::Transpose trans, blas::Diagonal diag, uint64 n,
950                           const DeviceMemory<std::complex<float>> &ap,
951                           DeviceMemory<std::complex<float>> *x, int incx) = 0;
952   virtual bool DoBlasTpsv(Stream *stream, blas::UpperLower uplo,
953                           blas::Transpose trans, blas::Diagonal diag, uint64 n,
954                           const DeviceMemory<std::complex<double>> &ap,
955                           DeviceMemory<std::complex<double>> *x, int incx) = 0;
956 
957   // Computes a matrix-vector product using a triangular matrix.
958   //
959   //     x <- a * x,
960   // or
961   //     x <- a' * x,
962   // or
963   //     x <- conj(a') * x,
964   //
965   // a is an n-by-n unit, or non-unit, upper or lower triangular matrix; x is a
966   // n-element vector.
967   virtual bool DoBlasTrmv(Stream *stream, blas::UpperLower uplo,
968                           blas::Transpose trans, blas::Diagonal diag, uint64 n,
969                           const DeviceMemory<float> &a, int lda,
970                           DeviceMemory<float> *x, int incx) = 0;
971   virtual bool DoBlasTrmv(Stream *stream, blas::UpperLower uplo,
972                           blas::Transpose trans, blas::Diagonal diag, uint64 n,
973                           const DeviceMemory<double> &a, int lda,
974                           DeviceMemory<double> *x, int incx) = 0;
975   virtual bool DoBlasTrmv(Stream *stream, blas::UpperLower uplo,
976                           blas::Transpose trans, blas::Diagonal diag, uint64 n,
977                           const DeviceMemory<std::complex<float>> &a, int lda,
978                           DeviceMemory<std::complex<float>> *x, int incx) = 0;
979   virtual bool DoBlasTrmv(Stream *stream, blas::UpperLower uplo,
980                           blas::Transpose trans, blas::Diagonal diag, uint64 n,
981                           const DeviceMemory<std::complex<double>> &a, int lda,
982                           DeviceMemory<std::complex<double>> *x, int incx) = 0;
983 
984   // Solves a system of linear equations whose coefficients are in a triangular
985   // matrix as below:
986   //
987   //     a * x = b,
988   // or
989   //     a' * x = b,
990   // or
991   //     conj(a') * x = b,
992   //
993   // b and x are n-element vectors; a is an n-by-n unit, or non-unit, upper or
994   // lower triangular matrix.
995   virtual bool DoBlasTrsv(Stream *stream, blas::UpperLower uplo,
996                           blas::Transpose trans, blas::Diagonal diag, uint64 n,
997                           const DeviceMemory<float> &a, int lda,
998                           DeviceMemory<float> *x, int incx) = 0;
999   virtual bool DoBlasTrsv(Stream *stream, blas::UpperLower uplo,
1000                           blas::Transpose trans, blas::Diagonal diag, uint64 n,
1001                           const DeviceMemory<double> &a, int lda,
1002                           DeviceMemory<double> *x, int incx) = 0;
1003   virtual bool DoBlasTrsv(Stream *stream, blas::UpperLower uplo,
1004                           blas::Transpose trans, blas::Diagonal diag, uint64 n,
1005                           const DeviceMemory<std::complex<float>> &a, int lda,
1006                           DeviceMemory<std::complex<float>> *x, int incx) = 0;
1007   virtual bool DoBlasTrsv(Stream *stream, blas::UpperLower uplo,
1008                           blas::Transpose trans, blas::Diagonal diag, uint64 n,
1009                           const DeviceMemory<std::complex<double>> &a, int lda,
1010                           DeviceMemory<std::complex<double>> *x, int incx) = 0;
1011 
1012   // Computes a matrix-matrix product with general matrices:
1013   //
1014   //     c <- alpha * op(a) * op(b) + beta * c,
1015   //
1016   // op(X) is one of op(X) = X, or op(X) = X', or op(X) = conj(X'); alpha and
1017   // beta are scalars; a, b, and c are matrices; op(a) is an m-by-k matrix;
1018   // op(b) is a k-by-n matrix; c is an m-by-n matrix.
1019   //
1020   // Note: The half interface uses float precision internally; the version
1021   // that uses half precision internally is not yet supported. There is no
1022   // batched version of the half-precision interface.
1023   virtual bool DoBlasGemm(Stream *stream, blas::Transpose transa,
1024                           blas::Transpose transb, uint64 m, uint64 n, uint64 k,
1025                           float alpha, const DeviceMemory<Eigen::half> &a,
1026                           int lda, const DeviceMemory<Eigen::half> &b, int ldb,
1027                           float beta, DeviceMemory<Eigen::half> *c,
1028                           int ldc) = 0;
1029   virtual bool DoBlasGemm(Stream *stream, blas::Transpose transa,
1030                           blas::Transpose transb, uint64 m, uint64 n, uint64 k,
1031                           float alpha, const DeviceMemory<float> &a, int lda,
1032                           const DeviceMemory<float> &b, int ldb, float beta,
1033                           DeviceMemory<float> *c, int ldc) = 0;
1034   virtual bool DoBlasGemm(Stream *stream, blas::Transpose transa,
1035                           blas::Transpose transb, uint64 m, uint64 n, uint64 k,
1036                           double alpha, const DeviceMemory<double> &a, int lda,
1037                           const DeviceMemory<double> &b, int ldb, double beta,
1038                           DeviceMemory<double> *c, int ldc) = 0;
1039   virtual bool DoBlasGemm(Stream *stream, blas::Transpose transa,
1040                           blas::Transpose transb, uint64 m, uint64 n, uint64 k,
1041                           std::complex<float> alpha,
1042                           const DeviceMemory<std::complex<float>> &a, int lda,
1043                           const DeviceMemory<std::complex<float>> &b, int ldb,
1044                           std::complex<float> beta,
1045                           DeviceMemory<std::complex<float>> *c, int ldc) = 0;
1046   virtual bool DoBlasGemm(Stream *stream, blas::Transpose transa,
1047                           blas::Transpose transb, uint64 m, uint64 n, uint64 k,
1048                           std::complex<double> alpha,
1049                           const DeviceMemory<std::complex<double>> &a, int lda,
1050                           const DeviceMemory<std::complex<double>> &b, int ldb,
1051                           std::complex<double> beta,
1052                           DeviceMemory<std::complex<double>> *c, int ldc) = 0;
1053 
1054   virtual bool DoBlasGemmWithProfiling(
1055       Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m,
1056       uint64 n, uint64 k, float alpha, const DeviceMemory<Eigen::half> &a,
1057       int lda, const DeviceMemory<Eigen::half> &b, int ldb, float beta,
1058       DeviceMemory<Eigen::half> *c, int ldc,
1059       ProfileResult *output_profile_result) = 0;
1060   virtual bool DoBlasGemmWithProfiling(
1061       Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m,
1062       uint64 n, uint64 k, float alpha, const DeviceMemory<float> &a, int lda,
1063       const DeviceMemory<float> &b, int ldb, float beta, DeviceMemory<float> *c,
1064       int ldc, ProfileResult *output_profile_result) = 0;
1065   virtual bool DoBlasGemmWithProfiling(
1066       Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m,
1067       uint64 n, uint64 k, double alpha, const DeviceMemory<double> &a, int lda,
1068       const DeviceMemory<double> &b, int ldb, double beta,
1069       DeviceMemory<double> *c, int ldc,
1070       ProfileResult *output_profile_result) = 0;
1071   virtual bool DoBlasGemmWithProfiling(
1072       Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m,
1073       uint64 n, uint64 k, std::complex<float> alpha,
1074       const DeviceMemory<std::complex<float>> &a, int lda,
1075       const DeviceMemory<std::complex<float>> &b, int ldb,
1076       std::complex<float> beta, DeviceMemory<std::complex<float>> *c, int ldc,
1077       ProfileResult *output_profile_result) = 0;
1078   virtual bool DoBlasGemmWithProfiling(
1079       Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m,
1080       uint64 n, uint64 k, std::complex<double> alpha,
1081       const DeviceMemory<std::complex<double>> &a, int lda,
1082       const DeviceMemory<std::complex<double>> &b, int ldb,
1083       std::complex<double> beta, DeviceMemory<std::complex<double>> *c, int ldc,
1084       ProfileResult *output_profile_result) = 0;
1085 
1086   // Gets a list of supported algorithms for DoBlasGemmWithAlgorithm.
1087   virtual bool GetBlasGemmAlgorithms(
1088       std::vector<AlgorithmType> *out_algorithms) = 0;
1089 
1090   // Like DoBlasGemm, but accepts an algorithm and an compute type.
1091   //
1092   // The compute type lets you say (e.g.) that the inputs and outputs are
1093   // Eigen::halfs, but you want the internal computations to be done with
1094   // float32 precision.
1095   //
1096   // Note the subtle difference in the version that accepts Eigen:::half --
1097   // alpha and beta have type const Eigen::half&, not float.
1098   //
1099   // If output_profile_result is not null, a failure here does not put the
1100   // stream in a failure state.  Instead, success/failure is indicated by
1101   // output_profile_result->is_valid().  This lets you use this function for
1102   // choosing the best algorithm among many (some of which may fail) without
1103   // creating a new Stream for each attempt.
1104   virtual bool DoBlasGemmWithAlgorithm(
1105       Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m,
1106       uint64 n, uint64 k, const HostOrDeviceScalar<int> &alpha,
1107       const DeviceMemory<int8> &a, int lda, const DeviceMemory<int8> &b,
1108       int ldb, const HostOrDeviceScalar<int> &beta, DeviceMemory<int32> *c,
1109       int ldc, ComputationType computation_type, AlgorithmType algorithm,
1110       ProfileResult *output_profile_result) = 0;
1111   virtual bool DoBlasGemmWithAlgorithm(
1112       Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m,
1113       uint64 n, uint64 k, const HostOrDeviceScalar<Eigen::half> &alpha,
1114       const DeviceMemory<Eigen::half> &a, int lda,
1115       const DeviceMemory<Eigen::half> &b, int ldb,
1116       const HostOrDeviceScalar<Eigen::half> &beta, DeviceMemory<Eigen::half> *c,
1117       int ldc, ComputationType computation_type, AlgorithmType algorithm,
1118       ProfileResult *output_profile_result) = 0;
1119   virtual bool DoBlasGemmWithAlgorithm(
1120       Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m,
1121       uint64 n, uint64 k, const HostOrDeviceScalar<float> &alpha,
1122       const DeviceMemory<float> &a, int lda, const DeviceMemory<float> &b,
1123       int ldb, const HostOrDeviceScalar<float> &beta, DeviceMemory<float> *c,
1124       int ldc, ComputationType computation_type, AlgorithmType algorithm,
1125       ProfileResult *output_profile_result) = 0;
1126   virtual bool DoBlasGemmWithAlgorithm(
1127       Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m,
1128       uint64 n, uint64 k, const HostOrDeviceScalar<double> &alpha,
1129       const DeviceMemory<double> &a, int lda, const DeviceMemory<double> &b,
1130       int ldb, const HostOrDeviceScalar<double> &beta, DeviceMemory<double> *c,
1131       int ldc, ComputationType computation_type, AlgorithmType algorithm,
1132       ProfileResult *output_profile_result) = 0;
1133   virtual bool DoBlasGemmWithAlgorithm(
1134       Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m,
1135       uint64 n, uint64 k, const HostOrDeviceScalar<std::complex<float>> &alpha,
1136       const DeviceMemory<std::complex<float>> &a, int lda,
1137       const DeviceMemory<std::complex<float>> &b, int ldb,
1138       const HostOrDeviceScalar<std::complex<float>> &beta,
1139       DeviceMemory<std::complex<float>> *c, int ldc,
1140       ComputationType computation_type, AlgorithmType algorithm,
1141       ProfileResult *output_profile_result) = 0;
1142   virtual bool DoBlasGemmWithAlgorithm(
1143       Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m,
1144       uint64 n, uint64 k, const HostOrDeviceScalar<std::complex<double>> &alpha,
1145       const DeviceMemory<std::complex<double>> &a, int lda,
1146       const DeviceMemory<std::complex<double>> &b, int ldb,
1147       const HostOrDeviceScalar<std::complex<double>> &beta,
1148       DeviceMemory<std::complex<double>> *c, int ldc,
1149       ComputationType computation_type, AlgorithmType algorithm,
1150       ProfileResult *output_profile_result) = 0;
1151 
1152   // Computes a batch of matrix-matrix product with general matrices.
1153   // This is a batched version of DoBlasGemm.
1154   // The batched GEMM computes matrix product for each input/output in a, b,
1155   // and c, which contain batch_count DeviceMemory objects.
1156   virtual bool DoBlasGemmBatched(
1157       Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m,
1158       uint64 n, uint64 k, float alpha,
1159       const port::ArraySlice<DeviceMemory<Eigen::half> *> &a, int lda,
1160       const port::ArraySlice<DeviceMemory<Eigen::half> *> &b, int ldb,
1161       float beta, const port::ArraySlice<DeviceMemory<Eigen::half> *> &c,
1162       int ldc, int batch_count, ScratchAllocator *scratch_allocator) = 0;
1163   virtual bool DoBlasGemmBatched(
1164       Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m,
1165       uint64 n, uint64 k, float alpha,
1166       const port::ArraySlice<DeviceMemory<float> *> &a, int lda,
1167       const port::ArraySlice<DeviceMemory<float> *> &b, int ldb, float beta,
1168       const port::ArraySlice<DeviceMemory<float> *> &c, int ldc,
1169       int batch_count, ScratchAllocator *scratch_allocator) = 0;
1170   virtual bool DoBlasGemmBatched(
1171       Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m,
1172       uint64 n, uint64 k, double alpha,
1173       const port::ArraySlice<DeviceMemory<double> *> &a, int lda,
1174       const port::ArraySlice<DeviceMemory<double> *> &b, int ldb, double beta,
1175       const port::ArraySlice<DeviceMemory<double> *> &c, int ldc,
1176       int batch_count, ScratchAllocator *scratch_allocator) = 0;
1177   virtual bool DoBlasGemmBatched(
1178       Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m,
1179       uint64 n, uint64 k, std::complex<float> alpha,
1180       const port::ArraySlice<DeviceMemory<std::complex<float>> *> &a, int lda,
1181       const port::ArraySlice<DeviceMemory<std::complex<float>> *> &b, int ldb,
1182       std::complex<float> beta,
1183       const port::ArraySlice<DeviceMemory<std::complex<float>> *> &c, int ldc,
1184       int batch_count, ScratchAllocator *scratch_allocator) = 0;
1185   virtual bool DoBlasGemmBatched(
1186       Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m,
1187       uint64 n, uint64 k, std::complex<double> alpha,
1188       const port::ArraySlice<DeviceMemory<std::complex<double>> *> &a, int lda,
1189       const port::ArraySlice<DeviceMemory<std::complex<double>> *> &b, int ldb,
1190       std::complex<double> beta,
1191       const port::ArraySlice<DeviceMemory<std::complex<double>> *> &c, int ldc,
1192       int batch_count, ScratchAllocator *scratch_allocator) = 0;
1193 
1194   // Batched gemm with strides instead of pointer arrays.
1195   virtual bool DoBlasGemmStridedBatched(
1196       Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m,
1197       uint64 n, uint64 k, float alpha, const DeviceMemory<Eigen::half> &a,
1198       int lda, int64 stride_a, const DeviceMemory<Eigen::half> &b, int ldb,
1199       int64 stride_b, float beta, DeviceMemory<Eigen::half> *c, int ldc,
1200       int64 stride_c, int batch_count) = 0;
1201   virtual bool DoBlasGemmStridedBatched(
1202       Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m,
1203       uint64 n, uint64 k, float alpha, const DeviceMemory<float> &a, int lda,
1204       int64 stride_a, const DeviceMemory<float> &b, int ldb, int64 stride_b,
1205       float beta, DeviceMemory<float> *c, int ldc, int64 stride_c,
1206       int batch_count) = 0;
1207   virtual bool DoBlasGemmStridedBatched(
1208       Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m,
1209       uint64 n, uint64 k, double alpha, const DeviceMemory<double> &a, int lda,
1210       int64 stride_a, const DeviceMemory<double> &b, int ldb, int64 stride_b,
1211       double beta, DeviceMemory<double> *c, int ldc, int64 stride_c,
1212       int batch_count) = 0;
1213   virtual bool DoBlasGemmStridedBatched(
1214       Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m,
1215       uint64 n, uint64 k, std::complex<float> alpha,
1216       const DeviceMemory<std::complex<float>> &a, int lda, int64 stride_a,
1217       const DeviceMemory<std::complex<float>> &b, int ldb, int64 stride_b,
1218       std::complex<float> beta, DeviceMemory<std::complex<float>> *c, int ldc,
1219       int64 stride_c, int batch_count) = 0;
1220   virtual bool DoBlasGemmStridedBatched(
1221       Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m,
1222       uint64 n, uint64 k, std::complex<double> alpha,
1223       const DeviceMemory<std::complex<double>> &a, int lda, int64 stride_a,
1224       const DeviceMemory<std::complex<double>> &b, int ldb, int64 stride_b,
1225       std::complex<double> beta, DeviceMemory<std::complex<double>> *c, int ldc,
1226       int64 stride_c, int batch_count) = 0;
1227 
1228   // Computes a matrix-matrix product where one input matrix is Hermitian:
1229   //
1230   //     c <- alpha * a * b + beta * c,
1231   // or
1232   //     c <- alpha * b * a + beta * c,
1233   //
1234   // alpha and beta are scalars; a is a Hermitian matrix; b and c are m-by-n
1235   // matrices.
1236   virtual bool DoBlasHemm(Stream *stream, blas::Side side,
1237                           blas::UpperLower uplo, uint64 m, uint64 n,
1238                           std::complex<float> alpha,
1239                           const DeviceMemory<std::complex<float>> &a, int lda,
1240                           const DeviceMemory<std::complex<float>> &b, int ldb,
1241                           std::complex<float> beta,
1242                           DeviceMemory<std::complex<float>> *c, int ldc) = 0;
1243   virtual bool DoBlasHemm(Stream *stream, blas::Side side,
1244                           blas::UpperLower uplo, uint64 m, uint64 n,
1245                           std::complex<double> alpha,
1246                           const DeviceMemory<std::complex<double>> &a, int lda,
1247                           const DeviceMemory<std::complex<double>> &b, int ldb,
1248                           std::complex<double> beta,
1249                           DeviceMemory<std::complex<double>> *c, int ldc) = 0;
1250 
1251   // Performs a Hermitian rank-k update.
1252   //
1253   //     c <- alpha * a * conj(a') + beta * c,
1254   // or
1255   //     c <- alpha * conj(a') * a + beta * c,
1256   //
1257   // alpha and beta are scalars; c is a n-by-n Hermitian matrix; a is an n-by-k
1258   // matrix in the first case and a k-by-n matrix in the second case.
1259   virtual bool DoBlasHerk(Stream *stream, blas::UpperLower uplo,
1260                           blas::Transpose trans, uint64 n, uint64 k,
1261                           float alpha,
1262                           const DeviceMemory<std::complex<float>> &a, int lda,
1263                           float beta, DeviceMemory<std::complex<float>> *c,
1264                           int ldc) = 0;
1265   virtual bool DoBlasHerk(Stream *stream, blas::UpperLower uplo,
1266                           blas::Transpose trans, uint64 n, uint64 k,
1267                           double alpha,
1268                           const DeviceMemory<std::complex<double>> &a, int lda,
1269                           double beta, DeviceMemory<std::complex<double>> *c,
1270                           int ldc) = 0;
1271 
1272   // Performs a Hermitian rank-2k update.
1273   //
1274   //     c <- alpha * a * conj(b') + conj(alpha) * b * conj(a') + beta * c,
1275   // or
1276   //     c <- alpha * conj(b') * a + conj(alpha) * conj(a') * b + beta * c,
1277   //
1278   // alpha and beta are scalars; c is a n-by-n Hermitian matrix; a and b are
1279   // n-by-k matrices in the first case and k-by-n matrices in the second case.
1280   virtual bool DoBlasHer2k(Stream *stream, blas::UpperLower uplo,
1281                            blas::Transpose trans, uint64 n, uint64 k,
1282                            std::complex<float> alpha,
1283                            const DeviceMemory<std::complex<float>> &a, int lda,
1284                            const DeviceMemory<std::complex<float>> &b, int ldb,
1285                            float beta, DeviceMemory<std::complex<float>> *c,
1286                            int ldc) = 0;
1287   virtual bool DoBlasHer2k(Stream *stream, blas::UpperLower uplo,
1288                            blas::Transpose trans, uint64 n, uint64 k,
1289                            std::complex<double> alpha,
1290                            const DeviceMemory<std::complex<double>> &a, int lda,
1291                            const DeviceMemory<std::complex<double>> &b, int ldb,
1292                            double beta, DeviceMemory<std::complex<double>> *c,
1293                            int ldc) = 0;
1294 
1295   // Computes a matrix-matrix product where one input matrix is symmetric.
1296   //
1297   //     c <- alpha * a * b + beta * c,
1298   // or
1299   //     c <- alpha * b * a + beta * c,
1300   //
1301   // alpha and beta are scalars; a is a symmetric matrix; b and c are m-by-n
1302   // matrices.
1303   virtual bool DoBlasSymm(Stream *stream, blas::Side side,
1304                           blas::UpperLower uplo, uint64 m, uint64 n,
1305                           float alpha, const DeviceMemory<float> &a, int lda,
1306                           const DeviceMemory<float> &b, int ldb, float beta,
1307                           DeviceMemory<float> *c, int ldc) = 0;
1308   virtual bool DoBlasSymm(Stream *stream, blas::Side side,
1309                           blas::UpperLower uplo, uint64 m, uint64 n,
1310                           double alpha, const DeviceMemory<double> &a, int lda,
1311                           const DeviceMemory<double> &b, int ldb, double beta,
1312                           DeviceMemory<double> *c, int ldc) = 0;
1313   virtual bool DoBlasSymm(Stream *stream, blas::Side side,
1314                           blas::UpperLower uplo, uint64 m, uint64 n,
1315                           std::complex<float> alpha,
1316                           const DeviceMemory<std::complex<float>> &a, int lda,
1317                           const DeviceMemory<std::complex<float>> &b, int ldb,
1318                           std::complex<float> beta,
1319                           DeviceMemory<std::complex<float>> *c, int ldc) = 0;
1320   virtual bool DoBlasSymm(Stream *stream, blas::Side side,
1321                           blas::UpperLower uplo, uint64 m, uint64 n,
1322                           std::complex<double> alpha,
1323                           const DeviceMemory<std::complex<double>> &a, int lda,
1324                           const DeviceMemory<std::complex<double>> &b, int ldb,
1325                           std::complex<double> beta,
1326                           DeviceMemory<std::complex<double>> *c, int ldc) = 0;
1327 
1328   // Performs a symmetric rank-k update.
1329   //
1330   //     c <- alpha * a * a' + beta * c,
1331   // or
1332   //     c <- alpha * a' * a + beta * c,
1333   //
1334   // alpha and beta are scalars; c is a n-by-n symmetric matrix; a is an n-by-k
1335   // matrix in the first case and a k-by-n matrix in the second case.
1336   virtual bool DoBlasSyrk(Stream *stream, blas::UpperLower uplo,
1337                           blas::Transpose trans, uint64 n, uint64 k,
1338                           float alpha, const DeviceMemory<float> &a, int lda,
1339                           float beta, DeviceMemory<float> *c, int ldc) = 0;
1340   virtual bool DoBlasSyrk(Stream *stream, blas::UpperLower uplo,
1341                           blas::Transpose trans, uint64 n, uint64 k,
1342                           double alpha, const DeviceMemory<double> &a, int lda,
1343                           double beta, DeviceMemory<double> *c, int ldc) = 0;
1344   virtual bool DoBlasSyrk(Stream *stream, blas::UpperLower uplo,
1345                           blas::Transpose trans, uint64 n, uint64 k,
1346                           std::complex<float> alpha,
1347                           const DeviceMemory<std::complex<float>> &a, int lda,
1348                           std::complex<float> beta,
1349                           DeviceMemory<std::complex<float>> *c, int ldc) = 0;
1350   virtual bool DoBlasSyrk(Stream *stream, blas::UpperLower uplo,
1351                           blas::Transpose trans, uint64 n, uint64 k,
1352                           std::complex<double> alpha,
1353                           const DeviceMemory<std::complex<double>> &a, int lda,
1354                           std::complex<double> beta,
1355                           DeviceMemory<std::complex<double>> *c, int ldc) = 0;
1356 
1357   // Performs a symmetric rank-2k update.
1358   //
1359   //     c <- alpha * a * b' + alpha * b * a' + beta * c,
1360   // or
1361   //     c <- alpha * b' * a + alpha * a' * b + beta * c,
1362   //
1363   // alpha and beta are scalars; c is a n-by-n symmetric matrix; a and b are
1364   // n-by-k matrices in the first case and k-by-n matrices in the second case.
1365   virtual bool DoBlasSyr2k(Stream *stream, blas::UpperLower uplo,
1366                            blas::Transpose trans, uint64 n, uint64 k,
1367                            float alpha, const DeviceMemory<float> &a, int lda,
1368                            const DeviceMemory<float> &b, int ldb, float beta,
1369                            DeviceMemory<float> *c, int ldc) = 0;
1370   virtual bool DoBlasSyr2k(Stream *stream, blas::UpperLower uplo,
1371                            blas::Transpose trans, uint64 n, uint64 k,
1372                            double alpha, const DeviceMemory<double> &a, int lda,
1373                            const DeviceMemory<double> &b, int ldb, double beta,
1374                            DeviceMemory<double> *c, int ldc) = 0;
1375   virtual bool DoBlasSyr2k(Stream *stream, blas::UpperLower uplo,
1376                            blas::Transpose trans, uint64 n, uint64 k,
1377                            std::complex<float> alpha,
1378                            const DeviceMemory<std::complex<float>> &a, int lda,
1379                            const DeviceMemory<std::complex<float>> &b, int ldb,
1380                            std::complex<float> beta,
1381                            DeviceMemory<std::complex<float>> *c, int ldc) = 0;
1382   virtual bool DoBlasSyr2k(Stream *stream, blas::UpperLower uplo,
1383                            blas::Transpose trans, uint64 n, uint64 k,
1384                            std::complex<double> alpha,
1385                            const DeviceMemory<std::complex<double>> &a, int lda,
1386                            const DeviceMemory<std::complex<double>> &b, int ldb,
1387                            std::complex<double> beta,
1388                            DeviceMemory<std::complex<double>> *c, int ldc) = 0;
1389 
1390   // Computes a matrix-matrix product where one input matrix is triangular.
1391   //
1392   //     b <- alpha * op(a) * b,
1393   // or
1394   //     b <- alpha * b * op(a)
1395   //
1396   // alpha is a scalar; b is an m-by-n matrix; a is a unit, or non-unit, upper
1397   // or lower triangular matrix; op(a) is one of op(a) = a, or op(a) = a', or
1398   // op(a) = conj(a').
1399   virtual bool DoBlasTrmm(Stream *stream, blas::Side side,
1400                           blas::UpperLower uplo, blas::Transpose transa,
1401                           blas::Diagonal diag, uint64 m, uint64 n, float alpha,
1402                           const DeviceMemory<float> &a, int lda,
1403                           DeviceMemory<float> *b, int ldb) = 0;
1404   virtual bool DoBlasTrmm(Stream *stream, blas::Side side,
1405                           blas::UpperLower uplo, blas::Transpose transa,
1406                           blas::Diagonal diag, uint64 m, uint64 n, double alpha,
1407                           const DeviceMemory<double> &a, int lda,
1408                           DeviceMemory<double> *b, int ldb) = 0;
1409   virtual bool DoBlasTrmm(Stream *stream, blas::Side side,
1410                           blas::UpperLower uplo, blas::Transpose transa,
1411                           blas::Diagonal diag, uint64 m, uint64 n,
1412                           std::complex<float> alpha,
1413                           const DeviceMemory<std::complex<float>> &a, int lda,
1414                           DeviceMemory<std::complex<float>> *b, int ldb) = 0;
1415   virtual bool DoBlasTrmm(Stream *stream, blas::Side side,
1416                           blas::UpperLower uplo, blas::Transpose transa,
1417                           blas::Diagonal diag, uint64 m, uint64 n,
1418                           std::complex<double> alpha,
1419                           const DeviceMemory<std::complex<double>> &a, int lda,
1420                           DeviceMemory<std::complex<double>> *b, int ldb) = 0;
1421 
1422   // Solves a triangular matrix equation.
1423   //
1424   //     op(a) * x = alpha * b,
1425   // or
1426   //     x * op(a) = alpha * b
1427   //
1428   // alpha is a scalar; x and b are m-by-n matrices; a is a unit, or non-unit,
1429   // upper or lower triangular matrix; op(a) is one of op(a) = a, or op(a) = a',
1430   // or op(a) = conj(a').
1431   virtual bool DoBlasTrsm(Stream *stream, blas::Side side,
1432                           blas::UpperLower uplo, blas::Transpose transa,
1433                           blas::Diagonal diag, uint64 m, uint64 n, float alpha,
1434                           const DeviceMemory<float> &a, int lda,
1435                           DeviceMemory<float> *b, int ldb) = 0;
1436   virtual bool DoBlasTrsm(Stream *stream, blas::Side side,
1437                           blas::UpperLower uplo, blas::Transpose transa,
1438                           blas::Diagonal diag, uint64 m, uint64 n, double alpha,
1439                           const DeviceMemory<double> &a, int lda,
1440                           DeviceMemory<double> *b, int ldb) = 0;
1441   virtual bool DoBlasTrsm(Stream *stream, blas::Side side,
1442                           blas::UpperLower uplo, blas::Transpose transa,
1443                           blas::Diagonal diag, uint64 m, uint64 n,
1444                           std::complex<float> alpha,
1445                           const DeviceMemory<std::complex<float>> &a, int lda,
1446                           DeviceMemory<std::complex<float>> *b, int ldb) = 0;
1447   virtual bool DoBlasTrsm(Stream *stream, blas::Side side,
1448                           blas::UpperLower uplo, blas::Transpose transa,
1449                           blas::Diagonal diag, uint64 m, uint64 n,
1450                           std::complex<double> alpha,
1451                           const DeviceMemory<std::complex<double>> &a, int lda,
1452                           DeviceMemory<std::complex<double>> *b, int ldb) = 0;
1453 
1454   // Creates a backend-specific plan object for a blaslt matmul operation, which
1455   // can then be passed to DoBlasLtMatmul(). When possible, plans should be
1456   // created once and reused for multiple calls to DoBlasLtMatmul().
1457   virtual port::StatusOr<std::unique_ptr<blas::IBlasLtMatmulPlan>>
1458   CreateBlasLtMatmulPlan(const blas::BlasLtMatmulPlanParams &params) = 0;
1459 
1460   // Gets a list of supported algorithms for DoBlasLtMatmul. The algorithms are
1461   // returned in the order of increasing estimated compute time according to an
1462   // internal heuristic. The first returned algorithm can be used as the default
1463   // algorithm if no autotuning is to be performed.
1464   virtual port::StatusOr<
1465       std::vector<std::unique_ptr<blas::IBlasLtMatmulAlgorithm>>>
1466   GetBlasLtMatmulAlgorithms(const blas::IBlasLtMatmulPlan *plan,
1467                             size_t max_workspace_size,
1468                             int max_algorithm_count) = 0;
1469 
1470   // Executes a blaslt matmul operation on the stream. If output_profile_result
1471   // is not nullptr, the operation is profiled, error messages are
1472   // suppressed, and output_profile_result->algorithm() is set to
1473   // algorithm->index(). If epilogue was set to kBias or kBiasThenReLU when
1474   // creating the plan, the bias argument here must refer to a valid device
1475   // vector of length equal to the number of rows in matrix c. If epilogue was
1476   // set to any other value then the bias argument here must be null. The bias
1477   // vector is broadcast across the batch dimension.
1478   // Note that the data types of a and b (c and bias) must match the ab_type
1479   // (c_type) with which the plan was created, and the data types of alpha and
1480   // beta must match the data type of c.
1481   virtual bool DoBlasLtMatmul(
1482       Stream *stream, const blas::IBlasLtMatmulPlan *plan,
1483       const HostOrDeviceScalar<void> &alpha, DeviceMemoryBase a,
1484       DeviceMemoryBase b, const HostOrDeviceScalar<void> &beta,
1485       DeviceMemoryBase c, ScratchAllocator *scratch_allocator,
1486       const blas::IBlasLtMatmulAlgorithm *algorithm, DeviceMemoryBase bias,
1487       blas::ProfileResult *output_profile_result) = 0;
1488 
1489   template <typename ABType, typename CType>
1490   bool DoBlasLtMatmul(Stream *stream, const blas::IBlasLtMatmulPlan *plan,
1491                       const HostOrDeviceScalar<CType> &alpha,
1492                       const DeviceMemory<ABType> &a,
1493                       const DeviceMemory<ABType> &b,
1494                       const HostOrDeviceScalar<CType> &beta,
1495                       DeviceMemory<CType> *c,
1496                       ScratchAllocator *scratch_allocator,
1497                       const blas::IBlasLtMatmulAlgorithm *algorithm,
1498                       const DeviceMemory<CType> &bias = {},
1499                       blas::ProfileResult *output_profile_result = nullptr) {
1500     constexpr blas::DataType ab_type = blas::ToDataType<ABType>::value;
1501     if (ab_type != plan->ab_type()) {
1502       VLOG(2) << "DoBlasLtMatmul returning false because a and b type does "
1503                  "not match plan: expected "
1504               << plan->ab_type() << ", got " << ab_type;
1505       return false;
1506     }
1507     constexpr blas::DataType c_type = blas::ToDataType<CType>::value;
1508     if (c_type != plan->c_type()) {
1509       VLOG(2) << "DoBlasLtMatmul returning false because c type does "
1510                  "not match plan: expected "
1511               << plan->c_type() << ", got " << c_type;
1512       return false;
1513     }
1514     return DoBlasLtMatmul(stream, plan, alpha, a, b, beta, *c,
1515                           scratch_allocator, algorithm, bias,
1516                           output_profile_result);
1517   }
1518 
1519   virtual port::Status GetVersion(std::string *version) = 0;
1520 
1521  protected:
BlasSupport()1522   BlasSupport() {}
1523 
1524  private:
1525   SE_DISALLOW_COPY_AND_ASSIGN(BlasSupport);
1526 };
1527 
1528 // Macro used to quickly declare overrides for abstract virtuals in the
1529 // BlasSupport base class.
1530 #define TENSORFLOW_STREAM_EXECUTOR_GPU_BLAS_SUPPORT_OVERRIDES                  \
1531   bool DoBlasAsum(Stream *stream, uint64 elem_count,                           \
1532                   const DeviceMemory<float> &x, int incx,                      \
1533                   DeviceMemory<float> *result) override;                       \
1534   bool DoBlasAsum(Stream *stream, uint64 elem_count,                           \
1535                   const DeviceMemory<double> &x, int incx,                     \
1536                   DeviceMemory<double> *result) override;                      \
1537   bool DoBlasAsum(Stream *stream, uint64 elem_count,                           \
1538                   const DeviceMemory<std::complex<float>> &x, int incx,        \
1539                   DeviceMemory<float> *result) override;                       \
1540   bool DoBlasAsum(Stream *stream, uint64 elem_count,                           \
1541                   const DeviceMemory<std::complex<double>> &x, int incx,       \
1542                   DeviceMemory<double> *result) override;                      \
1543   bool DoBlasAxpy(Stream *stream, uint64 elem_count, float alpha,              \
1544                   const DeviceMemory<float> &x, int incx,                      \
1545                   DeviceMemory<float> *y, int incy) override;                  \
1546   bool DoBlasAxpy(Stream *stream, uint64 elem_count, double alpha,             \
1547                   const DeviceMemory<double> &x, int incx,                     \
1548                   DeviceMemory<double> *y, int incy) override;                 \
1549   bool DoBlasAxpy(Stream *stream, uint64 elem_count,                           \
1550                   std::complex<float> alpha,                                   \
1551                   const DeviceMemory<std::complex<float>> &x, int incx,        \
1552                   DeviceMemory<std::complex<float>> *y, int incy) override;    \
1553   bool DoBlasAxpy(Stream *stream, uint64 elem_count,                           \
1554                   std::complex<double> alpha,                                  \
1555                   const DeviceMemory<std::complex<double>> &x, int incx,       \
1556                   DeviceMemory<std::complex<double>> *y, int incy) override;   \
1557   bool DoBlasCopy(Stream *stream, uint64 elem_count,                           \
1558                   const DeviceMemory<float> &x, int incx,                      \
1559                   DeviceMemory<float> *y, int incy) override;                  \
1560   bool DoBlasCopy(Stream *stream, uint64 elem_count,                           \
1561                   const DeviceMemory<double> &x, int incx,                     \
1562                   DeviceMemory<double> *y, int incy) override;                 \
1563   bool DoBlasCopy(Stream *stream, uint64 elem_count,                           \
1564                   const DeviceMemory<std::complex<float>> &x, int incx,        \
1565                   DeviceMemory<std::complex<float>> *y, int incy) override;    \
1566   bool DoBlasCopy(Stream *stream, uint64 elem_count,                           \
1567                   const DeviceMemory<std::complex<double>> &x, int incx,       \
1568                   DeviceMemory<std::complex<double>> *y, int incy) override;   \
1569   bool DoBlasDot(Stream *stream, uint64 elem_count,                            \
1570                  const DeviceMemory<float> &x, int incx,                       \
1571                  const DeviceMemory<float> &y, int incy,                       \
1572                  DeviceMemory<float> *result) override;                        \
1573   bool DoBlasDot(Stream *stream, uint64 elem_count,                            \
1574                  const DeviceMemory<double> &x, int incx,                      \
1575                  const DeviceMemory<double> &y, int incy,                      \
1576                  DeviceMemory<double> *result) override;                       \
1577   bool DoBlasDotc(Stream *stream, uint64 elem_count,                           \
1578                   const DeviceMemory<std::complex<float>> &x, int incx,        \
1579                   const DeviceMemory<std::complex<float>> &y, int incy,        \
1580                   DeviceMemory<std::complex<float>> *result) override;         \
1581   bool DoBlasDotc(Stream *stream, uint64 elem_count,                           \
1582                   const DeviceMemory<std::complex<double>> &x, int incx,       \
1583                   const DeviceMemory<std::complex<double>> &y, int incy,       \
1584                   DeviceMemory<std::complex<double>> *result) override;        \
1585   bool DoBlasDotu(Stream *stream, uint64 elem_count,                           \
1586                   const DeviceMemory<std::complex<float>> &x, int incx,        \
1587                   const DeviceMemory<std::complex<float>> &y, int incy,        \
1588                   DeviceMemory<std::complex<float>> *result) override;         \
1589   bool DoBlasDotu(Stream *stream, uint64 elem_count,                           \
1590                   const DeviceMemory<std::complex<double>> &x, int incx,       \
1591                   const DeviceMemory<std::complex<double>> &y, int incy,       \
1592                   DeviceMemory<std::complex<double>> *result) override;        \
1593   bool DoBlasNrm2(Stream *stream, uint64 elem_count,                           \
1594                   const DeviceMemory<float> &x, int incx,                      \
1595                   DeviceMemory<float> *result) override;                       \
1596   bool DoBlasNrm2(Stream *stream, uint64 elem_count,                           \
1597                   const DeviceMemory<double> &x, int incx,                     \
1598                   DeviceMemory<double> *result) override;                      \
1599   bool DoBlasNrm2(Stream *stream, uint64 elem_count,                           \
1600                   const DeviceMemory<std::complex<float>> &x, int incx,        \
1601                   DeviceMemory<float> *result) override;                       \
1602   bool DoBlasNrm2(Stream *stream, uint64 elem_count,                           \
1603                   const DeviceMemory<std::complex<double>> &x, int incx,       \
1604                   DeviceMemory<double> *result) override;                      \
1605   bool DoBlasRot(Stream *stream, uint64 elem_count, DeviceMemory<float> *x,    \
1606                  int incx, DeviceMemory<float> *y, int incy, float c, float s) \
1607       override;                                                                \
1608   bool DoBlasRot(Stream *stream, uint64 elem_count, DeviceMemory<double> *x,   \
1609                  int incx, DeviceMemory<double> *y, int incy, double c,        \
1610                  double s) override;                                           \
1611   bool DoBlasRot(Stream *stream, uint64 elem_count,                            \
1612                  DeviceMemory<std::complex<float>> *x, int incx,               \
1613                  DeviceMemory<std::complex<float>> *y, int incy, float c,      \
1614                  float s) override;                                            \
1615   bool DoBlasRot(Stream *stream, uint64 elem_count,                            \
1616                  DeviceMemory<std::complex<double>> *x, int incx,              \
1617                  DeviceMemory<std::complex<double>> *y, int incy, double c,    \
1618                  double s) override;                                           \
1619   bool DoBlasRotg(Stream *stream, DeviceMemory<float> *a,                      \
1620                   DeviceMemory<float> *b, DeviceMemory<float> *c,              \
1621                   DeviceMemory<float> *s) override;                            \
1622   bool DoBlasRotg(Stream *stream, DeviceMemory<double> *a,                     \
1623                   DeviceMemory<double> *b, DeviceMemory<double> *c,            \
1624                   DeviceMemory<double> *s) override;                           \
1625   bool DoBlasRotg(Stream *stream, DeviceMemory<std::complex<float>> *a,        \
1626                   DeviceMemory<std::complex<float>> *b,                        \
1627                   DeviceMemory<float> *c,                                      \
1628                   DeviceMemory<std::complex<float>> *s) override;              \
1629   bool DoBlasRotg(Stream *stream, DeviceMemory<std::complex<double>> *a,       \
1630                   DeviceMemory<std::complex<double>> *b,                       \
1631                   DeviceMemory<double> *c,                                     \
1632                   DeviceMemory<std::complex<double>> *s) override;             \
1633   bool DoBlasRotm(Stream *stream, uint64 elem_count, DeviceMemory<float> *x,   \
1634                   int incx, DeviceMemory<float> *y, int incy,                  \
1635                   const DeviceMemory<float> &param) override;                  \
1636   bool DoBlasRotm(Stream *stream, uint64 elem_count, DeviceMemory<double> *x,  \
1637                   int incx, DeviceMemory<double> *y, int incy,                 \
1638                   const DeviceMemory<double> &param) override;                 \
1639   bool DoBlasRotmg(Stream *stream, DeviceMemory<float> *d1,                    \
1640                    DeviceMemory<float> *d2, DeviceMemory<float> *x1,           \
1641                    const DeviceMemory<float> &y1, DeviceMemory<float> *param)  \
1642       override;                                                                \
1643   bool DoBlasRotmg(Stream *stream, DeviceMemory<double> *d1,                   \
1644                    DeviceMemory<double> *d2, DeviceMemory<double> *x1,         \
1645                    const DeviceMemory<double> &y1,                             \
1646                    DeviceMemory<double> *param) override;                      \
1647   bool DoBlasScal(Stream *stream, uint64 elem_count, float alpha,              \
1648                   DeviceMemory<float> *x, int incx) override;                  \
1649   bool DoBlasScal(Stream *stream, uint64 elem_count, double alpha,             \
1650                   DeviceMemory<double> *x, int incx) override;                 \
1651   bool DoBlasScal(Stream *stream, uint64 elem_count, float alpha,              \
1652                   DeviceMemory<std::complex<float>> *x, int incx) override;    \
1653   bool DoBlasScal(Stream *stream, uint64 elem_count, double alpha,             \
1654                   DeviceMemory<std::complex<double>> *x, int incx) override;   \
1655   bool DoBlasScal(Stream *stream, uint64 elem_count,                           \
1656                   std::complex<float> alpha,                                   \
1657                   DeviceMemory<std::complex<float>> *x, int incx) override;    \
1658   bool DoBlasScal(Stream *stream, uint64 elem_count,                           \
1659                   std::complex<double> alpha,                                  \
1660                   DeviceMemory<std::complex<double>> *x, int incx) override;   \
1661   bool DoBlasSwap(Stream *stream, uint64 elem_count, DeviceMemory<float> *x,   \
1662                   int incx, DeviceMemory<float> *y, int incy) override;        \
1663   bool DoBlasSwap(Stream *stream, uint64 elem_count, DeviceMemory<double> *x,  \
1664                   int incx, DeviceMemory<double> *y, int incy) override;       \
1665   bool DoBlasSwap(Stream *stream, uint64 elem_count,                           \
1666                   DeviceMemory<std::complex<float>> *x, int incx,              \
1667                   DeviceMemory<std::complex<float>> *y, int incy) override;    \
1668   bool DoBlasSwap(Stream *stream, uint64 elem_count,                           \
1669                   DeviceMemory<std::complex<double>> *x, int incx,             \
1670                   DeviceMemory<std::complex<double>> *y, int incy) override;   \
1671   bool DoBlasIamax(Stream *stream, uint64 elem_count,                          \
1672                    const DeviceMemory<float> &x, int incx,                     \
1673                    DeviceMemory<int> *result) override;                        \
1674   bool DoBlasIamax(Stream *stream, uint64 elem_count,                          \
1675                    const DeviceMemory<double> &x, int incx,                    \
1676                    DeviceMemory<int> *result) override;                        \
1677   bool DoBlasIamax(Stream *stream, uint64 elem_count,                          \
1678                    const DeviceMemory<std::complex<float>> &x, int incx,       \
1679                    DeviceMemory<int> *result) override;                        \
1680   bool DoBlasIamax(Stream *stream, uint64 elem_count,                          \
1681                    const DeviceMemory<std::complex<double>> &x, int incx,      \
1682                    DeviceMemory<int> *result) override;                        \
1683   bool DoBlasIamin(Stream *stream, uint64 elem_count,                          \
1684                    const DeviceMemory<float> &x, int incx,                     \
1685                    DeviceMemory<int> *result) override;                        \
1686   bool DoBlasIamin(Stream *stream, uint64 elem_count,                          \
1687                    const DeviceMemory<double> &x, int incx,                    \
1688                    DeviceMemory<int> *result) override;                        \
1689   bool DoBlasIamin(Stream *stream, uint64 elem_count,                          \
1690                    const DeviceMemory<std::complex<float>> &x, int incx,       \
1691                    DeviceMemory<int> *result) override;                        \
1692   bool DoBlasIamin(Stream *stream, uint64 elem_count,                          \
1693                    const DeviceMemory<std::complex<double>> &x, int incx,      \
1694                    DeviceMemory<int> *result) override;                        \
1695   bool DoBlasGbmv(Stream *stream, blas::Transpose trans, uint64 m, uint64 n,   \
1696                   uint64 kl, uint64 ku, float alpha,                           \
1697                   const DeviceMemory<float> &a, int lda,                       \
1698                   const DeviceMemory<float> &x, int incx, float beta,          \
1699                   DeviceMemory<float> *y, int incy) override;                  \
1700   bool DoBlasGbmv(Stream *stream, blas::Transpose trans, uint64 m, uint64 n,   \
1701                   uint64 kl, uint64 ku, double alpha,                          \
1702                   const DeviceMemory<double> &a, int lda,                      \
1703                   const DeviceMemory<double> &x, int incx, double beta,        \
1704                   DeviceMemory<double> *y, int incy) override;                 \
1705   bool DoBlasGbmv(Stream *stream, blas::Transpose trans, uint64 m, uint64 n,   \
1706                   uint64 kl, uint64 ku, std::complex<float> alpha,             \
1707                   const DeviceMemory<std::complex<float>> &a, int lda,         \
1708                   const DeviceMemory<std::complex<float>> &x, int incx,        \
1709                   std::complex<float> beta,                                    \
1710                   DeviceMemory<std::complex<float>> *y, int incy) override;    \
1711   bool DoBlasGbmv(Stream *stream, blas::Transpose trans, uint64 m, uint64 n,   \
1712                   uint64 kl, uint64 ku, std::complex<double> alpha,            \
1713                   const DeviceMemory<std::complex<double>> &a, int lda,        \
1714                   const DeviceMemory<std::complex<double>> &x, int incx,       \
1715                   std::complex<double> beta,                                   \
1716                   DeviceMemory<std::complex<double>> *y, int incy) override;   \
1717   bool DoBlasGemv(Stream *stream, blas::Transpose trans, uint64 m, uint64 n,   \
1718                   float alpha, const DeviceMemory<float> &a, int lda,          \
1719                   const DeviceMemory<float> &x, int incx, float beta,          \
1720                   DeviceMemory<float> *y, int incy) override;                  \
1721   bool DoBlasGemv(Stream *stream, blas::Transpose trans, uint64 m, uint64 n,   \
1722                   double alpha, const DeviceMemory<double> &a, int lda,        \
1723                   const DeviceMemory<double> &x, int incx, double beta,        \
1724                   DeviceMemory<double> *y, int incy) override;                 \
1725   bool DoBlasGemv(Stream *stream, blas::Transpose trans, uint64 m, uint64 n,   \
1726                   std::complex<float> alpha,                                   \
1727                   const DeviceMemory<std::complex<float>> &a, int lda,         \
1728                   const DeviceMemory<std::complex<float>> &x, int incx,        \
1729                   std::complex<float> beta,                                    \
1730                   DeviceMemory<std::complex<float>> *y, int incy) override;    \
1731   bool DoBlasGemv(Stream *stream, blas::Transpose trans, uint64 m, uint64 n,   \
1732                   std::complex<double> alpha,                                  \
1733                   const DeviceMemory<std::complex<double>> &a, int lda,        \
1734                   const DeviceMemory<std::complex<double>> &x, int incx,       \
1735                   std::complex<double> beta,                                   \
1736                   DeviceMemory<std::complex<double>> *y, int incy) override;   \
1737   bool DoBlasGemvWithProfiling(                                                \
1738       Stream *stream, blas::Transpose trans, uint64 m, uint64 n, float alpha,  \
1739       const DeviceMemory<float> &a, int lda, const DeviceMemory<float> &x,     \
1740       int incx, float beta, DeviceMemory<float> *y, int incy,                  \
1741       blas::ProfileResult *output_profile_result) override;                    \
1742   bool DoBlasGemvWithProfiling(                                                \
1743       Stream *stream, blas::Transpose trans, uint64 m, uint64 n, double alpha, \
1744       const DeviceMemory<double> &a, int lda, const DeviceMemory<double> &x,   \
1745       int incx, double beta, DeviceMemory<double> *y, int incy,                \
1746       blas::ProfileResult *output_profile_result) override;                    \
1747   bool DoBlasGemvWithProfiling(                                                \
1748       Stream *stream, blas::Transpose trans, uint64 m, uint64 n,               \
1749       std::complex<float> alpha, const DeviceMemory<std::complex<float>> &a,   \
1750       int lda, const DeviceMemory<std::complex<float>> &x, int incx,           \
1751       std::complex<float> beta, DeviceMemory<std::complex<float>> *y,          \
1752       int incy, blas::ProfileResult *output_profile_result) override;          \
1753   bool DoBlasGemvWithProfiling(                                                \
1754       Stream *stream, blas::Transpose trans, uint64 m, uint64 n,               \
1755       std::complex<double> alpha, const DeviceMemory<std::complex<double>> &a, \
1756       int lda, const DeviceMemory<std::complex<double>> &x, int incx,          \
1757       std::complex<double> beta, DeviceMemory<std::complex<double>> *y,        \
1758       int incy, blas::ProfileResult *output_profile_result) override;          \
1759   bool DoBlasGer(Stream *stream, uint64 m, uint64 n, float alpha,              \
1760                  const DeviceMemory<float> &x, int incx,                       \
1761                  const DeviceMemory<float> &y, int incy,                       \
1762                  DeviceMemory<float> *a, int lda) override;                    \
1763   bool DoBlasGer(Stream *stream, uint64 m, uint64 n, double alpha,             \
1764                  const DeviceMemory<double> &x, int incx,                      \
1765                  const DeviceMemory<double> &y, int incy,                      \
1766                  DeviceMemory<double> *a, int lda) override;                   \
1767   bool DoBlasGerc(Stream *stream, uint64 m, uint64 n,                          \
1768                   std::complex<float> alpha,                                   \
1769                   const DeviceMemory<std::complex<float>> &x, int incx,        \
1770                   const DeviceMemory<std::complex<float>> &y, int incy,        \
1771                   DeviceMemory<std::complex<float>> *a, int lda) override;     \
1772   bool DoBlasGerc(Stream *stream, uint64 m, uint64 n,                          \
1773                   std::complex<double> alpha,                                  \
1774                   const DeviceMemory<std::complex<double>> &x, int incx,       \
1775                   const DeviceMemory<std::complex<double>> &y, int incy,       \
1776                   DeviceMemory<std::complex<double>> *a, int lda) override;    \
1777   bool DoBlasGeru(Stream *stream, uint64 m, uint64 n,                          \
1778                   std::complex<float> alpha,                                   \
1779                   const DeviceMemory<std::complex<float>> &x, int incx,        \
1780                   const DeviceMemory<std::complex<float>> &y, int incy,        \
1781                   DeviceMemory<std::complex<float>> *a, int lda) override;     \
1782   bool DoBlasGeru(Stream *stream, uint64 m, uint64 n,                          \
1783                   std::complex<double> alpha,                                  \
1784                   const DeviceMemory<std::complex<double>> &x, int incx,       \
1785                   const DeviceMemory<std::complex<double>> &y, int incy,       \
1786                   DeviceMemory<std::complex<double>> *a, int lda) override;    \
1787   bool DoBlasHbmv(Stream *stream, blas::UpperLower uplo, uint64 n, uint64 k,   \
1788                   std::complex<float> alpha,                                   \
1789                   const DeviceMemory<std::complex<float>> &a, int lda,         \
1790                   const DeviceMemory<std::complex<float>> &x, int incx,        \
1791                   std::complex<float> beta,                                    \
1792                   DeviceMemory<std::complex<float>> *y, int incy) override;    \
1793   bool DoBlasHbmv(Stream *stream, blas::UpperLower uplo, uint64 n, uint64 k,   \
1794                   std::complex<double> alpha,                                  \
1795                   const DeviceMemory<std::complex<double>> &a, int lda,        \
1796                   const DeviceMemory<std::complex<double>> &x, int incx,       \
1797                   std::complex<double> beta,                                   \
1798                   DeviceMemory<std::complex<double>> *y, int incy) override;   \
1799   bool DoBlasHemv(Stream *stream, blas::UpperLower uplo, uint64 n,             \
1800                   std::complex<float> alpha,                                   \
1801                   const DeviceMemory<std::complex<float>> &a, int lda,         \
1802                   const DeviceMemory<std::complex<float>> &x, int incx,        \
1803                   std::complex<float> beta,                                    \
1804                   DeviceMemory<std::complex<float>> *y, int incy) override;    \
1805   bool DoBlasHemv(Stream *stream, blas::UpperLower uplo, uint64 n,             \
1806                   std::complex<double> alpha,                                  \
1807                   const DeviceMemory<std::complex<double>> &a, int lda,        \
1808                   const DeviceMemory<std::complex<double>> &x, int incx,       \
1809                   std::complex<double> beta,                                   \
1810                   DeviceMemory<std::complex<double>> *y, int incy) override;   \
1811   bool DoBlasHer(Stream *stream, blas::UpperLower uplo, uint64 n, float alpha, \
1812                  const DeviceMemory<std::complex<float>> &x, int incx,         \
1813                  DeviceMemory<std::complex<float>> *a, int lda) override;      \
1814   bool DoBlasHer(Stream *stream, blas::UpperLower uplo, uint64 n,              \
1815                  double alpha, const DeviceMemory<std::complex<double>> &x,    \
1816                  int incx, DeviceMemory<std::complex<double>> *a, int lda)     \
1817       override;                                                                \
1818   bool DoBlasHer2(Stream *stream, blas::UpperLower uplo, uint64 n,             \
1819                   std::complex<float> alpha,                                   \
1820                   const DeviceMemory<std::complex<float>> &x, int incx,        \
1821                   const DeviceMemory<std::complex<float>> &y, int incy,        \
1822                   DeviceMemory<std::complex<float>> *a, int lda) override;     \
1823   bool DoBlasHer2(Stream *stream, blas::UpperLower uplo, uint64 n,             \
1824                   std::complex<double> alpha,                                  \
1825                   const DeviceMemory<std::complex<double>> &x, int incx,       \
1826                   const DeviceMemory<std::complex<double>> &y, int incy,       \
1827                   DeviceMemory<std::complex<double>> *a, int lda) override;    \
1828   bool DoBlasHpmv(Stream *stream, blas::UpperLower uplo, uint64 n,             \
1829                   std::complex<float> alpha,                                   \
1830                   const DeviceMemory<std::complex<float>> &ap,                 \
1831                   const DeviceMemory<std::complex<float>> &x, int incx,        \
1832                   std::complex<float> beta,                                    \
1833                   DeviceMemory<std::complex<float>> *y, int incy) override;    \
1834   bool DoBlasHpmv(Stream *stream, blas::UpperLower uplo, uint64 n,             \
1835                   std::complex<double> alpha,                                  \
1836                   const DeviceMemory<std::complex<double>> &ap,                \
1837                   const DeviceMemory<std::complex<double>> &x, int incx,       \
1838                   std::complex<double> beta,                                   \
1839                   DeviceMemory<std::complex<double>> *y, int incy) override;   \
1840   bool DoBlasHpr(Stream *stream, blas::UpperLower uplo, uint64 n, float alpha, \
1841                  const DeviceMemory<std::complex<float>> &x, int incx,         \
1842                  DeviceMemory<std::complex<float>> *ap) override;              \
1843   bool DoBlasHpr(Stream *stream, blas::UpperLower uplo, uint64 n,              \
1844                  double alpha, const DeviceMemory<std::complex<double>> &x,    \
1845                  int incx, DeviceMemory<std::complex<double>> *ap) override;   \
1846   bool DoBlasHpr2(Stream *stream, blas::UpperLower uplo, uint64 n,             \
1847                   std::complex<float> alpha,                                   \
1848                   const DeviceMemory<std::complex<float>> &x, int incx,        \
1849                   const DeviceMemory<std::complex<float>> &y, int incy,        \
1850                   DeviceMemory<std::complex<float>> *ap) override;             \
1851   bool DoBlasHpr2(Stream *stream, blas::UpperLower uplo, uint64 n,             \
1852                   std::complex<double> alpha,                                  \
1853                   const DeviceMemory<std::complex<double>> &x, int incx,       \
1854                   const DeviceMemory<std::complex<double>> &y, int incy,       \
1855                   DeviceMemory<std::complex<double>> *ap) override;            \
1856   bool DoBlasSbmv(Stream *stream, blas::UpperLower uplo, uint64 n, uint64 k,   \
1857                   float alpha, const DeviceMemory<float> &a, int lda,          \
1858                   const DeviceMemory<float> &x, int incx, float beta,          \
1859                   DeviceMemory<float> *y, int incy) override;                  \
1860   bool DoBlasSbmv(Stream *stream, blas::UpperLower uplo, uint64 n, uint64 k,   \
1861                   double alpha, const DeviceMemory<double> &a, int lda,        \
1862                   const DeviceMemory<double> &x, int incx, double beta,        \
1863                   DeviceMemory<double> *y, int incy) override;                 \
1864   bool DoBlasSpmv(Stream *stream, blas::UpperLower uplo, uint64 n,             \
1865                   float alpha, const DeviceMemory<float> &ap,                  \
1866                   const DeviceMemory<float> &x, int incx, float beta,          \
1867                   DeviceMemory<float> *y, int incy) override;                  \
1868   bool DoBlasSpmv(Stream *stream, blas::UpperLower uplo, uint64 n,             \
1869                   double alpha, const DeviceMemory<double> &ap,                \
1870                   const DeviceMemory<double> &x, int incx, double beta,        \
1871                   DeviceMemory<double> *y, int incy) override;                 \
1872   bool DoBlasSpr(Stream *stream, blas::UpperLower uplo, uint64 n, float alpha, \
1873                  const DeviceMemory<float> &x, int incx,                       \
1874                  DeviceMemory<float> *ap) override;                            \
1875   bool DoBlasSpr(Stream *stream, blas::UpperLower uplo, uint64 n,              \
1876                  double alpha, const DeviceMemory<double> &x, int incx,        \
1877                  DeviceMemory<double> *ap) override;                           \
1878   bool DoBlasSpr2(Stream *stream, blas::UpperLower uplo, uint64 n,             \
1879                   float alpha, const DeviceMemory<float> &x, int incx,         \
1880                   const DeviceMemory<float> &y, int incy,                      \
1881                   DeviceMemory<float> *ap) override;                           \
1882   bool DoBlasSpr2(Stream *stream, blas::UpperLower uplo, uint64 n,             \
1883                   double alpha, const DeviceMemory<double> &x, int incx,       \
1884                   const DeviceMemory<double> &y, int incy,                     \
1885                   DeviceMemory<double> *ap) override;                          \
1886   bool DoBlasSymv(Stream *stream, blas::UpperLower uplo, uint64 n,             \
1887                   float alpha, const DeviceMemory<float> &a, int lda,          \
1888                   const DeviceMemory<float> &x, int incx, float beta,          \
1889                   DeviceMemory<float> *y, int incy) override;                  \
1890   bool DoBlasSymv(Stream *stream, blas::UpperLower uplo, uint64 n,             \
1891                   double alpha, const DeviceMemory<double> &a, int lda,        \
1892                   const DeviceMemory<double> &x, int incx, double beta,        \
1893                   DeviceMemory<double> *y, int incy) override;                 \
1894   bool DoBlasSyr(Stream *stream, blas::UpperLower uplo, uint64 n, float alpha, \
1895                  const DeviceMemory<float> &x, int incx,                       \
1896                  DeviceMemory<float> *a, int lda) override;                    \
1897   bool DoBlasSyr(Stream *stream, blas::UpperLower uplo, uint64 n,              \
1898                  double alpha, const DeviceMemory<double> &x, int incx,        \
1899                  DeviceMemory<double> *a, int lda) override;                   \
1900   bool DoBlasSyr2(Stream *stream, blas::UpperLower uplo, uint64 n,             \
1901                   float alpha, const DeviceMemory<float> &x, int incx,         \
1902                   const DeviceMemory<float> &y, int incy,                      \
1903                   DeviceMemory<float> *a, int lda) override;                   \
1904   bool DoBlasSyr2(Stream *stream, blas::UpperLower uplo, uint64 n,             \
1905                   double alpha, const DeviceMemory<double> &x, int incx,       \
1906                   const DeviceMemory<double> &y, int incy,                     \
1907                   DeviceMemory<double> *a, int lda) override;                  \
1908   bool DoBlasTbmv(Stream *stream, blas::UpperLower uplo,                       \
1909                   blas::Transpose trans, blas::Diagonal diag, uint64 n,        \
1910                   uint64 k, const DeviceMemory<float> &a, int lda,             \
1911                   DeviceMemory<float> *x, int incx) override;                  \
1912   bool DoBlasTbmv(Stream *stream, blas::UpperLower uplo,                       \
1913                   blas::Transpose trans, blas::Diagonal diag, uint64 n,        \
1914                   uint64 k, const DeviceMemory<double> &a, int lda,            \
1915                   DeviceMemory<double> *x, int incx) override;                 \
1916   bool DoBlasTbmv(Stream *stream, blas::UpperLower uplo,                       \
1917                   blas::Transpose trans, blas::Diagonal diag, uint64 n,        \
1918                   uint64 k, const DeviceMemory<std::complex<float>> &a,        \
1919                   int lda, DeviceMemory<std::complex<float>> *x, int incx)     \
1920       override;                                                                \
1921   bool DoBlasTbmv(Stream *stream, blas::UpperLower uplo,                       \
1922                   blas::Transpose trans, blas::Diagonal diag, uint64 n,        \
1923                   uint64 k, const DeviceMemory<std::complex<double>> &a,       \
1924                   int lda, DeviceMemory<std::complex<double>> *x, int incx)    \
1925       override;                                                                \
1926   bool DoBlasTbsv(Stream *stream, blas::UpperLower uplo,                       \
1927                   blas::Transpose trans, blas::Diagonal diag, uint64 n,        \
1928                   uint64 k, const DeviceMemory<float> &a, int lda,             \
1929                   DeviceMemory<float> *x, int incx) override;                  \
1930   bool DoBlasTbsv(Stream *stream, blas::UpperLower uplo,                       \
1931                   blas::Transpose trans, blas::Diagonal diag, uint64 n,        \
1932                   uint64 k, const DeviceMemory<double> &a, int lda,            \
1933                   DeviceMemory<double> *x, int incx) override;                 \
1934   bool DoBlasTbsv(Stream *stream, blas::UpperLower uplo,                       \
1935                   blas::Transpose trans, blas::Diagonal diag, uint64 n,        \
1936                   uint64 k, const DeviceMemory<std::complex<float>> &a,        \
1937                   int lda, DeviceMemory<std::complex<float>> *x, int incx)     \
1938       override;                                                                \
1939   bool DoBlasTbsv(Stream *stream, blas::UpperLower uplo,                       \
1940                   blas::Transpose trans, blas::Diagonal diag, uint64 n,        \
1941                   uint64 k, const DeviceMemory<std::complex<double>> &a,       \
1942                   int lda, DeviceMemory<std::complex<double>> *x, int incx)    \
1943       override;                                                                \
1944   bool DoBlasTpmv(Stream *stream, blas::UpperLower uplo,                       \
1945                   blas::Transpose trans, blas::Diagonal diag, uint64 n,        \
1946                   const DeviceMemory<float> &ap, DeviceMemory<float> *x,       \
1947                   int incx) override;                                          \
1948   bool DoBlasTpmv(Stream *stream, blas::UpperLower uplo,                       \
1949                   blas::Transpose trans, blas::Diagonal diag, uint64 n,        \
1950                   const DeviceMemory<double> &ap, DeviceMemory<double> *x,     \
1951                   int incx) override;                                          \
1952   bool DoBlasTpmv(Stream *stream, blas::UpperLower uplo,                       \
1953                   blas::Transpose trans, blas::Diagonal diag, uint64 n,        \
1954                   const DeviceMemory<std::complex<float>> &ap,                 \
1955                   DeviceMemory<std::complex<float>> *x, int incx) override;    \
1956   bool DoBlasTpmv(Stream *stream, blas::UpperLower uplo,                       \
1957                   blas::Transpose trans, blas::Diagonal diag, uint64 n,        \
1958                   const DeviceMemory<std::complex<double>> &ap,                \
1959                   DeviceMemory<std::complex<double>> *x, int incx) override;   \
1960   bool DoBlasTpsv(Stream *stream, blas::UpperLower uplo,                       \
1961                   blas::Transpose trans, blas::Diagonal diag, uint64 n,        \
1962                   const DeviceMemory<float> &ap, DeviceMemory<float> *x,       \
1963                   int incx) override;                                          \
1964   bool DoBlasTpsv(Stream *stream, blas::UpperLower uplo,                       \
1965                   blas::Transpose trans, blas::Diagonal diag, uint64 n,        \
1966                   const DeviceMemory<double> &ap, DeviceMemory<double> *x,     \
1967                   int incx) override;                                          \
1968   bool DoBlasTpsv(Stream *stream, blas::UpperLower uplo,                       \
1969                   blas::Transpose trans, blas::Diagonal diag, uint64 n,        \
1970                   const DeviceMemory<std::complex<float>> &ap,                 \
1971                   DeviceMemory<std::complex<float>> *x, int incx) override;    \
1972   bool DoBlasTpsv(Stream *stream, blas::UpperLower uplo,                       \
1973                   blas::Transpose trans, blas::Diagonal diag, uint64 n,        \
1974                   const DeviceMemory<std::complex<double>> &ap,                \
1975                   DeviceMemory<std::complex<double>> *x, int incx) override;   \
1976   bool DoBlasTrmv(Stream *stream, blas::UpperLower uplo,                       \
1977                   blas::Transpose trans, blas::Diagonal diag, uint64 n,        \
1978                   const DeviceMemory<float> &a, int lda,                       \
1979                   DeviceMemory<float> *x, int incx) override;                  \
1980   bool DoBlasTrmv(Stream *stream, blas::UpperLower uplo,                       \
1981                   blas::Transpose trans, blas::Diagonal diag, uint64 n,        \
1982                   const DeviceMemory<double> &a, int lda,                      \
1983                   DeviceMemory<double> *x, int incx) override;                 \
1984   bool DoBlasTrmv(Stream *stream, blas::UpperLower uplo,                       \
1985                   blas::Transpose trans, blas::Diagonal diag, uint64 n,        \
1986                   const DeviceMemory<std::complex<float>> &a, int lda,         \
1987                   DeviceMemory<std::complex<float>> *x, int incx) override;    \
1988   bool DoBlasTrmv(Stream *stream, blas::UpperLower uplo,                       \
1989                   blas::Transpose trans, blas::Diagonal diag, uint64 n,        \
1990                   const DeviceMemory<std::complex<double>> &a, int lda,        \
1991                   DeviceMemory<std::complex<double>> *x, int incx) override;   \
1992   bool DoBlasTrsv(Stream *stream, blas::UpperLower uplo,                       \
1993                   blas::Transpose trans, blas::Diagonal diag, uint64 n,        \
1994                   const DeviceMemory<float> &a, int lda,                       \
1995                   DeviceMemory<float> *x, int incx) override;                  \
1996   bool DoBlasTrsv(Stream *stream, blas::UpperLower uplo,                       \
1997                   blas::Transpose trans, blas::Diagonal diag, uint64 n,        \
1998                   const DeviceMemory<double> &a, int lda,                      \
1999                   DeviceMemory<double> *x, int incx) override;                 \
2000   bool DoBlasTrsv(Stream *stream, blas::UpperLower uplo,                       \
2001                   blas::Transpose trans, blas::Diagonal diag, uint64 n,        \
2002                   const DeviceMemory<std::complex<float>> &a, int lda,         \
2003                   DeviceMemory<std::complex<float>> *x, int incx) override;    \
2004   bool DoBlasTrsv(Stream *stream, blas::UpperLower uplo,                       \
2005                   blas::Transpose trans, blas::Diagonal diag, uint64 n,        \
2006                   const DeviceMemory<std::complex<double>> &a, int lda,        \
2007                   DeviceMemory<std::complex<double>> *x, int incx) override;   \
2008   bool DoBlasGemm(Stream *stream, blas::Transpose transa,                      \
2009                   blas::Transpose transb, uint64 m, uint64 n, uint64 k,        \
2010                   float alpha, const DeviceMemory<Eigen::half> &a, int lda,    \
2011                   const DeviceMemory<Eigen::half> &b, int ldb, float beta,     \
2012                   DeviceMemory<Eigen::half> *c, int ldc) override;             \
2013   bool DoBlasGemm(Stream *stream, blas::Transpose transa,                      \
2014                   blas::Transpose transb, uint64 m, uint64 n, uint64 k,        \
2015                   float alpha, const DeviceMemory<float> &a, int lda,          \
2016                   const DeviceMemory<float> &b, int ldb, float beta,           \
2017                   DeviceMemory<float> *c, int ldc) override;                   \
2018   bool DoBlasGemm(Stream *stream, blas::Transpose transa,                      \
2019                   blas::Transpose transb, uint64 m, uint64 n, uint64 k,        \
2020                   double alpha, const DeviceMemory<double> &a, int lda,        \
2021                   const DeviceMemory<double> &b, int ldb, double beta,         \
2022                   DeviceMemory<double> *c, int ldc) override;                  \
2023   bool DoBlasGemm(Stream *stream, blas::Transpose transa,                      \
2024                   blas::Transpose transb, uint64 m, uint64 n, uint64 k,        \
2025                   std::complex<float> alpha,                                   \
2026                   const DeviceMemory<std::complex<float>> &a, int lda,         \
2027                   const DeviceMemory<std::complex<float>> &b, int ldb,         \
2028                   std::complex<float> beta,                                    \
2029                   DeviceMemory<std::complex<float>> *c, int ldc) override;     \
2030   bool DoBlasGemm(Stream *stream, blas::Transpose transa,                      \
2031                   blas::Transpose transb, uint64 m, uint64 n, uint64 k,        \
2032                   std::complex<double> alpha,                                  \
2033                   const DeviceMemory<std::complex<double>> &a, int lda,        \
2034                   const DeviceMemory<std::complex<double>> &b, int ldb,        \
2035                   std::complex<double> beta,                                   \
2036                   DeviceMemory<std::complex<double>> *c, int ldc) override;    \
2037   bool DoBlasGemmWithProfiling(                                                \
2038       Stream *stream, blas::Transpose transa, blas::Transpose transb,          \
2039       uint64 m, uint64 n, uint64 k, float alpha,                               \
2040       const DeviceMemory<Eigen::half> &a, int lda,                             \
2041       const DeviceMemory<Eigen::half> &b, int ldb, float beta,                 \
2042       DeviceMemory<Eigen::half> *c, int ldc,                                   \
2043       blas::ProfileResult *output_profile_result) override;                    \
2044   bool DoBlasGemmWithProfiling(                                                \
2045       Stream *stream, blas::Transpose transa, blas::Transpose transb,          \
2046       uint64 m, uint64 n, uint64 k, float alpha, const DeviceMemory<float> &a, \
2047       int lda, const DeviceMemory<float> &b, int ldb, float beta,              \
2048       DeviceMemory<float> *c, int ldc,                                         \
2049       blas::ProfileResult *output_profile_result) override;                    \
2050   bool DoBlasGemmWithProfiling(                                                \
2051       Stream *stream, blas::Transpose transa, blas::Transpose transb,          \
2052       uint64 m, uint64 n, uint64 k, double alpha,                              \
2053       const DeviceMemory<double> &a, int lda, const DeviceMemory<double> &b,   \
2054       int ldb, double beta, DeviceMemory<double> *c, int ldc,                  \
2055       blas::ProfileResult *output_profile_result) override;                    \
2056   bool DoBlasGemmWithProfiling(                                                \
2057       Stream *stream, blas::Transpose transa, blas::Transpose transb,          \
2058       uint64 m, uint64 n, uint64 k, std::complex<float> alpha,                 \
2059       const DeviceMemory<std::complex<float>> &a, int lda,                     \
2060       const DeviceMemory<std::complex<float>> &b, int ldb,                     \
2061       std::complex<float> beta, DeviceMemory<std::complex<float>> *c, int ldc, \
2062       blas::ProfileResult *output_profile_result) override;                    \
2063   bool DoBlasGemmWithProfiling(                                                \
2064       Stream *stream, blas::Transpose transa, blas::Transpose transb,          \
2065       uint64 m, uint64 n, uint64 k, std::complex<double> alpha,                \
2066       const DeviceMemory<std::complex<double>> &a, int lda,                    \
2067       const DeviceMemory<std::complex<double>> &b, int ldb,                    \
2068       std::complex<double> beta, DeviceMemory<std::complex<double>> *c,        \
2069       int ldc, blas::ProfileResult *output_profile_result) override;           \
2070   bool GetBlasGemmAlgorithms(std::vector<blas::AlgorithmType> *out_algorithms) \
2071       override;                                                                \
2072   bool DoBlasGemmWithAlgorithm(                                                \
2073       Stream *stream, blas::Transpose transa, blas::Transpose transb,          \
2074       uint64 m, uint64 n, uint64 k, const HostOrDeviceScalar<int> &alpha,      \
2075       const DeviceMemory<int8> &a, int lda, const DeviceMemory<int8> &b,       \
2076       int ldb, const HostOrDeviceScalar<int> &beta, DeviceMemory<int> *c,      \
2077       int ldc, blas::ComputationType computation_type,                         \
2078       blas::AlgorithmType algorithm,                                           \
2079       blas::ProfileResult *output_profile_result) override;                    \
2080   bool DoBlasGemmWithAlgorithm(                                                \
2081       Stream *stream, blas::Transpose transa, blas::Transpose transb,          \
2082       uint64 m, uint64 n, uint64 k,                                            \
2083       const HostOrDeviceScalar<Eigen::half> &alpha,                            \
2084       const DeviceMemory<Eigen::half> &a, int lda,                             \
2085       const DeviceMemory<Eigen::half> &b, int ldb,                             \
2086       const HostOrDeviceScalar<Eigen::half> &beta,                             \
2087       DeviceMemory<Eigen::half> *c, int ldc,                                   \
2088       blas::ComputationType computation_type, blas::AlgorithmType algorithm,   \
2089       blas::ProfileResult *output_profile_result) override;                    \
2090   bool DoBlasGemmWithAlgorithm(                                                \
2091       Stream *stream, blas::Transpose transa, blas::Transpose transb,          \
2092       uint64 m, uint64 n, uint64 k, const HostOrDeviceScalar<float> &alpha,    \
2093       const DeviceMemory<float> &a, int lda, const DeviceMemory<float> &b,     \
2094       int ldb, const HostOrDeviceScalar<float> &beta, DeviceMemory<float> *c,  \
2095       int ldc, blas::ComputationType computation_type,                         \
2096       blas::AlgorithmType algorithm,                                           \
2097       blas::ProfileResult *output_profile_result) override;                    \
2098   bool DoBlasGemmWithAlgorithm(                                                \
2099       Stream *stream, blas::Transpose transa, blas::Transpose transb,          \
2100       uint64 m, uint64 n, uint64 k, const HostOrDeviceScalar<double> &alpha,   \
2101       const DeviceMemory<double> &a, int lda, const DeviceMemory<double> &b,   \
2102       int ldb, const HostOrDeviceScalar<double> &beta,                         \
2103       DeviceMemory<double> *c, int ldc,                                        \
2104       blas::ComputationType computation_type, blas::AlgorithmType algorithm,   \
2105       blas::ProfileResult *output_profile_result) override;                    \
2106   bool DoBlasGemmWithAlgorithm(                                                \
2107       Stream *stream, blas::Transpose transa, blas::Transpose transb,          \
2108       uint64 m, uint64 n, uint64 k,                                            \
2109       const HostOrDeviceScalar<std::complex<float>> &alpha,                    \
2110       const DeviceMemory<std::complex<float>> &a, int lda,                     \
2111       const DeviceMemory<std::complex<float>> &b, int ldb,                     \
2112       const HostOrDeviceScalar<std::complex<float>> &beta,                     \
2113       DeviceMemory<std::complex<float>> *c, int ldc,                           \
2114       blas::ComputationType computation_type, blas::AlgorithmType algorithm,   \
2115       blas::ProfileResult *output_profile_result) override;                    \
2116   bool DoBlasGemmWithAlgorithm(                                                \
2117       Stream *stream, blas::Transpose transa, blas::Transpose transb,          \
2118       uint64 m, uint64 n, uint64 k,                                            \
2119       const HostOrDeviceScalar<std::complex<double>> &alpha,                   \
2120       const DeviceMemory<std::complex<double>> &a, int lda,                    \
2121       const DeviceMemory<std::complex<double>> &b, int ldb,                    \
2122       const HostOrDeviceScalar<std::complex<double>> &beta,                    \
2123       DeviceMemory<std::complex<double>> *c, int ldc,                          \
2124       blas::ComputationType computation_type, blas::AlgorithmType algorithm,   \
2125       blas::ProfileResult *output_profile_result) override;                    \
2126   bool DoBlasGemmBatched(                                                      \
2127       Stream *stream, blas::Transpose transa, blas::Transpose transb,          \
2128       uint64 m, uint64 n, uint64 k, float alpha,                               \
2129       const port::ArraySlice<DeviceMemory<Eigen::half> *> &a, int lda,         \
2130       const port::ArraySlice<DeviceMemory<Eigen::half> *> &b, int ldb,         \
2131       float beta, const port::ArraySlice<DeviceMemory<Eigen::half> *> &c,      \
2132       int ldc, int batch_count, ScratchAllocator *scratch_allocator) override; \
2133   bool DoBlasGemmBatched(                                                      \
2134       Stream *stream, blas::Transpose transa, blas::Transpose transb,          \
2135       uint64 m, uint64 n, uint64 k, float alpha,                               \
2136       const port::ArraySlice<DeviceMemory<float> *> &a, int lda,               \
2137       const port::ArraySlice<DeviceMemory<float> *> &b, int ldb, float beta,   \
2138       const port::ArraySlice<DeviceMemory<float> *> &c, int ldc,               \
2139       int batch_count, ScratchAllocator *scratch_allocator) override;          \
2140   bool DoBlasGemmBatched(                                                      \
2141       Stream *stream, blas::Transpose transa, blas::Transpose transb,          \
2142       uint64 m, uint64 n, uint64 k, double alpha,                              \
2143       const port::ArraySlice<DeviceMemory<double> *> &a, int lda,              \
2144       const port::ArraySlice<DeviceMemory<double> *> &b, int ldb, double beta, \
2145       const port::ArraySlice<DeviceMemory<double> *> &c, int ldc,              \
2146       int batch_count, ScratchAllocator *scratch_allocator) override;          \
2147   bool DoBlasGemmBatched(                                                      \
2148       Stream *stream, blas::Transpose transa, blas::Transpose transb,          \
2149       uint64 m, uint64 n, uint64 k, std::complex<float> alpha,                 \
2150       const port::ArraySlice<DeviceMemory<std::complex<float>> *> &a, int lda, \
2151       const port::ArraySlice<DeviceMemory<std::complex<float>> *> &b, int ldb, \
2152       std::complex<float> beta,                                                \
2153       const port::ArraySlice<DeviceMemory<std::complex<float>> *> &c, int ldc, \
2154       int batch_count, ScratchAllocator *scratch_allocator) override;          \
2155   bool DoBlasGemmBatched(                                                      \
2156       Stream *stream, blas::Transpose transa, blas::Transpose transb,          \
2157       uint64 m, uint64 n, uint64 k, std::complex<double> alpha,                \
2158       const port::ArraySlice<DeviceMemory<std::complex<double>> *> &a,         \
2159       int lda,                                                                 \
2160       const port::ArraySlice<DeviceMemory<std::complex<double>> *> &b,         \
2161       int ldb, std::complex<double> beta,                                      \
2162       const port::ArraySlice<DeviceMemory<std::complex<double>> *> &c,         \
2163       int ldc, int batch_count, ScratchAllocator *scratch_allocator) override; \
2164   bool DoBlasGemmStridedBatched(                                               \
2165       Stream *stream, blas::Transpose transa, blas::Transpose transb,          \
2166       uint64 m, uint64 n, uint64 k, float alpha,                               \
2167       const DeviceMemory<Eigen::half> &a, int lda, int64 stride_a,             \
2168       const DeviceMemory<Eigen::half> &b, int ldb, int64 stride_b, float beta, \
2169       DeviceMemory<Eigen::half> *c, int ldc, int64 stride_c, int batch_count); \
2170   bool DoBlasGemmStridedBatched(                                               \
2171       Stream *stream, blas::Transpose transa, blas::Transpose transb,          \
2172       uint64 m, uint64 n, uint64 k, float alpha, const DeviceMemory<float> &a, \
2173       int lda, int64 stride_a, const DeviceMemory<float> &b, int ldb,          \
2174       int64 stride_b, float beta, DeviceMemory<float> *c, int ldc,             \
2175       int64 stride_c, int batch_count);                                        \
2176   bool DoBlasGemmStridedBatched(                                               \
2177       Stream *stream, blas::Transpose transa, blas::Transpose transb,          \
2178       uint64 m, uint64 n, uint64 k, double alpha,                              \
2179       const DeviceMemory<double> &a, int lda, int64 stride_a,                  \
2180       const DeviceMemory<double> &b, int ldb, int64 stride_b, double beta,     \
2181       DeviceMemory<double> *c, int ldc, int64 stride_c, int batch_count);      \
2182   bool DoBlasGemmStridedBatched(                                               \
2183       Stream *stream, blas::Transpose transa, blas::Transpose transb,          \
2184       uint64 m, uint64 n, uint64 k, std::complex<float> alpha,                 \
2185       const DeviceMemory<std::complex<float>> &a, int lda, int64 stride_a,     \
2186       const DeviceMemory<std::complex<float>> &b, int ldb, int64 stride_b,     \
2187       std::complex<float> beta, DeviceMemory<std::complex<float>> *c, int ldc, \
2188       int64 stride_c, int batch_count);                                        \
2189   bool DoBlasGemmStridedBatched(                                               \
2190       Stream *stream, blas::Transpose transa, blas::Transpose transb,          \
2191       uint64 m, uint64 n, uint64 k, std::complex<double> alpha,                \
2192       const DeviceMemory<std::complex<double>> &a, int lda, int64 stride_a,    \
2193       const DeviceMemory<std::complex<double>> &b, int ldb, int64 stride_b,    \
2194       std::complex<double> beta, DeviceMemory<std::complex<double>> *c,        \
2195       int ldc, int64 stride_c, int batch_count);                               \
2196   bool DoBlasHemm(Stream *stream, blas::Side side, blas::UpperLower uplo,      \
2197                   uint64 m, uint64 n, std::complex<float> alpha,               \
2198                   const DeviceMemory<std::complex<float>> &a, int lda,         \
2199                   const DeviceMemory<std::complex<float>> &b, int ldb,         \
2200                   std::complex<float> beta,                                    \
2201                   DeviceMemory<std::complex<float>> *c, int ldc) override;     \
2202   bool DoBlasHemm(Stream *stream, blas::Side side, blas::UpperLower uplo,      \
2203                   uint64 m, uint64 n, std::complex<double> alpha,              \
2204                   const DeviceMemory<std::complex<double>> &a, int lda,        \
2205                   const DeviceMemory<std::complex<double>> &b, int ldb,        \
2206                   std::complex<double> beta,                                   \
2207                   DeviceMemory<std::complex<double>> *c, int ldc) override;    \
2208   bool DoBlasHerk(Stream *stream, blas::UpperLower uplo,                       \
2209                   blas::Transpose trans, uint64 n, uint64 k, float alpha,      \
2210                   const DeviceMemory<std::complex<float>> &a, int lda,         \
2211                   float beta, DeviceMemory<std::complex<float>> *c, int ldc)   \
2212       override;                                                                \
2213   bool DoBlasHerk(Stream *stream, blas::UpperLower uplo,                       \
2214                   blas::Transpose trans, uint64 n, uint64 k, double alpha,     \
2215                   const DeviceMemory<std::complex<double>> &a, int lda,        \
2216                   double beta, DeviceMemory<std::complex<double>> *c, int ldc) \
2217       override;                                                                \
2218   bool DoBlasHer2k(                                                            \
2219       Stream *stream, blas::UpperLower uplo, blas::Transpose trans, uint64 n,  \
2220       uint64 k, std::complex<float> alpha,                                     \
2221       const DeviceMemory<std::complex<float>> &a, int lda,                     \
2222       const DeviceMemory<std::complex<float>> &b, int ldb, float beta,         \
2223       DeviceMemory<std::complex<float>> *c, int ldc) override;                 \
2224   bool DoBlasHer2k(                                                            \
2225       Stream *stream, blas::UpperLower uplo, blas::Transpose trans, uint64 n,  \
2226       uint64 k, std::complex<double> alpha,                                    \
2227       const DeviceMemory<std::complex<double>> &a, int lda,                    \
2228       const DeviceMemory<std::complex<double>> &b, int ldb, double beta,       \
2229       DeviceMemory<std::complex<double>> *c, int ldc) override;                \
2230   bool DoBlasSymm(Stream *stream, blas::Side side, blas::UpperLower uplo,      \
2231                   uint64 m, uint64 n, float alpha,                             \
2232                   const DeviceMemory<float> &a, int lda,                       \
2233                   const DeviceMemory<float> &b, int ldb, float beta,           \
2234                   DeviceMemory<float> *c, int ldc) override;                   \
2235   bool DoBlasSymm(Stream *stream, blas::Side side, blas::UpperLower uplo,      \
2236                   uint64 m, uint64 n, double alpha,                            \
2237                   const DeviceMemory<double> &a, int lda,                      \
2238                   const DeviceMemory<double> &b, int ldb, double beta,         \
2239                   DeviceMemory<double> *c, int ldc) override;                  \
2240   bool DoBlasSymm(Stream *stream, blas::Side side, blas::UpperLower uplo,      \
2241                   uint64 m, uint64 n, std::complex<float> alpha,               \
2242                   const DeviceMemory<std::complex<float>> &a, int lda,         \
2243                   const DeviceMemory<std::complex<float>> &b, int ldb,         \
2244                   std::complex<float> beta,                                    \
2245                   DeviceMemory<std::complex<float>> *c, int ldc) override;     \
2246   bool DoBlasSymm(Stream *stream, blas::Side side, blas::UpperLower uplo,      \
2247                   uint64 m, uint64 n, std::complex<double> alpha,              \
2248                   const DeviceMemory<std::complex<double>> &a, int lda,        \
2249                   const DeviceMemory<std::complex<double>> &b, int ldb,        \
2250                   std::complex<double> beta,                                   \
2251                   DeviceMemory<std::complex<double>> *c, int ldc) override;    \
2252   bool DoBlasSyrk(Stream *stream, blas::UpperLower uplo,                       \
2253                   blas::Transpose trans, uint64 n, uint64 k, float alpha,      \
2254                   const DeviceMemory<float> &a, int lda, float beta,           \
2255                   DeviceMemory<float> *c, int ldc) override;                   \
2256   bool DoBlasSyrk(Stream *stream, blas::UpperLower uplo,                       \
2257                   blas::Transpose trans, uint64 n, uint64 k, double alpha,     \
2258                   const DeviceMemory<double> &a, int lda, double beta,         \
2259                   DeviceMemory<double> *c, int ldc) override;                  \
2260   bool DoBlasSyrk(Stream *stream, blas::UpperLower uplo,                       \
2261                   blas::Transpose trans, uint64 n, uint64 k,                   \
2262                   std::complex<float> alpha,                                   \
2263                   const DeviceMemory<std::complex<float>> &a, int lda,         \
2264                   std::complex<float> beta,                                    \
2265                   DeviceMemory<std::complex<float>> *c, int ldc) override;     \
2266   bool DoBlasSyrk(Stream *stream, blas::UpperLower uplo,                       \
2267                   blas::Transpose trans, uint64 n, uint64 k,                   \
2268                   std::complex<double> alpha,                                  \
2269                   const DeviceMemory<std::complex<double>> &a, int lda,        \
2270                   std::complex<double> beta,                                   \
2271                   DeviceMemory<std::complex<double>> *c, int ldc) override;    \
2272   bool DoBlasSyr2k(Stream *stream, blas::UpperLower uplo,                      \
2273                    blas::Transpose trans, uint64 n, uint64 k, float alpha,     \
2274                    const DeviceMemory<float> &a, int lda,                      \
2275                    const DeviceMemory<float> &b, int ldb, float beta,          \
2276                    DeviceMemory<float> *c, int ldc) override;                  \
2277   bool DoBlasSyr2k(Stream *stream, blas::UpperLower uplo,                      \
2278                    blas::Transpose trans, uint64 n, uint64 k, double alpha,    \
2279                    const DeviceMemory<double> &a, int lda,                     \
2280                    const DeviceMemory<double> &b, int ldb, double beta,        \
2281                    DeviceMemory<double> *c, int ldc) override;                 \
2282   bool DoBlasSyr2k(Stream *stream, blas::UpperLower uplo,                      \
2283                    blas::Transpose trans, uint64 n, uint64 k,                  \
2284                    std::complex<float> alpha,                                  \
2285                    const DeviceMemory<std::complex<float>> &a, int lda,        \
2286                    const DeviceMemory<std::complex<float>> &b, int ldb,        \
2287                    std::complex<float> beta,                                   \
2288                    DeviceMemory<std::complex<float>> *c, int ldc) override;    \
2289   bool DoBlasSyr2k(Stream *stream, blas::UpperLower uplo,                      \
2290                    blas::Transpose trans, uint64 n, uint64 k,                  \
2291                    std::complex<double> alpha,                                 \
2292                    const DeviceMemory<std::complex<double>> &a, int lda,       \
2293                    const DeviceMemory<std::complex<double>> &b, int ldb,       \
2294                    std::complex<double> beta,                                  \
2295                    DeviceMemory<std::complex<double>> *c, int ldc) override;   \
2296   bool DoBlasTrmm(Stream *stream, blas::Side side, blas::UpperLower uplo,      \
2297                   blas::Transpose transa, blas::Diagonal diag, uint64 m,       \
2298                   uint64 n, float alpha, const DeviceMemory<float> &a,         \
2299                   int lda, DeviceMemory<float> *b, int ldb) override;          \
2300   bool DoBlasTrmm(Stream *stream, blas::Side side, blas::UpperLower uplo,      \
2301                   blas::Transpose transa, blas::Diagonal diag, uint64 m,       \
2302                   uint64 n, double alpha, const DeviceMemory<double> &a,       \
2303                   int lda, DeviceMemory<double> *b, int ldb) override;         \
2304   bool DoBlasTrmm(Stream *stream, blas::Side side, blas::UpperLower uplo,      \
2305                   blas::Transpose transa, blas::Diagonal diag, uint64 m,       \
2306                   uint64 n, std::complex<float> alpha,                         \
2307                   const DeviceMemory<std::complex<float>> &a, int lda,         \
2308                   DeviceMemory<std::complex<float>> *b, int ldb) override;     \
2309   bool DoBlasTrmm(Stream *stream, blas::Side side, blas::UpperLower uplo,      \
2310                   blas::Transpose transa, blas::Diagonal diag, uint64 m,       \
2311                   uint64 n, std::complex<double> alpha,                        \
2312                   const DeviceMemory<std::complex<double>> &a, int lda,        \
2313                   DeviceMemory<std::complex<double>> *b, int ldb) override;    \
2314   bool DoBlasTrsm(Stream *stream, blas::Side side, blas::UpperLower uplo,      \
2315                   blas::Transpose transa, blas::Diagonal diag, uint64 m,       \
2316                   uint64 n, float alpha, const DeviceMemory<float> &a,         \
2317                   int lda, DeviceMemory<float> *b, int ldb) override;          \
2318   bool DoBlasTrsm(Stream *stream, blas::Side side, blas::UpperLower uplo,      \
2319                   blas::Transpose transa, blas::Diagonal diag, uint64 m,       \
2320                   uint64 n, double alpha, const DeviceMemory<double> &a,       \
2321                   int lda, DeviceMemory<double> *b, int ldb) override;         \
2322   bool DoBlasTrsm(Stream *stream, blas::Side side, blas::UpperLower uplo,      \
2323                   blas::Transpose transa, blas::Diagonal diag, uint64 m,       \
2324                   uint64 n, std::complex<float> alpha,                         \
2325                   const DeviceMemory<std::complex<float>> &a, int lda,         \
2326                   DeviceMemory<std::complex<float>> *b, int ldb) override;     \
2327   bool DoBlasTrsm(Stream *stream, blas::Side side, blas::UpperLower uplo,      \
2328                   blas::Transpose transa, blas::Diagonal diag, uint64 m,       \
2329                   uint64 n, std::complex<double> alpha,                        \
2330                   const DeviceMemory<std::complex<double>> &a, int lda,        \
2331                   DeviceMemory<std::complex<double>> *b, int ldb) override;    \
2332   port::StatusOr<std::unique_ptr<blas::IBlasLtMatmulPlan>>                     \
2333   CreateBlasLtMatmulPlan(const blas::BlasLtMatmulPlanParams &params) override; \
2334   port::StatusOr<std::vector<std::unique_ptr<blas::IBlasLtMatmulAlgorithm>>>   \
2335   GetBlasLtMatmulAlgorithms(const blas::IBlasLtMatmulPlan *plan,               \
2336                             size_t max_workspace_size,                         \
2337                             int max_algorithm_count) override;                 \
2338   bool DoBlasLtMatmul(                                                         \
2339       Stream *stream, const blas::IBlasLtMatmulPlan *plan,                     \
2340       const HostOrDeviceScalar<void> &alpha, DeviceMemoryBase a,               \
2341       DeviceMemoryBase b, const HostOrDeviceScalar<void> &beta,                \
2342       DeviceMemoryBase c, ScratchAllocator *scratch_allocator,                 \
2343       const blas::IBlasLtMatmulAlgorithm *algorithm, DeviceMemoryBase bias,    \
2344       blas::ProfileResult *output_profile_result) override;                    \
2345   port::Status GetVersion(std::string *version) override;
2346 
2347 }  // namespace blas
2348 }  // namespace stream_executor
2349 
2350 #endif  // TENSORFLOW_STREAM_EXECUTOR_BLAS_H_
2351