1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // Exposes the family of BLAS routines as pre-canned high performance calls for
17 // use in conjunction with the StreamExecutor abstraction.
18 //
19 // Note that this interface is optionally supported by platforms; see
20 // StreamExecutor::SupportsBlas() for details.
21 //
22 // This abstraction makes it simple to entrain BLAS operations on GPU data into
23 // a Stream -- users typically will not use this API directly, but will use the
24 // Stream builder methods to entrain these operations "under the hood". For
25 // example:
26 //
27 //  DeviceMemory<float> x = stream_exec->AllocateArray<float>(1024);
28 //  DeviceMemory<float> y = stream_exec->AllocateArray<float>(1024);
29 //  // ... populate x and y ...
30 //  Stream stream{stream_exec};
31 //  stream
32 //    .Init()
33 //    .ThenBlasAxpy(1024, 5.5, x, 1, &y, 1);
34 //  SE_CHECK_OK(stream.BlockHostUntilDone());
35 //
36 // By using stream operations in this manner the user can easily intermix custom
37 // kernel launches (via StreamExecutor::ThenLaunch()) with these pre-canned BLAS
38 // routines.
39 
40 #ifndef TENSORFLOW_STREAM_EXECUTOR_BLAS_H_
41 #define TENSORFLOW_STREAM_EXECUTOR_BLAS_H_
42 
43 #include <complex>
44 #include <vector>
45 
46 #include "tensorflow/stream_executor/host_or_device_scalar.h"
47 #include "tensorflow/stream_executor/lib/array_slice.h"
48 #include "tensorflow/stream_executor/platform/port.h"
49 
50 namespace Eigen {
51 struct half;
52 }  // namespace Eigen
53 
54 namespace stream_executor {
55 
56 class Stream;
57 class ScratchAllocator;
58 
59 template <typename ElemT>
60 class DeviceMemory;
61 
62 namespace blas {
63 
64 // Specifies whether the input matrix will be transposed or
65 // transposed+conjugated before any BLAS operations.
66 enum class Transpose { kNoTranspose, kTranspose, kConjugateTranspose };
67 
68 // Returns a name for t.
69 string TransposeString(Transpose t);
70 
71 // Specifies whether the upper or lower triangular part of a
72 // symmetric/Hermitian matrix is used.
73 enum class UpperLower { kUpper, kLower };
74 
75 // Returns a name for ul.
76 string UpperLowerString(UpperLower ul);
77 
78 // Specifies whether a matrix is unit triangular.
79 enum class Diagonal { kUnit, kNonUnit };
80 
81 // Returns a name for d.
82 string DiagonalString(Diagonal d);
83 
84 // Specifies whether a Hermitian matrix appears on the left or right in
85 // operation.
86 enum class Side { kLeft, kRight };
87 
88 // Returns a name for s.
89 string SideString(Side s);
90 
91 // Type with which intermediate computations of a blas routine are performed.
92 //
93 // Some blas calls can perform computations with a type that's different than
94 // the type of their inputs/outputs.  This lets you e.g. multiply two matricies
95 // of int8s using float32s to store the matmul's intermediate values.
96 enum class ComputationType {
97   kF16,         // 16-bit floating-point
98   kF32,         // 32-bit floating-point
99   kF64,         // 64-bit floating-point
100   kI32,         // 32-bit integer
101   kComplexF32,  // Complex number comprised of two f32s.
102   kComplexF64,  // Complex number comprised of two f64s.
103 };
104 
105 // Converts a ComputationType to a string.
106 string ComputationTypeString(ComputationType ty);
107 
108 std::ostream &operator<<(std::ostream &os, ComputationType ty);
109 
110 // Opaque identifier for an "algorithm" used by a blas routine.  This functions
111 // as a hint to the blas library.
112 typedef int64 AlgorithmType;
113 constexpr AlgorithmType kDefaultAlgorithm = -1;
114 constexpr AlgorithmType kDefaultBlasGemm = -2;
115 constexpr AlgorithmType kDefaultBlasGemv = -3;
116 constexpr AlgorithmType kNoAlgorithm = -4;
117 
118 // blas uses -1 to represent the default algorithm. This happens to match up
119 // with the CUBLAS_GEMM_DFALT constant, so cuda_blas.cc is using static_cast
120 // to convert from AlgorithmType to cublasGemmAlgo_t, and uses a static_assert
121 // to ensure that this assumption does not break.
122 // If another blas implementation uses a different value for the default
123 // algorithm, then it needs to convert kDefaultGemmAlgo to that value
124 // (e.g. via a function called ToWhateverGemmAlgo).
125 constexpr AlgorithmType kDefaultGemmAlgo = -1;
126 
127 // Describes the result of a performance experiment, usually timing the speed of
128 // a particular AlgorithmType.
129 //
130 // If the call we were benchmarking failed (a common occurrence; not all
131 // algorithms are valid for all calls), is_valid() will be false.
132 class ProfileResult {
133  public:
is_valid()134   bool is_valid() const { return is_valid_; }
set_is_valid(bool val)135   void set_is_valid(bool val) { is_valid_ = val; }
algorithm()136   AlgorithmType algorithm() const { return algorithm_; }
set_algorithm(AlgorithmType val)137   void set_algorithm(AlgorithmType val) { algorithm_ = val; }
elapsed_time_in_ms()138   float elapsed_time_in_ms() const { return elapsed_time_in_ms_; }
set_elapsed_time_in_ms(float val)139   void set_elapsed_time_in_ms(float val) { elapsed_time_in_ms_ = val; }
140 
141  private:
142   bool is_valid_ = false;
143   AlgorithmType algorithm_ = kDefaultAlgorithm;
144   float elapsed_time_in_ms_ = std::numeric_limits<float>::max();
145 };
146 
147 class AlgorithmConfig {
148  public:
AlgorithmConfig()149   AlgorithmConfig() : algorithm_(kDefaultAlgorithm) {}
AlgorithmConfig(AlgorithmType algorithm)150   explicit AlgorithmConfig(AlgorithmType algorithm) : algorithm_(algorithm) {}
algorithm()151   AlgorithmType algorithm() const { return algorithm_; }
set_algorithm(AlgorithmType val)152   void set_algorithm(AlgorithmType val) { algorithm_ = val; }
153   bool operator==(const AlgorithmConfig &other) const {
154     return this->algorithm_ == other.algorithm_;
155   }
156   bool operator!=(const AlgorithmConfig &other) const {
157     return !(*this == other);
158   }
159   string ToString() const;
160 
161  private:
162   AlgorithmType algorithm_;
163 };
164 
165 // BLAS support interface -- this can be derived from a GPU executor when the
166 // underlying platform has an BLAS library implementation available. See
167 // StreamExecutor::AsBlas().
168 //
169 // Thread-hostile: CUDA associates a CUDA-context with a particular thread in
170 // the system. Any operation that a user attempts to perform by enqueueing BLAS
171 // operations on a thread not-associated with the CUDA-context has unknown
172 // behavior at the current time; see b/13176597
173 class BlasSupport {
174  public:
~BlasSupport()175   virtual ~BlasSupport() {}
176 
177   // Computes the sum of magnitudes of the vector elements.
178   // result <- |Re x(1)| + |Im x(1)| + |Re  x(2)| + |Im  x(2)|+ ... + |Re  x(n)|
179   // + |Im x(n)|.
180   // Note that Im x(i) = 0 for real types float/double.
181   virtual bool DoBlasAsum(Stream *stream, uint64 elem_count,
182                           const DeviceMemory<float> &x, int incx,
183                           DeviceMemory<float> *result) = 0;
184   virtual bool DoBlasAsum(Stream *stream, uint64 elem_count,
185                           const DeviceMemory<double> &x, int incx,
186                           DeviceMemory<double> *result) = 0;
187   virtual bool DoBlasAsum(Stream *stream, uint64 elem_count,
188                           const DeviceMemory<std::complex<float>> &x, int incx,
189                           DeviceMemory<float> *result) = 0;
190   virtual bool DoBlasAsum(Stream *stream, uint64 elem_count,
191                           const DeviceMemory<std::complex<double>> &x, int incx,
192                           DeviceMemory<double> *result) = 0;
193 
194   // Performs a BLAS y <- ax+y operation.
195   virtual bool DoBlasAxpy(Stream *stream, uint64 elem_count, float alpha,
196                           const DeviceMemory<float> &x, int incx,
197                           DeviceMemory<float> *y, int incy) = 0;
198   virtual bool DoBlasAxpy(Stream *stream, uint64 elem_count, double alpha,
199                           const DeviceMemory<double> &x, int incx,
200                           DeviceMemory<double> *y, int incy) = 0;
201   virtual bool DoBlasAxpy(Stream *stream, uint64 elem_count,
202                           std::complex<float> alpha,
203                           const DeviceMemory<std::complex<float>> &x, int incx,
204                           DeviceMemory<std::complex<float>> *y, int incy) = 0;
205   virtual bool DoBlasAxpy(Stream *stream, uint64 elem_count,
206                           std::complex<double> alpha,
207                           const DeviceMemory<std::complex<double>> &x, int incx,
208                           DeviceMemory<std::complex<double>> *y, int incy) = 0;
209 
210   // Copies vector to another vector: y <- x.
211   virtual bool DoBlasCopy(Stream *stream, uint64 elem_count,
212                           const DeviceMemory<float> &x, int incx,
213                           DeviceMemory<float> *y, int incy) = 0;
214   virtual bool DoBlasCopy(Stream *stream, uint64 elem_count,
215                           const DeviceMemory<double> &x, int incx,
216                           DeviceMemory<double> *y, int incy) = 0;
217   virtual bool DoBlasCopy(Stream *stream, uint64 elem_count,
218                           const DeviceMemory<std::complex<float>> &x, int incx,
219                           DeviceMemory<std::complex<float>> *y, int incy) = 0;
220   virtual bool DoBlasCopy(Stream *stream, uint64 elem_count,
221                           const DeviceMemory<std::complex<double>> &x, int incx,
222                           DeviceMemory<std::complex<double>> *y, int incy) = 0;
223 
224   // Performs a BLAS dot product result <- x . y.
225   virtual bool DoBlasDot(Stream *stream, uint64 elem_count,
226                          const DeviceMemory<float> &x, int incx,
227                          const DeviceMemory<float> &y, int incy,
228                          DeviceMemory<float> *result) = 0;
229   virtual bool DoBlasDot(Stream *stream, uint64 elem_count,
230                          const DeviceMemory<double> &x, int incx,
231                          const DeviceMemory<double> &y, int incy,
232                          DeviceMemory<double> *result) = 0;
233 
234   // Performs a BLAS dot product result <- conj(x) . y for complex types.
235   virtual bool DoBlasDotc(Stream *stream, uint64 elem_count,
236                           const DeviceMemory<std::complex<float>> &x, int incx,
237                           const DeviceMemory<std::complex<float>> &y, int incy,
238                           DeviceMemory<std::complex<float>> *result) = 0;
239   virtual bool DoBlasDotc(Stream *stream, uint64 elem_count,
240                           const DeviceMemory<std::complex<double>> &x, int incx,
241                           const DeviceMemory<std::complex<double>> &y, int incy,
242                           DeviceMemory<std::complex<double>> *result) = 0;
243 
244   // Performs a BLAS dot product result <- x . y for complex types. Note that
245   // x is unconjugated in this routine.
246   virtual bool DoBlasDotu(Stream *stream, uint64 elem_count,
247                           const DeviceMemory<std::complex<float>> &x, int incx,
248                           const DeviceMemory<std::complex<float>> &y, int incy,
249                           DeviceMemory<std::complex<float>> *result) = 0;
250   virtual bool DoBlasDotu(Stream *stream, uint64 elem_count,
251                           const DeviceMemory<std::complex<double>> &x, int incx,
252                           const DeviceMemory<std::complex<double>> &y, int incy,
253                           DeviceMemory<std::complex<double>> *result) = 0;
254 
255   // Computes the Euclidean norm of a vector: result <- ||x||.
256   // See the following link for more information of Euclidean norm:
257   // http://en.wikipedia.org/wiki/Norm_(mathematics)#Euclidean_norm
258   virtual bool DoBlasNrm2(Stream *stream, uint64 elem_count,
259                           const DeviceMemory<float> &x, int incx,
260                           DeviceMemory<float> *result) = 0;
261   virtual bool DoBlasNrm2(Stream *stream, uint64 elem_count,
262                           const DeviceMemory<double> &x, int incx,
263                           DeviceMemory<double> *result) = 0;
264   virtual bool DoBlasNrm2(Stream *stream, uint64 elem_count,
265                           const DeviceMemory<std::complex<float>> &x, int incx,
266                           DeviceMemory<float> *result) = 0;
267   virtual bool DoBlasNrm2(Stream *stream, uint64 elem_count,
268                           const DeviceMemory<std::complex<double>> &x, int incx,
269                           DeviceMemory<double> *result) = 0;
270 
271   // Performs rotation of points in the plane:
272   // x(i) = c*x(i) + s*y(i)
273   // y(i) = c*y(i) - s*x(i).
274   virtual bool DoBlasRot(Stream *stream, uint64 elem_count,
275                          DeviceMemory<float> *x, int incx,
276                          DeviceMemory<float> *y, int incy, float c,
277                          float s) = 0;
278   virtual bool DoBlasRot(Stream *stream, uint64 elem_count,
279                          DeviceMemory<double> *x, int incx,
280                          DeviceMemory<double> *y, int incy, double c,
281                          double s) = 0;
282   virtual bool DoBlasRot(Stream *stream, uint64 elem_count,
283                          DeviceMemory<std::complex<float>> *x, int incx,
284                          DeviceMemory<std::complex<float>> *y, int incy,
285                          float c, float s) = 0;
286   virtual bool DoBlasRot(Stream *stream, uint64 elem_count,
287                          DeviceMemory<std::complex<double>> *x, int incx,
288                          DeviceMemory<std::complex<double>> *y, int incy,
289                          double c, double s) = 0;
290 
291   // Computes the parameters for a Givens rotation.
292   // Given the Cartesian coordinates (a, b) of a point, these routines return
293   // the parameters c, s, r, and z associated with the Givens rotation. The
294   // parameters c and s define a unitary matrix such that:
295   //
296   //   |  c s |.| a | = | r |
297   //   | -s c | | b |   | 0 |
298   //
299   // The parameter z is defined such that if |a| > |b|, z is s; otherwise if
300   // c is not 0 z is 1/c; otherwise z is 1.
301   virtual bool DoBlasRotg(Stream *stream, DeviceMemory<float> *a,
302                           DeviceMemory<float> *b, DeviceMemory<float> *c,
303                           DeviceMemory<float> *s) = 0;
304   virtual bool DoBlasRotg(Stream *stream, DeviceMemory<double> *a,
305                           DeviceMemory<double> *b, DeviceMemory<double> *c,
306                           DeviceMemory<double> *s) = 0;
307   virtual bool DoBlasRotg(Stream *stream, DeviceMemory<std::complex<float>> *a,
308                           DeviceMemory<std::complex<float>> *b,
309                           DeviceMemory<float> *c,
310                           DeviceMemory<std::complex<float>> *s) = 0;
311   virtual bool DoBlasRotg(Stream *stream, DeviceMemory<std::complex<double>> *a,
312                           DeviceMemory<std::complex<double>> *b,
313                           DeviceMemory<double> *c,
314                           DeviceMemory<std::complex<double>> *s) = 0;
315 
316   // Performs modified Givens rotation of points in the plane.
317   // Given two vectors x and y, each vector element of these vectors is replaced
318   // as follows:
319   //
320   //   | x(i) | =  H | x(i) |
321   //   | y(i) |      | y(i) |
322   //
323   // for i=1 to n, where H is a modified Givens transformation matrix whose
324   // values are stored in the param[1] through param[4] array.
325   // For more information please Google this routine.
326   virtual bool DoBlasRotm(Stream *stream, uint64 elem_count,
327                           DeviceMemory<float> *x, int incx,
328                           DeviceMemory<float> *y, int incy,
329                           const DeviceMemory<float> &param) = 0;
330   virtual bool DoBlasRotm(Stream *stream, uint64 elem_count,
331                           DeviceMemory<double> *x, int incx,
332                           DeviceMemory<double> *y, int incy,
333                           const DeviceMemory<double> &param) = 0;
334 
335   // Computes the parameters for a modified Givens rotation.
336   // Given Cartesian coordinates (x1, y1) of an input vector, these routines
337   // compute the components of a modified Givens transformation matrix H that
338   // zeros the y-component of the resulting vector:
339   //
340   //   | x1 | =  H | x1 * sqrt(d1) |
341   //   |  0 |      | y1 * sqrt(d1) |
342   //
343   // For more information please Google this routine.
344   virtual bool DoBlasRotmg(Stream *stream, DeviceMemory<float> *d1,
345                            DeviceMemory<float> *d2, DeviceMemory<float> *x1,
346                            const DeviceMemory<float> &y1,
347                            DeviceMemory<float> *param) = 0;
348   virtual bool DoBlasRotmg(Stream *stream, DeviceMemory<double> *d1,
349                            DeviceMemory<double> *d2, DeviceMemory<double> *x1,
350                            const DeviceMemory<double> &y1,
351                            DeviceMemory<double> *param) = 0;
352 
353   // Computes the product of a vector by a scalar: x <- a*x.
354   virtual bool DoBlasScal(Stream *stream, uint64 elem_count, float alpha,
355                           DeviceMemory<float> *x, int incx) = 0;
356   virtual bool DoBlasScal(Stream *stream, uint64 elem_count, double alpha,
357                           DeviceMemory<double> *x, int incx) = 0;
358   virtual bool DoBlasScal(Stream *stream, uint64 elem_count, float alpha,
359                           DeviceMemory<std::complex<float>> *x, int incx) = 0;
360   virtual bool DoBlasScal(Stream *stream, uint64 elem_count, double alpha,
361                           DeviceMemory<std::complex<double>> *x, int incx) = 0;
362   virtual bool DoBlasScal(Stream *stream, uint64 elem_count,
363                           std::complex<float> alpha,
364                           DeviceMemory<std::complex<float>> *x, int incx) = 0;
365   virtual bool DoBlasScal(Stream *stream, uint64 elem_count,
366                           std::complex<double> alpha,
367                           DeviceMemory<std::complex<double>> *x, int incx) = 0;
368 
369   // Swaps a vector with another vector.
370   virtual bool DoBlasSwap(Stream *stream, uint64 elem_count,
371                           DeviceMemory<float> *x, int incx,
372                           DeviceMemory<float> *y, int incy) = 0;
373   virtual bool DoBlasSwap(Stream *stream, uint64 elem_count,
374                           DeviceMemory<double> *x, int incx,
375                           DeviceMemory<double> *y, int incy) = 0;
376   virtual bool DoBlasSwap(Stream *stream, uint64 elem_count,
377                           DeviceMemory<std::complex<float>> *x, int incx,
378                           DeviceMemory<std::complex<float>> *y, int incy) = 0;
379   virtual bool DoBlasSwap(Stream *stream, uint64 elem_count,
380                           DeviceMemory<std::complex<double>> *x, int incx,
381                           DeviceMemory<std::complex<double>> *y, int incy) = 0;
382 
383   // Finds the index of the element with maximum absolute value.
384   virtual bool DoBlasIamax(Stream *stream, uint64 elem_count,
385                            const DeviceMemory<float> &x, int incx,
386                            DeviceMemory<int> *result) = 0;
387   virtual bool DoBlasIamax(Stream *stream, uint64 elem_count,
388                            const DeviceMemory<double> &x, int incx,
389                            DeviceMemory<int> *result) = 0;
390   virtual bool DoBlasIamax(Stream *stream, uint64 elem_count,
391                            const DeviceMemory<std::complex<float>> &x, int incx,
392                            DeviceMemory<int> *result) = 0;
393   virtual bool DoBlasIamax(Stream *stream, uint64 elem_count,
394                            const DeviceMemory<std::complex<double>> &x,
395                            int incx, DeviceMemory<int> *result) = 0;
396 
397   // Finds the index of the element with minimum absolute value.
398   virtual bool DoBlasIamin(Stream *stream, uint64 elem_count,
399                            const DeviceMemory<float> &x, int incx,
400                            DeviceMemory<int> *result) = 0;
401   virtual bool DoBlasIamin(Stream *stream, uint64 elem_count,
402                            const DeviceMemory<double> &x, int incx,
403                            DeviceMemory<int> *result) = 0;
404   virtual bool DoBlasIamin(Stream *stream, uint64 elem_count,
405                            const DeviceMemory<std::complex<float>> &x, int incx,
406                            DeviceMemory<int> *result) = 0;
407   virtual bool DoBlasIamin(Stream *stream, uint64 elem_count,
408                            const DeviceMemory<std::complex<double>> &x,
409                            int incx, DeviceMemory<int> *result) = 0;
410 
411   // Computes a matrix-vector product using a general band matrix:
412   //
413   //     y <- alpha * a * x + beta * y,
414   // or
415   //     y <- alpha * a' * x + beta * y,
416   // or
417   //     y <- alpha * conj(a') * x + beta * y,
418   //
419   // alpha and beta are scalars; a is an m-by-n general band matrix, with kl
420   // sub-diagonals and ku super-diagonals; x is a vector with
421   // n(trans==kNoTranspose)/m(otherwise) elements;
422   // y is a vector with m(trans==kNoTranspose)/n(otherwise) elements.
423   virtual bool DoBlasGbmv(Stream *stream, blas::Transpose trans, uint64 m,
424                           uint64 n, uint64 kl, uint64 ku, float alpha,
425                           const DeviceMemory<float> &a, int lda,
426                           const DeviceMemory<float> &x, int incx, float beta,
427                           DeviceMemory<float> *y, int incy) = 0;
428   virtual bool DoBlasGbmv(Stream *stream, blas::Transpose trans, uint64 m,
429                           uint64 n, uint64 kl, uint64 ku, double alpha,
430                           const DeviceMemory<double> &a, int lda,
431                           const DeviceMemory<double> &x, int incx, double beta,
432                           DeviceMemory<double> *y, int incy) = 0;
433   virtual bool DoBlasGbmv(Stream *stream, blas::Transpose trans, uint64 m,
434                           uint64 n, uint64 kl, uint64 ku,
435                           std::complex<float> alpha,
436                           const DeviceMemory<std::complex<float>> &a, int lda,
437                           const DeviceMemory<std::complex<float>> &x, int incx,
438                           std::complex<float> beta,
439                           DeviceMemory<std::complex<float>> *y, int incy) = 0;
440   virtual bool DoBlasGbmv(Stream *stream, blas::Transpose trans, uint64 m,
441                           uint64 n, uint64 kl, uint64 ku,
442                           std::complex<double> alpha,
443                           const DeviceMemory<std::complex<double>> &a, int lda,
444                           const DeviceMemory<std::complex<double>> &x, int incx,
445                           std::complex<double> beta,
446                           DeviceMemory<std::complex<double>> *y, int incy) = 0;
447 
448   // Computes a matrix-vector product using a general matrix.
449   //
450   //     y <- alpha * a * x + beta * y,
451   // or
452   //     y <- alpha * a' * x + beta * y,
453   // or
454   //     y <- alpha * conj(a') * x + beta * y,
455   //
456   // alpha and beta are scalars; a is an m-by-n general matrix; x is a vector
457   // with n(trans==kNoTranspose)/m(otherwise) elements;
458   // y is a vector with m(trans==kNoTranspose)/n(otherwise) elements.
459   virtual bool DoBlasGemv(Stream *stream, blas::Transpose trans, uint64 m,
460                           uint64 n, float alpha, const DeviceMemory<float> &a,
461                           int lda, const DeviceMemory<float> &x, int incx,
462                           float beta, DeviceMemory<float> *y, int incy) = 0;
463   virtual bool DoBlasGemv(Stream *stream, blas::Transpose trans, uint64 m,
464                           uint64 n, double alpha, const DeviceMemory<double> &a,
465                           int lda, const DeviceMemory<double> &x, int incx,
466                           double beta, DeviceMemory<double> *y, int incy) = 0;
467   virtual bool DoBlasGemv(Stream *stream, blas::Transpose trans, uint64 m,
468                           uint64 n, std::complex<float> alpha,
469                           const DeviceMemory<std::complex<float>> &a, int lda,
470                           const DeviceMemory<std::complex<float>> &x, int incx,
471                           std::complex<float> beta,
472                           DeviceMemory<std::complex<float>> *y, int incy) = 0;
473   virtual bool DoBlasGemv(Stream *stream, blas::Transpose trans, uint64 m,
474                           uint64 n, std::complex<double> alpha,
475                           const DeviceMemory<std::complex<double>> &a, int lda,
476                           const DeviceMemory<std::complex<double>> &x, int incx,
477                           std::complex<double> beta,
478                           DeviceMemory<std::complex<double>> *y, int incy) = 0;
479 
480   virtual bool DoBlasGemvWithProfiling(
481       Stream *stream, blas::Transpose trans, uint64 m, uint64 n, float alpha,
482       const DeviceMemory<float> &a, int lda, const DeviceMemory<float> &x,
483       int incx, float beta, DeviceMemory<float> *y, int incy,
484       ProfileResult *output_profile_result) = 0;
485   virtual bool DoBlasGemvWithProfiling(
486       Stream *stream, blas::Transpose trans, uint64 m, uint64 n, double alpha,
487       const DeviceMemory<double> &a, int lda, const DeviceMemory<double> &x,
488       int incx, double beta, DeviceMemory<double> *y, int incy,
489       ProfileResult *output_profile_result) = 0;
490   virtual bool DoBlasGemvWithProfiling(
491       Stream *stream, blas::Transpose trans, uint64 m, uint64 n,
492       std::complex<float> alpha, const DeviceMemory<std::complex<float>> &a,
493       int lda, const DeviceMemory<std::complex<float>> &x, int incx,
494       std::complex<float> beta, DeviceMemory<std::complex<float>> *y, int incy,
495       ProfileResult *output_profile_result) = 0;
496   virtual bool DoBlasGemvWithProfiling(
497       Stream *stream, blas::Transpose trans, uint64 m, uint64 n,
498       std::complex<double> alpha, const DeviceMemory<std::complex<double>> &a,
499       int lda, const DeviceMemory<std::complex<double>> &x, int incx,
500       std::complex<double> beta, DeviceMemory<std::complex<double>> *y,
501       int incy, ProfileResult *output_profile_result) = 0;
502 
503   // Performs a rank-1 update of a general matrix.
504   //
505   //     a <- alpha * x * y' + a,
506   //
507   // alpha is a scalar; x is an m-element vector; y is an n-element vector; a is
508   // an m-by-n general matrix.
509   virtual bool DoBlasGer(Stream *stream, uint64 m, uint64 n, float alpha,
510                          const DeviceMemory<float> &x, int incx,
511                          const DeviceMemory<float> &y, int incy,
512                          DeviceMemory<float> *a, int lda) = 0;
513   virtual bool DoBlasGer(Stream *stream, uint64 m, uint64 n, double alpha,
514                          const DeviceMemory<double> &x, int incx,
515                          const DeviceMemory<double> &y, int incy,
516                          DeviceMemory<double> *a, int lda) = 0;
517 
518   // Performs a rank-1 update (conjugated) of a general matrix.
519   //
520   //     a <- alpha * x * conj(y') + a,
521   //
522   // alpha is a scalar; x is an m-element vector; y is an n-element vector; a is
523   // an m-by-n general matrix.
524   virtual bool DoBlasGerc(Stream *stream, uint64 m, uint64 n,
525                           std::complex<float> alpha,
526                           const DeviceMemory<std::complex<float>> &x, int incx,
527                           const DeviceMemory<std::complex<float>> &y, int incy,
528                           DeviceMemory<std::complex<float>> *a, int lda) = 0;
529   virtual bool DoBlasGerc(Stream *stream, uint64 m, uint64 n,
530                           std::complex<double> alpha,
531                           const DeviceMemory<std::complex<double>> &x, int incx,
532                           const DeviceMemory<std::complex<double>> &y, int incy,
533                           DeviceMemory<std::complex<double>> *a, int lda) = 0;
534 
535   // Performs a rank-1 update (unconjugated) of a general matrix.
536   //
537   //     a <- alpha * x * y' + a,
538   //
539   // alpha is a scalar; x is an m-element vector; y is an n-element vector; a is
540   // an m-by-n general matrix.
541   virtual bool DoBlasGeru(Stream *stream, uint64 m, uint64 n,
542                           std::complex<float> alpha,
543                           const DeviceMemory<std::complex<float>> &x, int incx,
544                           const DeviceMemory<std::complex<float>> &y, int incy,
545                           DeviceMemory<std::complex<float>> *a, int lda) = 0;
546   virtual bool DoBlasGeru(Stream *stream, uint64 m, uint64 n,
547                           std::complex<double> alpha,
548                           const DeviceMemory<std::complex<double>> &x, int incx,
549                           const DeviceMemory<std::complex<double>> &y, int incy,
550                           DeviceMemory<std::complex<double>> *a, int lda) = 0;
551 
552   // Computes a matrix-vector product using a Hermitian band matrix.
553   //
554   //     y <- alpha * a * x + beta * y,
555   //
556   // alpha and beta are scalars; a is an n-by-n Hermitian band matrix, with k
557   // super-diagonals; x and y are n-element vectors.
558   virtual bool DoBlasHbmv(Stream *stream, blas::UpperLower uplo, uint64 n,
559                           uint64 k, std::complex<float> alpha,
560                           const DeviceMemory<std::complex<float>> &a, int lda,
561                           const DeviceMemory<std::complex<float>> &x, int incx,
562                           std::complex<float> beta,
563                           DeviceMemory<std::complex<float>> *y, int incy) = 0;
564   virtual bool DoBlasHbmv(Stream *stream, blas::UpperLower uplo, uint64 n,
565                           uint64 k, std::complex<double> alpha,
566                           const DeviceMemory<std::complex<double>> &a, int lda,
567                           const DeviceMemory<std::complex<double>> &x, int incx,
568                           std::complex<double> beta,
569                           DeviceMemory<std::complex<double>> *y, int incy) = 0;
570 
571   // Computes a matrix-vector product using a Hermitian matrix.
572   //
573   //     y <- alpha * a * x + beta * y,
574   //
575   // alpha and beta are scalars; a is an n-by-n Hermitian matrix; x and y are
576   // n-element vectors.
577   virtual bool DoBlasHemv(Stream *stream, blas::UpperLower uplo, uint64 n,
578                           std::complex<float> alpha,
579                           const DeviceMemory<std::complex<float>> &a, int lda,
580                           const DeviceMemory<std::complex<float>> &x, int incx,
581                           std::complex<float> beta,
582                           DeviceMemory<std::complex<float>> *y, int incy) = 0;
583   virtual bool DoBlasHemv(Stream *stream, blas::UpperLower uplo, uint64 n,
584                           std::complex<double> alpha,
585                           const DeviceMemory<std::complex<double>> &a, int lda,
586                           const DeviceMemory<std::complex<double>> &x, int incx,
587                           std::complex<double> beta,
588                           DeviceMemory<std::complex<double>> *y, int incy) = 0;
589 
590   // Performs a rank-1 update of a Hermitian matrix.
591   //
592   //     a <- alpha * x * conj(x') + a,
593   //
594   // alpha is a scalar; x is an n-element vector; a is an n-by-n Hermitian
595   // matrix.
596   virtual bool DoBlasHer(Stream *stream, blas::UpperLower uplo, uint64 n,
597                          float alpha,
598                          const DeviceMemory<std::complex<float>> &x, int incx,
599                          DeviceMemory<std::complex<float>> *a, int lda) = 0;
600   virtual bool DoBlasHer(Stream *stream, blas::UpperLower uplo, uint64 n,
601                          double alpha,
602                          const DeviceMemory<std::complex<double>> &x, int incx,
603                          DeviceMemory<std::complex<double>> *a, int lda) = 0;
604 
605   // Performs a rank-2 update of a Hermitian matrix.
606   //
607   //     a <- alpha * x * conj(x') + conj(alpha) * y * conj(x') + a,
608   //
609   // alpha is a scalar; x and y are n-element vectors; a is an n-by-n Hermitian
610   // matrix.
611   virtual bool DoBlasHer2(Stream *stream, blas::UpperLower uplo, uint64 n,
612                           std::complex<float> alpha,
613                           const DeviceMemory<std::complex<float>> &x, int incx,
614                           const DeviceMemory<std::complex<float>> &y, int incy,
615                           DeviceMemory<std::complex<float>> *a, int lda) = 0;
616   virtual bool DoBlasHer2(Stream *stream, blas::UpperLower uplo, uint64 n,
617                           std::complex<double> alpha,
618                           const DeviceMemory<std::complex<double>> &x, int incx,
619                           const DeviceMemory<std::complex<double>> &y, int incy,
620                           DeviceMemory<std::complex<double>> *a, int lda) = 0;
621 
622   // Computes a matrix-vector product using a Hermitian packed matrix.
623   //
624   //     y <- alpha * a * x + beta * y,
625   //
626   // alpha and beta are scalars; a is an n-by-n Hermitian matrix, supplied in
627   // packed form; x and y are n-element vectors.
628   virtual bool DoBlasHpmv(Stream *stream, blas::UpperLower uplo, uint64 n,
629                           std::complex<float> alpha,
630                           const DeviceMemory<std::complex<float>> &ap,
631                           const DeviceMemory<std::complex<float>> &x, int incx,
632                           std::complex<float> beta,
633                           DeviceMemory<std::complex<float>> *y, int incy) = 0;
634   virtual bool DoBlasHpmv(Stream *stream, blas::UpperLower uplo, uint64 n,
635                           std::complex<double> alpha,
636                           const DeviceMemory<std::complex<double>> &ap,
637                           const DeviceMemory<std::complex<double>> &x, int incx,
638                           std::complex<double> beta,
639                           DeviceMemory<std::complex<double>> *y, int incy) = 0;
640 
641   // Performs a rank-1 update of a Hermitian packed matrix.
642   //
643   //     a <- alpha * x * conj(x') + a,
644   //
645   // alpha is a scalar; x is an n-element vector; a is an n-by-n Hermitian
646   // matrix, supplied in packed form.
647   virtual bool DoBlasHpr(Stream *stream, blas::UpperLower uplo, uint64 n,
648                          float alpha,
649                          const DeviceMemory<std::complex<float>> &x, int incx,
650                          DeviceMemory<std::complex<float>> *ap) = 0;
651   virtual bool DoBlasHpr(Stream *stream, blas::UpperLower uplo, uint64 n,
652                          double alpha,
653                          const DeviceMemory<std::complex<double>> &x, int incx,
654                          DeviceMemory<std::complex<double>> *ap) = 0;
655 
656   // Performs a rank-2 update of a Hermitian packed matrix.
657   //
658   //     a <- alpha * x * conj(x') + conj(alpha) * y * conj(x') + a,
659   //
660   // alpha is a scalar; x and y are n-element vectors; a is an n-by-n Hermitian
661   // matrix, supplied in packed form.
662   virtual bool DoBlasHpr2(Stream *stream, blas::UpperLower uplo, uint64 n,
663                           std::complex<float> alpha,
664                           const DeviceMemory<std::complex<float>> &x, int incx,
665                           const DeviceMemory<std::complex<float>> &y, int incy,
666                           DeviceMemory<std::complex<float>> *ap) = 0;
667   virtual bool DoBlasHpr2(Stream *stream, blas::UpperLower uplo, uint64 n,
668                           std::complex<double> alpha,
669                           const DeviceMemory<std::complex<double>> &x, int incx,
670                           const DeviceMemory<std::complex<double>> &y, int incy,
671                           DeviceMemory<std::complex<double>> *ap) = 0;
672 
673   // Computes a matrix-vector product using a symmetric band matrix.
674   //
675   //     y <- alpha * a * x + beta * y,
676   //
677   // alpha and beta are scalars; a is an n-by-n symmetric band matrix, with k
678   // super-diagonals; x and y are n-element vectors.
679   virtual bool DoBlasSbmv(Stream *stream, blas::UpperLower uplo, uint64 n,
680                           uint64 k, float alpha, const DeviceMemory<float> &a,
681                           int lda, const DeviceMemory<float> &x, int incx,
682                           float beta, DeviceMemory<float> *y, int incy) = 0;
683   virtual bool DoBlasSbmv(Stream *stream, blas::UpperLower uplo, uint64 n,
684                           uint64 k, double alpha, const DeviceMemory<double> &a,
685                           int lda, const DeviceMemory<double> &x, int incx,
686                           double beta, DeviceMemory<double> *y, int incy) = 0;
687 
688   // Computes a matrix-vector product using a symmetric packed matrix.
689   //
690   //     y <- alpha * a * x + beta * y,
691   //
692   // alpha and beta are scalars; a is an n-by-n symmetric matrix, supplied in
693   // packed form; x and y are n-element vectors.
694   virtual bool DoBlasSpmv(Stream *stream, blas::UpperLower uplo, uint64 n,
695                           float alpha, const DeviceMemory<float> &ap,
696                           const DeviceMemory<float> &x, int incx, float beta,
697                           DeviceMemory<float> *y, int incy) = 0;
698   virtual bool DoBlasSpmv(Stream *stream, blas::UpperLower uplo, uint64 n,
699                           double alpha, const DeviceMemory<double> &ap,
700                           const DeviceMemory<double> &x, int incx, double beta,
701                           DeviceMemory<double> *y, int incy) = 0;
702 
703   // Performs a rank-1 update of a symmetric packed matrix.
704   //
705   //     a <- alpha * x * x' + a,
706   //
707   // alpha is a scalar; x is an n-element vector; a is an n-by-n symmetric
708   // matrix, supplied in packed form.
709   virtual bool DoBlasSpr(Stream *stream, blas::UpperLower uplo, uint64 n,
710                          float alpha, const DeviceMemory<float> &x, int incx,
711                          DeviceMemory<float> *ap) = 0;
712   virtual bool DoBlasSpr(Stream *stream, blas::UpperLower uplo, uint64 n,
713                          double alpha, const DeviceMemory<double> &x, int incx,
714                          DeviceMemory<double> *ap) = 0;
715 
716   // Performs a rank-2 update of a symmetric packed matrix.
717   //
718   //     a <- alpha * x * x' + alpha * y * x' + a,
719   //
720   // alpha is a scalar; x and y are n-element vectors; a is an n-by-n symmetric
721   // matrix, supplied in packed form.
722   virtual bool DoBlasSpr2(Stream *stream, blas::UpperLower uplo, uint64 n,
723                           float alpha, const DeviceMemory<float> &x, int incx,
724                           const DeviceMemory<float> &y, int incy,
725                           DeviceMemory<float> *ap) = 0;
726   virtual bool DoBlasSpr2(Stream *stream, blas::UpperLower uplo, uint64 n,
727                           double alpha, const DeviceMemory<double> &x, int incx,
728                           const DeviceMemory<double> &y, int incy,
729                           DeviceMemory<double> *ap) = 0;
730 
731   // Computes a matrix-vector product for a symmetric matrix.
732   //
733   //     y <- alpha * a * x + beta * y,
734   //
735   // alpha and beta are scalars; a is an n-by-n symmetric matrix; x and y are
736   // n-element vectors.
737   virtual bool DoBlasSymv(Stream *stream, blas::UpperLower uplo, uint64 n,
738                           float alpha, const DeviceMemory<float> &a, int lda,
739                           const DeviceMemory<float> &x, int incx, float beta,
740                           DeviceMemory<float> *y, int incy) = 0;
741   virtual bool DoBlasSymv(Stream *stream, blas::UpperLower uplo, uint64 n,
742                           double alpha, const DeviceMemory<double> &a, int lda,
743                           const DeviceMemory<double> &x, int incx, double beta,
744                           DeviceMemory<double> *y, int incy) = 0;
745 
746   // Performs a rank-1 update of a symmetric matrix.
747   //
748   //     a <- alpha * x * x' + a,
749   //
750   // alpha is a scalar; x is an n-element vector; a is an n-by-n symmetric
751   // matrix.
752   virtual bool DoBlasSyr(Stream *stream, blas::UpperLower uplo, uint64 n,
753                          float alpha, const DeviceMemory<float> &x, int incx,
754                          DeviceMemory<float> *a, int lda) = 0;
755   virtual bool DoBlasSyr(Stream *stream, blas::UpperLower uplo, uint64 n,
756                          double alpha, const DeviceMemory<double> &x, int incx,
757                          DeviceMemory<double> *a, int lda) = 0;
758 
759   // Performs a rank-2 update of symmetric matrix.
760   //
761   //     a <- alpha * x * x' + alpha * y * x' + a,
762   //
763   // alpha is a scalar; x and y are n-element vectors; a is an n-by-n symmetric
764   // matrix.
765   virtual bool DoBlasSyr2(Stream *stream, blas::UpperLower uplo, uint64 n,
766                           float alpha, const DeviceMemory<float> &x, int incx,
767                           const DeviceMemory<float> &y, int incy,
768                           DeviceMemory<float> *a, int lda) = 0;
769   virtual bool DoBlasSyr2(Stream *stream, blas::UpperLower uplo, uint64 n,
770                           double alpha, const DeviceMemory<double> &x, int incx,
771                           const DeviceMemory<double> &y, int incy,
772                           DeviceMemory<double> *a, int lda) = 0;
773 
774   // Computes a matrix-vector product using a triangular band matrix.
775   //
776   //     x <- a * x,
777   // or
778   //     x <- a' * x,
779   // or
780   //     x <- conj(a') * x,
781   //
782   // a is an n-by-n unit, or non-unit, upper or lower triangular band matrix,
783   // with k+1 diagonals; x is a n-element vector.
784   virtual bool DoBlasTbmv(Stream *stream, blas::UpperLower uplo,
785                           blas::Transpose trans, blas::Diagonal diag, uint64 n,
786                           uint64 k, const DeviceMemory<float> &a, int lda,
787                           DeviceMemory<float> *x, int incx) = 0;
788   virtual bool DoBlasTbmv(Stream *stream, blas::UpperLower uplo,
789                           blas::Transpose trans, blas::Diagonal diag, uint64 n,
790                           uint64 k, const DeviceMemory<double> &a, int lda,
791                           DeviceMemory<double> *x, int incx) = 0;
792   virtual bool DoBlasTbmv(Stream *stream, blas::UpperLower uplo,
793                           blas::Transpose trans, blas::Diagonal diag, uint64 n,
794                           uint64 k, const DeviceMemory<std::complex<float>> &a,
795                           int lda, DeviceMemory<std::complex<float>> *x,
796                           int incx) = 0;
797   virtual bool DoBlasTbmv(Stream *stream, blas::UpperLower uplo,
798                           blas::Transpose trans, blas::Diagonal diag, uint64 n,
799                           uint64 k, const DeviceMemory<std::complex<double>> &a,
800                           int lda, DeviceMemory<std::complex<double>> *x,
801                           int incx) = 0;
802 
803   // Solves a system of linear equations whose coefficients are in a triangular
804   // band matrix as below:
805   //
806   //     a * x = b,
807   // or
808   //     a' * x = b,
809   // or
810   //     conj(a') * x = b,
811   //
812   // b and x are n-element vectors; a is an n-by-n unit, or non-unit, upper or
813   // lower triangular band matrix, with k+1 diagonals.
814   virtual bool DoBlasTbsv(Stream *stream, blas::UpperLower uplo,
815                           blas::Transpose trans, blas::Diagonal diag, uint64 n,
816                           uint64 k, const DeviceMemory<float> &a, int lda,
817                           DeviceMemory<float> *x, int incx) = 0;
818   virtual bool DoBlasTbsv(Stream *stream, blas::UpperLower uplo,
819                           blas::Transpose trans, blas::Diagonal diag, uint64 n,
820                           uint64 k, const DeviceMemory<double> &a, int lda,
821                           DeviceMemory<double> *x, int incx) = 0;
822   virtual bool DoBlasTbsv(Stream *stream, blas::UpperLower uplo,
823                           blas::Transpose trans, blas::Diagonal diag, uint64 n,
824                           uint64 k, const DeviceMemory<std::complex<float>> &a,
825                           int lda, DeviceMemory<std::complex<float>> *x,
826                           int incx) = 0;
827   virtual bool DoBlasTbsv(Stream *stream, blas::UpperLower uplo,
828                           blas::Transpose trans, blas::Diagonal diag, uint64 n,
829                           uint64 k, const DeviceMemory<std::complex<double>> &a,
830                           int lda, DeviceMemory<std::complex<double>> *x,
831                           int incx) = 0;
832 
833   // Computes a matrix-vector product using a triangular packed matrix.
834   //
835   //     x <- a * x,
836   // or
837   //     x <- a' * x,
838   // or
839   //     x <- conj(a') * x,
840   //
841   // a is an n-by-n unit, or non-unit, upper or lower triangular matrix,
842   // supplied in packed form; x is a n-element vector.
843   virtual bool DoBlasTpmv(Stream *stream, blas::UpperLower uplo,
844                           blas::Transpose trans, blas::Diagonal diag, uint64 n,
845                           const DeviceMemory<float> &ap, DeviceMemory<float> *x,
846                           int incx) = 0;
847   virtual bool DoBlasTpmv(Stream *stream, blas::UpperLower uplo,
848                           blas::Transpose trans, blas::Diagonal diag, uint64 n,
849                           const DeviceMemory<double> &ap,
850                           DeviceMemory<double> *x, int incx) = 0;
851   virtual bool DoBlasTpmv(Stream *stream, blas::UpperLower uplo,
852                           blas::Transpose trans, blas::Diagonal diag, uint64 n,
853                           const DeviceMemory<std::complex<float>> &ap,
854                           DeviceMemory<std::complex<float>> *x, int incx) = 0;
855   virtual bool DoBlasTpmv(Stream *stream, blas::UpperLower uplo,
856                           blas::Transpose trans, blas::Diagonal diag, uint64 n,
857                           const DeviceMemory<std::complex<double>> &ap,
858                           DeviceMemory<std::complex<double>> *x, int incx) = 0;
859 
860   // Solves a system of linear equations whose coefficients are in a triangular
861   // packed matrix as below:
862   //
863   //     a * x = b,
864   // or
865   //     a' * x = b,
866   // or
867   //     conj(a') * x = b,
868   //
869   // b and x are n-element vectors; a is an n-by-n unit, or non-unit, upper or
870   // lower triangular matrix, supplied in packed form.
871   virtual bool DoBlasTpsv(Stream *stream, blas::UpperLower uplo,
872                           blas::Transpose trans, blas::Diagonal diag, uint64 n,
873                           const DeviceMemory<float> &ap, DeviceMemory<float> *x,
874                           int incx) = 0;
875   virtual bool DoBlasTpsv(Stream *stream, blas::UpperLower uplo,
876                           blas::Transpose trans, blas::Diagonal diag, uint64 n,
877                           const DeviceMemory<double> &ap,
878                           DeviceMemory<double> *x, int incx) = 0;
879   virtual bool DoBlasTpsv(Stream *stream, blas::UpperLower uplo,
880                           blas::Transpose trans, blas::Diagonal diag, uint64 n,
881                           const DeviceMemory<std::complex<float>> &ap,
882                           DeviceMemory<std::complex<float>> *x, int incx) = 0;
883   virtual bool DoBlasTpsv(Stream *stream, blas::UpperLower uplo,
884                           blas::Transpose trans, blas::Diagonal diag, uint64 n,
885                           const DeviceMemory<std::complex<double>> &ap,
886                           DeviceMemory<std::complex<double>> *x, int incx) = 0;
887 
888   // Computes a matrix-vector product using a triangular matrix.
889   //
890   //     x <- a * x,
891   // or
892   //     x <- a' * x,
893   // or
894   //     x <- conj(a') * x,
895   //
896   // a is an n-by-n unit, or non-unit, upper or lower triangular matrix; x is a
897   // n-element vector.
898   virtual bool DoBlasTrmv(Stream *stream, blas::UpperLower uplo,
899                           blas::Transpose trans, blas::Diagonal diag, uint64 n,
900                           const DeviceMemory<float> &a, int lda,
901                           DeviceMemory<float> *x, int incx) = 0;
902   virtual bool DoBlasTrmv(Stream *stream, blas::UpperLower uplo,
903                           blas::Transpose trans, blas::Diagonal diag, uint64 n,
904                           const DeviceMemory<double> &a, int lda,
905                           DeviceMemory<double> *x, int incx) = 0;
906   virtual bool DoBlasTrmv(Stream *stream, blas::UpperLower uplo,
907                           blas::Transpose trans, blas::Diagonal diag, uint64 n,
908                           const DeviceMemory<std::complex<float>> &a, int lda,
909                           DeviceMemory<std::complex<float>> *x, int incx) = 0;
910   virtual bool DoBlasTrmv(Stream *stream, blas::UpperLower uplo,
911                           blas::Transpose trans, blas::Diagonal diag, uint64 n,
912                           const DeviceMemory<std::complex<double>> &a, int lda,
913                           DeviceMemory<std::complex<double>> *x, int incx) = 0;
914 
915   // Solves a system of linear equations whose coefficients are in a triangular
916   // matrix as below:
917   //
918   //     a * x = b,
919   // or
920   //     a' * x = b,
921   // or
922   //     conj(a') * x = b,
923   //
924   // b and x are n-element vectors; a is an n-by-n unit, or non-unit, upper or
925   // lower triangular matrix.
926   virtual bool DoBlasTrsv(Stream *stream, blas::UpperLower uplo,
927                           blas::Transpose trans, blas::Diagonal diag, uint64 n,
928                           const DeviceMemory<float> &a, int lda,
929                           DeviceMemory<float> *x, int incx) = 0;
930   virtual bool DoBlasTrsv(Stream *stream, blas::UpperLower uplo,
931                           blas::Transpose trans, blas::Diagonal diag, uint64 n,
932                           const DeviceMemory<double> &a, int lda,
933                           DeviceMemory<double> *x, int incx) = 0;
934   virtual bool DoBlasTrsv(Stream *stream, blas::UpperLower uplo,
935                           blas::Transpose trans, blas::Diagonal diag, uint64 n,
936                           const DeviceMemory<std::complex<float>> &a, int lda,
937                           DeviceMemory<std::complex<float>> *x, int incx) = 0;
938   virtual bool DoBlasTrsv(Stream *stream, blas::UpperLower uplo,
939                           blas::Transpose trans, blas::Diagonal diag, uint64 n,
940                           const DeviceMemory<std::complex<double>> &a, int lda,
941                           DeviceMemory<std::complex<double>> *x, int incx) = 0;
942 
943   // Computes a matrix-matrix product with general matrices:
944   //
945   //     c <- alpha * op(a) * op(b) + beta * c,
946   //
947   // op(X) is one of op(X) = X, or op(X) = X', or op(X) = conj(X'); alpha and
948   // beta are scalars; a, b, and c are matrices; op(a) is an m-by-k matrix;
949   // op(b) is a k-by-n matrix; c is an m-by-n matrix.
950   //
951   // Note: The half interface uses float precision internally; the version
952   // that uses half precision internally is not yet supported. There is no
953   // batched version of the half-precision interface.
954   virtual bool DoBlasGemm(Stream *stream, blas::Transpose transa,
955                           blas::Transpose transb, uint64 m, uint64 n, uint64 k,
956                           float alpha, const DeviceMemory<Eigen::half> &a,
957                           int lda, const DeviceMemory<Eigen::half> &b, int ldb,
958                           float beta, DeviceMemory<Eigen::half> *c,
959                           int ldc) = 0;
960   virtual bool DoBlasGemm(Stream *stream, blas::Transpose transa,
961                           blas::Transpose transb, uint64 m, uint64 n, uint64 k,
962                           float alpha, const DeviceMemory<float> &a, int lda,
963                           const DeviceMemory<float> &b, int ldb, float beta,
964                           DeviceMemory<float> *c, int ldc) = 0;
965   virtual bool DoBlasGemm(Stream *stream, blas::Transpose transa,
966                           blas::Transpose transb, uint64 m, uint64 n, uint64 k,
967                           double alpha, const DeviceMemory<double> &a, int lda,
968                           const DeviceMemory<double> &b, int ldb, double beta,
969                           DeviceMemory<double> *c, int ldc) = 0;
970   virtual bool DoBlasGemm(Stream *stream, blas::Transpose transa,
971                           blas::Transpose transb, uint64 m, uint64 n, uint64 k,
972                           std::complex<float> alpha,
973                           const DeviceMemory<std::complex<float>> &a, int lda,
974                           const DeviceMemory<std::complex<float>> &b, int ldb,
975                           std::complex<float> beta,
976                           DeviceMemory<std::complex<float>> *c, int ldc) = 0;
977   virtual bool DoBlasGemm(Stream *stream, blas::Transpose transa,
978                           blas::Transpose transb, uint64 m, uint64 n, uint64 k,
979                           std::complex<double> alpha,
980                           const DeviceMemory<std::complex<double>> &a, int lda,
981                           const DeviceMemory<std::complex<double>> &b, int ldb,
982                           std::complex<double> beta,
983                           DeviceMemory<std::complex<double>> *c, int ldc) = 0;
984 
985   virtual bool DoBlasGemmWithProfiling(
986       Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m,
987       uint64 n, uint64 k, float alpha, const DeviceMemory<Eigen::half> &a,
988       int lda, const DeviceMemory<Eigen::half> &b, int ldb, float beta,
989       DeviceMemory<Eigen::half> *c, int ldc,
990       ProfileResult *output_profile_result) = 0;
991   virtual bool DoBlasGemmWithProfiling(
992       Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m,
993       uint64 n, uint64 k, float alpha, const DeviceMemory<float> &a, int lda,
994       const DeviceMemory<float> &b, int ldb, float beta, DeviceMemory<float> *c,
995       int ldc, ProfileResult *output_profile_result) = 0;
996   virtual bool DoBlasGemmWithProfiling(
997       Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m,
998       uint64 n, uint64 k, double alpha, const DeviceMemory<double> &a, int lda,
999       const DeviceMemory<double> &b, int ldb, double beta,
1000       DeviceMemory<double> *c, int ldc,
1001       ProfileResult *output_profile_result) = 0;
1002   virtual bool DoBlasGemmWithProfiling(
1003       Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m,
1004       uint64 n, uint64 k, std::complex<float> alpha,
1005       const DeviceMemory<std::complex<float>> &a, int lda,
1006       const DeviceMemory<std::complex<float>> &b, int ldb,
1007       std::complex<float> beta, DeviceMemory<std::complex<float>> *c, int ldc,
1008       ProfileResult *output_profile_result) = 0;
1009   virtual bool DoBlasGemmWithProfiling(
1010       Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m,
1011       uint64 n, uint64 k, std::complex<double> alpha,
1012       const DeviceMemory<std::complex<double>> &a, int lda,
1013       const DeviceMemory<std::complex<double>> &b, int ldb,
1014       std::complex<double> beta, DeviceMemory<std::complex<double>> *c, int ldc,
1015       ProfileResult *output_profile_result) = 0;
1016 
1017   // Gets a list of supported algorithms for DoBlasGemmWithAlgorithm.
1018   virtual bool GetBlasGemmAlgorithms(
1019       std::vector<AlgorithmType> *out_algorithms) = 0;
1020 
1021   // Like DoBlasGemm, but accepts an algorithm and an compute type.
1022   //
1023   // The compute type lets you say (e.g.) that the inputs and outputs are
1024   // Eigen::halfs, but you want the internal computations to be done with
1025   // float32 precision.
1026   //
1027   // Note the subtle difference in the version that accepts Eigen:::half --
1028   // alpha and beta have type const Eigen::half&, not float.
1029   //
1030   // If output_profile_result is not null, a failure here does not put the
1031   // stream in a failure state.  Instead, success/failure is indicated by
1032   // output_profile_result->is_valid().  This lets you use this function for
1033   // choosing the best algorithm among many (some of which may fail) without
1034   // creating a new Stream for each attempt.
1035   virtual bool DoBlasGemmWithAlgorithm(
1036       Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m,
1037       uint64 n, uint64 k, const HostOrDeviceScalar<int> &alpha,
1038       const DeviceMemory<int8> &a, int lda, const DeviceMemory<int8> &b,
1039       int ldb, const HostOrDeviceScalar<int> &beta, DeviceMemory<int32> *c,
1040       int ldc, ComputationType computation_type, AlgorithmType algorithm,
1041       ProfileResult *output_profile_result) = 0;
1042   virtual bool DoBlasGemmWithAlgorithm(
1043       Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m,
1044       uint64 n, uint64 k, const HostOrDeviceScalar<Eigen::half> &alpha,
1045       const DeviceMemory<Eigen::half> &a, int lda,
1046       const DeviceMemory<Eigen::half> &b, int ldb,
1047       const HostOrDeviceScalar<Eigen::half> &beta, DeviceMemory<Eigen::half> *c,
1048       int ldc, ComputationType computation_type, AlgorithmType algorithm,
1049       ProfileResult *output_profile_result) = 0;
1050   virtual bool DoBlasGemmWithAlgorithm(
1051       Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m,
1052       uint64 n, uint64 k, const HostOrDeviceScalar<float> &alpha,
1053       const DeviceMemory<float> &a, int lda, const DeviceMemory<float> &b,
1054       int ldb, const HostOrDeviceScalar<float> &beta, DeviceMemory<float> *c,
1055       int ldc, ComputationType computation_type, AlgorithmType algorithm,
1056       ProfileResult *output_profile_result) = 0;
1057   virtual bool DoBlasGemmWithAlgorithm(
1058       Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m,
1059       uint64 n, uint64 k, const HostOrDeviceScalar<double> &alpha,
1060       const DeviceMemory<double> &a, int lda, const DeviceMemory<double> &b,
1061       int ldb, const HostOrDeviceScalar<double> &beta, DeviceMemory<double> *c,
1062       int ldc, ComputationType computation_type, AlgorithmType algorithm,
1063       ProfileResult *output_profile_result) = 0;
1064   virtual bool DoBlasGemmWithAlgorithm(
1065       Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m,
1066       uint64 n, uint64 k, const HostOrDeviceScalar<std::complex<float>> &alpha,
1067       const DeviceMemory<std::complex<float>> &a, int lda,
1068       const DeviceMemory<std::complex<float>> &b, int ldb,
1069       const HostOrDeviceScalar<std::complex<float>> &beta,
1070       DeviceMemory<std::complex<float>> *c, int ldc,
1071       ComputationType computation_type, AlgorithmType algorithm,
1072       ProfileResult *output_profile_result) = 0;
1073   virtual bool DoBlasGemmWithAlgorithm(
1074       Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m,
1075       uint64 n, uint64 k, const HostOrDeviceScalar<std::complex<double>> &alpha,
1076       const DeviceMemory<std::complex<double>> &a, int lda,
1077       const DeviceMemory<std::complex<double>> &b, int ldb,
1078       const HostOrDeviceScalar<std::complex<double>> &beta,
1079       DeviceMemory<std::complex<double>> *c, int ldc,
1080       ComputationType computation_type, AlgorithmType algorithm,
1081       ProfileResult *output_profile_result) = 0;
1082 
1083   // Computes a batch of matrix-matrix product with general matrices.
1084   // This is a batched version of DoBlasGemm.
1085   // The batched GEMM computes matrix product for each input/output in a, b,
1086   // and c, which contain batch_count DeviceMemory objects.
1087   virtual bool DoBlasGemmBatched(
1088       Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m,
1089       uint64 n, uint64 k, float alpha,
1090       const port::ArraySlice<DeviceMemory<Eigen::half> *> &a, int lda,
1091       const port::ArraySlice<DeviceMemory<Eigen::half> *> &b, int ldb,
1092       float beta, const port::ArraySlice<DeviceMemory<Eigen::half> *> &c,
1093       int ldc, int batch_count, ScratchAllocator *scratch_allocator) = 0;
1094   virtual bool DoBlasGemmBatched(
1095       Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m,
1096       uint64 n, uint64 k, float alpha,
1097       const port::ArraySlice<DeviceMemory<float> *> &a, int lda,
1098       const port::ArraySlice<DeviceMemory<float> *> &b, int ldb, float beta,
1099       const port::ArraySlice<DeviceMemory<float> *> &c, int ldc,
1100       int batch_count, ScratchAllocator *scratch_allocator) = 0;
1101   virtual bool DoBlasGemmBatched(
1102       Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m,
1103       uint64 n, uint64 k, double alpha,
1104       const port::ArraySlice<DeviceMemory<double> *> &a, int lda,
1105       const port::ArraySlice<DeviceMemory<double> *> &b, int ldb, double beta,
1106       const port::ArraySlice<DeviceMemory<double> *> &c, int ldc,
1107       int batch_count, ScratchAllocator *scratch_allocator) = 0;
1108   virtual bool DoBlasGemmBatched(
1109       Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m,
1110       uint64 n, uint64 k, std::complex<float> alpha,
1111       const port::ArraySlice<DeviceMemory<std::complex<float>> *> &a, int lda,
1112       const port::ArraySlice<DeviceMemory<std::complex<float>> *> &b, int ldb,
1113       std::complex<float> beta,
1114       const port::ArraySlice<DeviceMemory<std::complex<float>> *> &c, int ldc,
1115       int batch_count, ScratchAllocator *scratch_allocator) = 0;
1116   virtual bool DoBlasGemmBatched(
1117       Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m,
1118       uint64 n, uint64 k, std::complex<double> alpha,
1119       const port::ArraySlice<DeviceMemory<std::complex<double>> *> &a, int lda,
1120       const port::ArraySlice<DeviceMemory<std::complex<double>> *> &b, int ldb,
1121       std::complex<double> beta,
1122       const port::ArraySlice<DeviceMemory<std::complex<double>> *> &c, int ldc,
1123       int batch_count, ScratchAllocator *scratch_allocator) = 0;
1124 
1125   // Batched gemm with strides instead of pointer arrays.
1126   virtual bool DoBlasGemmStridedBatched(
1127       Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m,
1128       uint64 n, uint64 k, float alpha, const DeviceMemory<Eigen::half> &a,
1129       int lda, int64 stride_a, const DeviceMemory<Eigen::half> &b, int ldb,
1130       int64 stride_b, float beta, DeviceMemory<Eigen::half> *c, int ldc,
1131       int64 stride_c, int batch_count) = 0;
1132   virtual bool DoBlasGemmStridedBatched(
1133       Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m,
1134       uint64 n, uint64 k, float alpha, const DeviceMemory<float> &a, int lda,
1135       int64 stride_a, const DeviceMemory<float> &b, int ldb, int64 stride_b,
1136       float beta, DeviceMemory<float> *c, int ldc, int64 stride_c,
1137       int batch_count) = 0;
1138   virtual bool DoBlasGemmStridedBatched(
1139       Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m,
1140       uint64 n, uint64 k, double alpha, const DeviceMemory<double> &a, int lda,
1141       int64 stride_a, const DeviceMemory<double> &b, int ldb, int64 stride_b,
1142       double beta, DeviceMemory<double> *c, int ldc, int64 stride_c,
1143       int batch_count) = 0;
1144   virtual bool DoBlasGemmStridedBatched(
1145       Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m,
1146       uint64 n, uint64 k, std::complex<float> alpha,
1147       const DeviceMemory<std::complex<float>> &a, int lda, int64 stride_a,
1148       const DeviceMemory<std::complex<float>> &b, int ldb, int64 stride_b,
1149       std::complex<float> beta, DeviceMemory<std::complex<float>> *c, int ldc,
1150       int64 stride_c, int batch_count) = 0;
1151   virtual bool DoBlasGemmStridedBatched(
1152       Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m,
1153       uint64 n, uint64 k, std::complex<double> alpha,
1154       const DeviceMemory<std::complex<double>> &a, int lda, int64 stride_a,
1155       const DeviceMemory<std::complex<double>> &b, int ldb, int64 stride_b,
1156       std::complex<double> beta, DeviceMemory<std::complex<double>> *c, int ldc,
1157       int64 stride_c, int batch_count) = 0;
1158 
1159   // Computes a matrix-matrix product where one input matrix is Hermitian:
1160   //
1161   //     c <- alpha * a * b + beta * c,
1162   // or
1163   //     c <- alpha * b * a + beta * c,
1164   //
1165   // alpha and beta are scalars; a is a Hermitian matrix; b and c are m-by-n
1166   // matrices.
1167   virtual bool DoBlasHemm(Stream *stream, blas::Side side,
1168                           blas::UpperLower uplo, uint64 m, uint64 n,
1169                           std::complex<float> alpha,
1170                           const DeviceMemory<std::complex<float>> &a, int lda,
1171                           const DeviceMemory<std::complex<float>> &b, int ldb,
1172                           std::complex<float> beta,
1173                           DeviceMemory<std::complex<float>> *c, int ldc) = 0;
1174   virtual bool DoBlasHemm(Stream *stream, blas::Side side,
1175                           blas::UpperLower uplo, uint64 m, uint64 n,
1176                           std::complex<double> alpha,
1177                           const DeviceMemory<std::complex<double>> &a, int lda,
1178                           const DeviceMemory<std::complex<double>> &b, int ldb,
1179                           std::complex<double> beta,
1180                           DeviceMemory<std::complex<double>> *c, int ldc) = 0;
1181 
1182   // Performs a Hermitian rank-k update.
1183   //
1184   //     c <- alpha * a * conj(a') + beta * c,
1185   // or
1186   //     c <- alpha * conj(a') * a + beta * c,
1187   //
1188   // alpha and beta are scalars; c is a n-by-n Hermitian matrix; a is an n-by-k
1189   // matrix in the first case and a k-by-n matrix in the second case.
1190   virtual bool DoBlasHerk(Stream *stream, blas::UpperLower uplo,
1191                           blas::Transpose trans, uint64 n, uint64 k,
1192                           float alpha,
1193                           const DeviceMemory<std::complex<float>> &a, int lda,
1194                           float beta, DeviceMemory<std::complex<float>> *c,
1195                           int ldc) = 0;
1196   virtual bool DoBlasHerk(Stream *stream, blas::UpperLower uplo,
1197                           blas::Transpose trans, uint64 n, uint64 k,
1198                           double alpha,
1199                           const DeviceMemory<std::complex<double>> &a, int lda,
1200                           double beta, DeviceMemory<std::complex<double>> *c,
1201                           int ldc) = 0;
1202 
1203   // Performs a Hermitian rank-2k update.
1204   //
1205   //     c <- alpha * a * conj(b') + conj(alpha) * b * conj(a') + beta * c,
1206   // or
1207   //     c <- alpha * conj(b') * a + conj(alpha) * conj(a') * b + beta * c,
1208   //
1209   // alpha and beta are scalars; c is a n-by-n Hermitian matrix; a and b are
1210   // n-by-k matrices in the first case and k-by-n matrices in the second case.
1211   virtual bool DoBlasHer2k(Stream *stream, blas::UpperLower uplo,
1212                            blas::Transpose trans, uint64 n, uint64 k,
1213                            std::complex<float> alpha,
1214                            const DeviceMemory<std::complex<float>> &a, int lda,
1215                            const DeviceMemory<std::complex<float>> &b, int ldb,
1216                            float beta, DeviceMemory<std::complex<float>> *c,
1217                            int ldc) = 0;
1218   virtual bool DoBlasHer2k(Stream *stream, blas::UpperLower uplo,
1219                            blas::Transpose trans, uint64 n, uint64 k,
1220                            std::complex<double> alpha,
1221                            const DeviceMemory<std::complex<double>> &a, int lda,
1222                            const DeviceMemory<std::complex<double>> &b, int ldb,
1223                            double beta, DeviceMemory<std::complex<double>> *c,
1224                            int ldc) = 0;
1225 
1226   // Computes a matrix-matrix product where one input matrix is symmetric.
1227   //
1228   //     c <- alpha * a * b + beta * c,
1229   // or
1230   //     c <- alpha * b * a + beta * c,
1231   //
1232   // alpha and beta are scalars; a is a symmetric matrix; b and c are m-by-n
1233   // matrices.
1234   virtual bool DoBlasSymm(Stream *stream, blas::Side side,
1235                           blas::UpperLower uplo, uint64 m, uint64 n,
1236                           float alpha, const DeviceMemory<float> &a, int lda,
1237                           const DeviceMemory<float> &b, int ldb, float beta,
1238                           DeviceMemory<float> *c, int ldc) = 0;
1239   virtual bool DoBlasSymm(Stream *stream, blas::Side side,
1240                           blas::UpperLower uplo, uint64 m, uint64 n,
1241                           double alpha, const DeviceMemory<double> &a, int lda,
1242                           const DeviceMemory<double> &b, int ldb, double beta,
1243                           DeviceMemory<double> *c, int ldc) = 0;
1244   virtual bool DoBlasSymm(Stream *stream, blas::Side side,
1245                           blas::UpperLower uplo, uint64 m, uint64 n,
1246                           std::complex<float> alpha,
1247                           const DeviceMemory<std::complex<float>> &a, int lda,
1248                           const DeviceMemory<std::complex<float>> &b, int ldb,
1249                           std::complex<float> beta,
1250                           DeviceMemory<std::complex<float>> *c, int ldc) = 0;
1251   virtual bool DoBlasSymm(Stream *stream, blas::Side side,
1252                           blas::UpperLower uplo, uint64 m, uint64 n,
1253                           std::complex<double> alpha,
1254                           const DeviceMemory<std::complex<double>> &a, int lda,
1255                           const DeviceMemory<std::complex<double>> &b, int ldb,
1256                           std::complex<double> beta,
1257                           DeviceMemory<std::complex<double>> *c, int ldc) = 0;
1258 
1259   // Performs a symmetric rank-k update.
1260   //
1261   //     c <- alpha * a * a' + beta * c,
1262   // or
1263   //     c <- alpha * a' * a + beta * c,
1264   //
1265   // alpha and beta are scalars; c is a n-by-n symmetric matrix; a is an n-by-k
1266   // matrix in the first case and a k-by-n matrix in the second case.
1267   virtual bool DoBlasSyrk(Stream *stream, blas::UpperLower uplo,
1268                           blas::Transpose trans, uint64 n, uint64 k,
1269                           float alpha, const DeviceMemory<float> &a, int lda,
1270                           float beta, DeviceMemory<float> *c, int ldc) = 0;
1271   virtual bool DoBlasSyrk(Stream *stream, blas::UpperLower uplo,
1272                           blas::Transpose trans, uint64 n, uint64 k,
1273                           double alpha, const DeviceMemory<double> &a, int lda,
1274                           double beta, DeviceMemory<double> *c, int ldc) = 0;
1275   virtual bool DoBlasSyrk(Stream *stream, blas::UpperLower uplo,
1276                           blas::Transpose trans, uint64 n, uint64 k,
1277                           std::complex<float> alpha,
1278                           const DeviceMemory<std::complex<float>> &a, int lda,
1279                           std::complex<float> beta,
1280                           DeviceMemory<std::complex<float>> *c, int ldc) = 0;
1281   virtual bool DoBlasSyrk(Stream *stream, blas::UpperLower uplo,
1282                           blas::Transpose trans, uint64 n, uint64 k,
1283                           std::complex<double> alpha,
1284                           const DeviceMemory<std::complex<double>> &a, int lda,
1285                           std::complex<double> beta,
1286                           DeviceMemory<std::complex<double>> *c, int ldc) = 0;
1287 
1288   // Performs a symmetric rank-2k update.
1289   //
1290   //     c <- alpha * a * b' + alpha * b * a' + beta * c,
1291   // or
1292   //     c <- alpha * b' * a + alpha * a' * b + beta * c,
1293   //
1294   // alpha and beta are scalars; c is a n-by-n symmetric matrix; a and b are
1295   // n-by-k matrices in the first case and k-by-n matrices in the second case.
1296   virtual bool DoBlasSyr2k(Stream *stream, blas::UpperLower uplo,
1297                            blas::Transpose trans, uint64 n, uint64 k,
1298                            float alpha, const DeviceMemory<float> &a, int lda,
1299                            const DeviceMemory<float> &b, int ldb, float beta,
1300                            DeviceMemory<float> *c, int ldc) = 0;
1301   virtual bool DoBlasSyr2k(Stream *stream, blas::UpperLower uplo,
1302                            blas::Transpose trans, uint64 n, uint64 k,
1303                            double alpha, const DeviceMemory<double> &a, int lda,
1304                            const DeviceMemory<double> &b, int ldb, double beta,
1305                            DeviceMemory<double> *c, int ldc) = 0;
1306   virtual bool DoBlasSyr2k(Stream *stream, blas::UpperLower uplo,
1307                            blas::Transpose trans, uint64 n, uint64 k,
1308                            std::complex<float> alpha,
1309                            const DeviceMemory<std::complex<float>> &a, int lda,
1310                            const DeviceMemory<std::complex<float>> &b, int ldb,
1311                            std::complex<float> beta,
1312                            DeviceMemory<std::complex<float>> *c, int ldc) = 0;
1313   virtual bool DoBlasSyr2k(Stream *stream, blas::UpperLower uplo,
1314                            blas::Transpose trans, uint64 n, uint64 k,
1315                            std::complex<double> alpha,
1316                            const DeviceMemory<std::complex<double>> &a, int lda,
1317                            const DeviceMemory<std::complex<double>> &b, int ldb,
1318                            std::complex<double> beta,
1319                            DeviceMemory<std::complex<double>> *c, int ldc) = 0;
1320 
1321   // Computes a matrix-matrix product where one input matrix is triangular.
1322   //
1323   //     b <- alpha * op(a) * b,
1324   // or
1325   //     b <- alpha * b * op(a)
1326   //
1327   // alpha is a scalar; b is an m-by-n matrix; a is a unit, or non-unit, upper
1328   // or lower triangular matrix; op(a) is one of op(a) = a, or op(a) = a', or
1329   // op(a) = conj(a').
1330   virtual bool DoBlasTrmm(Stream *stream, blas::Side side,
1331                           blas::UpperLower uplo, blas::Transpose transa,
1332                           blas::Diagonal diag, uint64 m, uint64 n, float alpha,
1333                           const DeviceMemory<float> &a, int lda,
1334                           DeviceMemory<float> *b, int ldb) = 0;
1335   virtual bool DoBlasTrmm(Stream *stream, blas::Side side,
1336                           blas::UpperLower uplo, blas::Transpose transa,
1337                           blas::Diagonal diag, uint64 m, uint64 n, double alpha,
1338                           const DeviceMemory<double> &a, int lda,
1339                           DeviceMemory<double> *b, int ldb) = 0;
1340   virtual bool DoBlasTrmm(Stream *stream, blas::Side side,
1341                           blas::UpperLower uplo, blas::Transpose transa,
1342                           blas::Diagonal diag, uint64 m, uint64 n,
1343                           std::complex<float> alpha,
1344                           const DeviceMemory<std::complex<float>> &a, int lda,
1345                           DeviceMemory<std::complex<float>> *b, int ldb) = 0;
1346   virtual bool DoBlasTrmm(Stream *stream, blas::Side side,
1347                           blas::UpperLower uplo, blas::Transpose transa,
1348                           blas::Diagonal diag, uint64 m, uint64 n,
1349                           std::complex<double> alpha,
1350                           const DeviceMemory<std::complex<double>> &a, int lda,
1351                           DeviceMemory<std::complex<double>> *b, int ldb) = 0;
1352 
1353   // Solves a triangular matrix equation.
1354   //
1355   //     op(a) * x = alpha * b,
1356   // or
1357   //     x * op(a) = alpha * b
1358   //
1359   // alpha is a scalar; x and b are m-by-n matrices; a is a unit, or non-unit,
1360   // upper or lower triangular matrix; op(a) is one of op(a) = a, or op(a) = a',
1361   // or op(a) = conj(a').
1362   virtual bool DoBlasTrsm(Stream *stream, blas::Side side,
1363                           blas::UpperLower uplo, blas::Transpose transa,
1364                           blas::Diagonal diag, uint64 m, uint64 n, float alpha,
1365                           const DeviceMemory<float> &a, int lda,
1366                           DeviceMemory<float> *b, int ldb) = 0;
1367   virtual bool DoBlasTrsm(Stream *stream, blas::Side side,
1368                           blas::UpperLower uplo, blas::Transpose transa,
1369                           blas::Diagonal diag, uint64 m, uint64 n, double alpha,
1370                           const DeviceMemory<double> &a, int lda,
1371                           DeviceMemory<double> *b, int ldb) = 0;
1372   virtual bool DoBlasTrsm(Stream *stream, blas::Side side,
1373                           blas::UpperLower uplo, blas::Transpose transa,
1374                           blas::Diagonal diag, uint64 m, uint64 n,
1375                           std::complex<float> alpha,
1376                           const DeviceMemory<std::complex<float>> &a, int lda,
1377                           DeviceMemory<std::complex<float>> *b, int ldb) = 0;
1378   virtual bool DoBlasTrsm(Stream *stream, blas::Side side,
1379                           blas::UpperLower uplo, blas::Transpose transa,
1380                           blas::Diagonal diag, uint64 m, uint64 n,
1381                           std::complex<double> alpha,
1382                           const DeviceMemory<std::complex<double>> &a, int lda,
1383                           DeviceMemory<std::complex<double>> *b, int ldb) = 0;
1384 
1385  protected:
BlasSupport()1386   BlasSupport() {}
1387 
1388  private:
1389   SE_DISALLOW_COPY_AND_ASSIGN(BlasSupport);
1390 };
1391 
1392 // Macro used to quickly declare overrides for abstract virtuals in the
1393 // BlasSupport base class.
1394 #define TENSORFLOW_STREAM_EXECUTOR_GPU_BLAS_SUPPORT_OVERRIDES                  \
1395   bool DoBlasAsum(Stream *stream, uint64 elem_count,                           \
1396                   const DeviceMemory<float> &x, int incx,                      \
1397                   DeviceMemory<float> *result) override;                       \
1398   bool DoBlasAsum(Stream *stream, uint64 elem_count,                           \
1399                   const DeviceMemory<double> &x, int incx,                     \
1400                   DeviceMemory<double> *result) override;                      \
1401   bool DoBlasAsum(Stream *stream, uint64 elem_count,                           \
1402                   const DeviceMemory<std::complex<float>> &x, int incx,        \
1403                   DeviceMemory<float> *result) override;                       \
1404   bool DoBlasAsum(Stream *stream, uint64 elem_count,                           \
1405                   const DeviceMemory<std::complex<double>> &x, int incx,       \
1406                   DeviceMemory<double> *result) override;                      \
1407   bool DoBlasAxpy(Stream *stream, uint64 elem_count, float alpha,              \
1408                   const DeviceMemory<float> &x, int incx,                      \
1409                   DeviceMemory<float> *y, int incy) override;                  \
1410   bool DoBlasAxpy(Stream *stream, uint64 elem_count, double alpha,             \
1411                   const DeviceMemory<double> &x, int incx,                     \
1412                   DeviceMemory<double> *y, int incy) override;                 \
1413   bool DoBlasAxpy(Stream *stream, uint64 elem_count,                           \
1414                   std::complex<float> alpha,                                   \
1415                   const DeviceMemory<std::complex<float>> &x, int incx,        \
1416                   DeviceMemory<std::complex<float>> *y, int incy) override;    \
1417   bool DoBlasAxpy(Stream *stream, uint64 elem_count,                           \
1418                   std::complex<double> alpha,                                  \
1419                   const DeviceMemory<std::complex<double>> &x, int incx,       \
1420                   DeviceMemory<std::complex<double>> *y, int incy) override;   \
1421   bool DoBlasCopy(Stream *stream, uint64 elem_count,                           \
1422                   const DeviceMemory<float> &x, int incx,                      \
1423                   DeviceMemory<float> *y, int incy) override;                  \
1424   bool DoBlasCopy(Stream *stream, uint64 elem_count,                           \
1425                   const DeviceMemory<double> &x, int incx,                     \
1426                   DeviceMemory<double> *y, int incy) override;                 \
1427   bool DoBlasCopy(Stream *stream, uint64 elem_count,                           \
1428                   const DeviceMemory<std::complex<float>> &x, int incx,        \
1429                   DeviceMemory<std::complex<float>> *y, int incy) override;    \
1430   bool DoBlasCopy(Stream *stream, uint64 elem_count,                           \
1431                   const DeviceMemory<std::complex<double>> &x, int incx,       \
1432                   DeviceMemory<std::complex<double>> *y, int incy) override;   \
1433   bool DoBlasDot(Stream *stream, uint64 elem_count,                            \
1434                  const DeviceMemory<float> &x, int incx,                       \
1435                  const DeviceMemory<float> &y, int incy,                       \
1436                  DeviceMemory<float> *result) override;                        \
1437   bool DoBlasDot(Stream *stream, uint64 elem_count,                            \
1438                  const DeviceMemory<double> &x, int incx,                      \
1439                  const DeviceMemory<double> &y, int incy,                      \
1440                  DeviceMemory<double> *result) override;                       \
1441   bool DoBlasDotc(Stream *stream, uint64 elem_count,                           \
1442                   const DeviceMemory<std::complex<float>> &x, int incx,        \
1443                   const DeviceMemory<std::complex<float>> &y, int incy,        \
1444                   DeviceMemory<std::complex<float>> *result) override;         \
1445   bool DoBlasDotc(Stream *stream, uint64 elem_count,                           \
1446                   const DeviceMemory<std::complex<double>> &x, int incx,       \
1447                   const DeviceMemory<std::complex<double>> &y, int incy,       \
1448                   DeviceMemory<std::complex<double>> *result) override;        \
1449   bool DoBlasDotu(Stream *stream, uint64 elem_count,                           \
1450                   const DeviceMemory<std::complex<float>> &x, int incx,        \
1451                   const DeviceMemory<std::complex<float>> &y, int incy,        \
1452                   DeviceMemory<std::complex<float>> *result) override;         \
1453   bool DoBlasDotu(Stream *stream, uint64 elem_count,                           \
1454                   const DeviceMemory<std::complex<double>> &x, int incx,       \
1455                   const DeviceMemory<std::complex<double>> &y, int incy,       \
1456                   DeviceMemory<std::complex<double>> *result) override;        \
1457   bool DoBlasNrm2(Stream *stream, uint64 elem_count,                           \
1458                   const DeviceMemory<float> &x, int incx,                      \
1459                   DeviceMemory<float> *result) override;                       \
1460   bool DoBlasNrm2(Stream *stream, uint64 elem_count,                           \
1461                   const DeviceMemory<double> &x, int incx,                     \
1462                   DeviceMemory<double> *result) override;                      \
1463   bool DoBlasNrm2(Stream *stream, uint64 elem_count,                           \
1464                   const DeviceMemory<std::complex<float>> &x, int incx,        \
1465                   DeviceMemory<float> *result) override;                       \
1466   bool DoBlasNrm2(Stream *stream, uint64 elem_count,                           \
1467                   const DeviceMemory<std::complex<double>> &x, int incx,       \
1468                   DeviceMemory<double> *result) override;                      \
1469   bool DoBlasRot(Stream *stream, uint64 elem_count, DeviceMemory<float> *x,    \
1470                  int incx, DeviceMemory<float> *y, int incy, float c, float s) \
1471       override;                                                                \
1472   bool DoBlasRot(Stream *stream, uint64 elem_count, DeviceMemory<double> *x,   \
1473                  int incx, DeviceMemory<double> *y, int incy, double c,        \
1474                  double s) override;                                           \
1475   bool DoBlasRot(Stream *stream, uint64 elem_count,                            \
1476                  DeviceMemory<std::complex<float>> *x, int incx,               \
1477                  DeviceMemory<std::complex<float>> *y, int incy, float c,      \
1478                  float s) override;                                            \
1479   bool DoBlasRot(Stream *stream, uint64 elem_count,                            \
1480                  DeviceMemory<std::complex<double>> *x, int incx,              \
1481                  DeviceMemory<std::complex<double>> *y, int incy, double c,    \
1482                  double s) override;                                           \
1483   bool DoBlasRotg(Stream *stream, DeviceMemory<float> *a,                      \
1484                   DeviceMemory<float> *b, DeviceMemory<float> *c,              \
1485                   DeviceMemory<float> *s) override;                            \
1486   bool DoBlasRotg(Stream *stream, DeviceMemory<double> *a,                     \
1487                   DeviceMemory<double> *b, DeviceMemory<double> *c,            \
1488                   DeviceMemory<double> *s) override;                           \
1489   bool DoBlasRotg(Stream *stream, DeviceMemory<std::complex<float>> *a,        \
1490                   DeviceMemory<std::complex<float>> *b,                        \
1491                   DeviceMemory<float> *c,                                      \
1492                   DeviceMemory<std::complex<float>> *s) override;              \
1493   bool DoBlasRotg(Stream *stream, DeviceMemory<std::complex<double>> *a,       \
1494                   DeviceMemory<std::complex<double>> *b,                       \
1495                   DeviceMemory<double> *c,                                     \
1496                   DeviceMemory<std::complex<double>> *s) override;             \
1497   bool DoBlasRotm(Stream *stream, uint64 elem_count, DeviceMemory<float> *x,   \
1498                   int incx, DeviceMemory<float> *y, int incy,                  \
1499                   const DeviceMemory<float> &param) override;                  \
1500   bool DoBlasRotm(Stream *stream, uint64 elem_count, DeviceMemory<double> *x,  \
1501                   int incx, DeviceMemory<double> *y, int incy,                 \
1502                   const DeviceMemory<double> &param) override;                 \
1503   bool DoBlasRotmg(Stream *stream, DeviceMemory<float> *d1,                    \
1504                    DeviceMemory<float> *d2, DeviceMemory<float> *x1,           \
1505                    const DeviceMemory<float> &y1, DeviceMemory<float> *param)  \
1506       override;                                                                \
1507   bool DoBlasRotmg(Stream *stream, DeviceMemory<double> *d1,                   \
1508                    DeviceMemory<double> *d2, DeviceMemory<double> *x1,         \
1509                    const DeviceMemory<double> &y1,                             \
1510                    DeviceMemory<double> *param) override;                      \
1511   bool DoBlasScal(Stream *stream, uint64 elem_count, float alpha,              \
1512                   DeviceMemory<float> *x, int incx) override;                  \
1513   bool DoBlasScal(Stream *stream, uint64 elem_count, double alpha,             \
1514                   DeviceMemory<double> *x, int incx) override;                 \
1515   bool DoBlasScal(Stream *stream, uint64 elem_count, float alpha,              \
1516                   DeviceMemory<std::complex<float>> *x, int incx) override;    \
1517   bool DoBlasScal(Stream *stream, uint64 elem_count, double alpha,             \
1518                   DeviceMemory<std::complex<double>> *x, int incx) override;   \
1519   bool DoBlasScal(Stream *stream, uint64 elem_count,                           \
1520                   std::complex<float> alpha,                                   \
1521                   DeviceMemory<std::complex<float>> *x, int incx) override;    \
1522   bool DoBlasScal(Stream *stream, uint64 elem_count,                           \
1523                   std::complex<double> alpha,                                  \
1524                   DeviceMemory<std::complex<double>> *x, int incx) override;   \
1525   bool DoBlasSwap(Stream *stream, uint64 elem_count, DeviceMemory<float> *x,   \
1526                   int incx, DeviceMemory<float> *y, int incy) override;        \
1527   bool DoBlasSwap(Stream *stream, uint64 elem_count, DeviceMemory<double> *x,  \
1528                   int incx, DeviceMemory<double> *y, int incy) override;       \
1529   bool DoBlasSwap(Stream *stream, uint64 elem_count,                           \
1530                   DeviceMemory<std::complex<float>> *x, int incx,              \
1531                   DeviceMemory<std::complex<float>> *y, int incy) override;    \
1532   bool DoBlasSwap(Stream *stream, uint64 elem_count,                           \
1533                   DeviceMemory<std::complex<double>> *x, int incx,             \
1534                   DeviceMemory<std::complex<double>> *y, int incy) override;   \
1535   bool DoBlasIamax(Stream *stream, uint64 elem_count,                          \
1536                    const DeviceMemory<float> &x, int incx,                     \
1537                    DeviceMemory<int> *result) override;                        \
1538   bool DoBlasIamax(Stream *stream, uint64 elem_count,                          \
1539                    const DeviceMemory<double> &x, int incx,                    \
1540                    DeviceMemory<int> *result) override;                        \
1541   bool DoBlasIamax(Stream *stream, uint64 elem_count,                          \
1542                    const DeviceMemory<std::complex<float>> &x, int incx,       \
1543                    DeviceMemory<int> *result) override;                        \
1544   bool DoBlasIamax(Stream *stream, uint64 elem_count,                          \
1545                    const DeviceMemory<std::complex<double>> &x, int incx,      \
1546                    DeviceMemory<int> *result) override;                        \
1547   bool DoBlasIamin(Stream *stream, uint64 elem_count,                          \
1548                    const DeviceMemory<float> &x, int incx,                     \
1549                    DeviceMemory<int> *result) override;                        \
1550   bool DoBlasIamin(Stream *stream, uint64 elem_count,                          \
1551                    const DeviceMemory<double> &x, int incx,                    \
1552                    DeviceMemory<int> *result) override;                        \
1553   bool DoBlasIamin(Stream *stream, uint64 elem_count,                          \
1554                    const DeviceMemory<std::complex<float>> &x, int incx,       \
1555                    DeviceMemory<int> *result) override;                        \
1556   bool DoBlasIamin(Stream *stream, uint64 elem_count,                          \
1557                    const DeviceMemory<std::complex<double>> &x, int incx,      \
1558                    DeviceMemory<int> *result) override;                        \
1559   bool DoBlasGbmv(Stream *stream, blas::Transpose trans, uint64 m, uint64 n,   \
1560                   uint64 kl, uint64 ku, float alpha,                           \
1561                   const DeviceMemory<float> &a, int lda,                       \
1562                   const DeviceMemory<float> &x, int incx, float beta,          \
1563                   DeviceMemory<float> *y, int incy) override;                  \
1564   bool DoBlasGbmv(Stream *stream, blas::Transpose trans, uint64 m, uint64 n,   \
1565                   uint64 kl, uint64 ku, double alpha,                          \
1566                   const DeviceMemory<double> &a, int lda,                      \
1567                   const DeviceMemory<double> &x, int incx, double beta,        \
1568                   DeviceMemory<double> *y, int incy) override;                 \
1569   bool DoBlasGbmv(Stream *stream, blas::Transpose trans, uint64 m, uint64 n,   \
1570                   uint64 kl, uint64 ku, std::complex<float> alpha,             \
1571                   const DeviceMemory<std::complex<float>> &a, int lda,         \
1572                   const DeviceMemory<std::complex<float>> &x, int incx,        \
1573                   std::complex<float> beta,                                    \
1574                   DeviceMemory<std::complex<float>> *y, int incy) override;    \
1575   bool DoBlasGbmv(Stream *stream, blas::Transpose trans, uint64 m, uint64 n,   \
1576                   uint64 kl, uint64 ku, std::complex<double> alpha,            \
1577                   const DeviceMemory<std::complex<double>> &a, int lda,        \
1578                   const DeviceMemory<std::complex<double>> &x, int incx,       \
1579                   std::complex<double> beta,                                   \
1580                   DeviceMemory<std::complex<double>> *y, int incy) override;   \
1581   bool DoBlasGemv(Stream *stream, blas::Transpose trans, uint64 m, uint64 n,   \
1582                   float alpha, const DeviceMemory<float> &a, int lda,          \
1583                   const DeviceMemory<float> &x, int incx, float beta,          \
1584                   DeviceMemory<float> *y, int incy) override;                  \
1585   bool DoBlasGemv(Stream *stream, blas::Transpose trans, uint64 m, uint64 n,   \
1586                   double alpha, const DeviceMemory<double> &a, int lda,        \
1587                   const DeviceMemory<double> &x, int incx, double beta,        \
1588                   DeviceMemory<double> *y, int incy) override;                 \
1589   bool DoBlasGemv(Stream *stream, blas::Transpose trans, uint64 m, uint64 n,   \
1590                   std::complex<float> alpha,                                   \
1591                   const DeviceMemory<std::complex<float>> &a, int lda,         \
1592                   const DeviceMemory<std::complex<float>> &x, int incx,        \
1593                   std::complex<float> beta,                                    \
1594                   DeviceMemory<std::complex<float>> *y, int incy) override;    \
1595   bool DoBlasGemv(Stream *stream, blas::Transpose trans, uint64 m, uint64 n,   \
1596                   std::complex<double> alpha,                                  \
1597                   const DeviceMemory<std::complex<double>> &a, int lda,        \
1598                   const DeviceMemory<std::complex<double>> &x, int incx,       \
1599                   std::complex<double> beta,                                   \
1600                   DeviceMemory<std::complex<double>> *y, int incy) override;   \
1601   bool DoBlasGemvWithProfiling(                                                \
1602       Stream *stream, blas::Transpose trans, uint64 m, uint64 n, float alpha,  \
1603       const DeviceMemory<float> &a, int lda, const DeviceMemory<float> &x,     \
1604       int incx, float beta, DeviceMemory<float> *y, int incy,                  \
1605       blas::ProfileResult *output_profile_result) override;                    \
1606   bool DoBlasGemvWithProfiling(                                                \
1607       Stream *stream, blas::Transpose trans, uint64 m, uint64 n, double alpha, \
1608       const DeviceMemory<double> &a, int lda, const DeviceMemory<double> &x,   \
1609       int incx, double beta, DeviceMemory<double> *y, int incy,                \
1610       blas::ProfileResult *output_profile_result) override;                    \
1611   bool DoBlasGemvWithProfiling(                                                \
1612       Stream *stream, blas::Transpose trans, uint64 m, uint64 n,               \
1613       std::complex<float> alpha, const DeviceMemory<std::complex<float>> &a,   \
1614       int lda, const DeviceMemory<std::complex<float>> &x, int incx,           \
1615       std::complex<float> beta, DeviceMemory<std::complex<float>> *y,          \
1616       int incy, blas::ProfileResult *output_profile_result) override;          \
1617   bool DoBlasGemvWithProfiling(                                                \
1618       Stream *stream, blas::Transpose trans, uint64 m, uint64 n,               \
1619       std::complex<double> alpha, const DeviceMemory<std::complex<double>> &a, \
1620       int lda, const DeviceMemory<std::complex<double>> &x, int incx,          \
1621       std::complex<double> beta, DeviceMemory<std::complex<double>> *y,        \
1622       int incy, blas::ProfileResult *output_profile_result) override;          \
1623   bool DoBlasGer(Stream *stream, uint64 m, uint64 n, float alpha,              \
1624                  const DeviceMemory<float> &x, int incx,                       \
1625                  const DeviceMemory<float> &y, int incy,                       \
1626                  DeviceMemory<float> *a, int lda) override;                    \
1627   bool DoBlasGer(Stream *stream, uint64 m, uint64 n, double alpha,             \
1628                  const DeviceMemory<double> &x, int incx,                      \
1629                  const DeviceMemory<double> &y, int incy,                      \
1630                  DeviceMemory<double> *a, int lda) override;                   \
1631   bool DoBlasGerc(Stream *stream, uint64 m, uint64 n,                          \
1632                   std::complex<float> alpha,                                   \
1633                   const DeviceMemory<std::complex<float>> &x, int incx,        \
1634                   const DeviceMemory<std::complex<float>> &y, int incy,        \
1635                   DeviceMemory<std::complex<float>> *a, int lda) override;     \
1636   bool DoBlasGerc(Stream *stream, uint64 m, uint64 n,                          \
1637                   std::complex<double> alpha,                                  \
1638                   const DeviceMemory<std::complex<double>> &x, int incx,       \
1639                   const DeviceMemory<std::complex<double>> &y, int incy,       \
1640                   DeviceMemory<std::complex<double>> *a, int lda) override;    \
1641   bool DoBlasGeru(Stream *stream, uint64 m, uint64 n,                          \
1642                   std::complex<float> alpha,                                   \
1643                   const DeviceMemory<std::complex<float>> &x, int incx,        \
1644                   const DeviceMemory<std::complex<float>> &y, int incy,        \
1645                   DeviceMemory<std::complex<float>> *a, int lda) override;     \
1646   bool DoBlasGeru(Stream *stream, uint64 m, uint64 n,                          \
1647                   std::complex<double> alpha,                                  \
1648                   const DeviceMemory<std::complex<double>> &x, int incx,       \
1649                   const DeviceMemory<std::complex<double>> &y, int incy,       \
1650                   DeviceMemory<std::complex<double>> *a, int lda) override;    \
1651   bool DoBlasHbmv(Stream *stream, blas::UpperLower uplo, uint64 n, uint64 k,   \
1652                   std::complex<float> alpha,                                   \
1653                   const DeviceMemory<std::complex<float>> &a, int lda,         \
1654                   const DeviceMemory<std::complex<float>> &x, int incx,        \
1655                   std::complex<float> beta,                                    \
1656                   DeviceMemory<std::complex<float>> *y, int incy) override;    \
1657   bool DoBlasHbmv(Stream *stream, blas::UpperLower uplo, uint64 n, uint64 k,   \
1658                   std::complex<double> alpha,                                  \
1659                   const DeviceMemory<std::complex<double>> &a, int lda,        \
1660                   const DeviceMemory<std::complex<double>> &x, int incx,       \
1661                   std::complex<double> beta,                                   \
1662                   DeviceMemory<std::complex<double>> *y, int incy) override;   \
1663   bool DoBlasHemv(Stream *stream, blas::UpperLower uplo, uint64 n,             \
1664                   std::complex<float> alpha,                                   \
1665                   const DeviceMemory<std::complex<float>> &a, int lda,         \
1666                   const DeviceMemory<std::complex<float>> &x, int incx,        \
1667                   std::complex<float> beta,                                    \
1668                   DeviceMemory<std::complex<float>> *y, int incy) override;    \
1669   bool DoBlasHemv(Stream *stream, blas::UpperLower uplo, uint64 n,             \
1670                   std::complex<double> alpha,                                  \
1671                   const DeviceMemory<std::complex<double>> &a, int lda,        \
1672                   const DeviceMemory<std::complex<double>> &x, int incx,       \
1673                   std::complex<double> beta,                                   \
1674                   DeviceMemory<std::complex<double>> *y, int incy) override;   \
1675   bool DoBlasHer(Stream *stream, blas::UpperLower uplo, uint64 n, float alpha, \
1676                  const DeviceMemory<std::complex<float>> &x, int incx,         \
1677                  DeviceMemory<std::complex<float>> *a, int lda) override;      \
1678   bool DoBlasHer(Stream *stream, blas::UpperLower uplo, uint64 n,              \
1679                  double alpha, const DeviceMemory<std::complex<double>> &x,    \
1680                  int incx, DeviceMemory<std::complex<double>> *a, int lda)     \
1681       override;                                                                \
1682   bool DoBlasHer2(Stream *stream, blas::UpperLower uplo, uint64 n,             \
1683                   std::complex<float> alpha,                                   \
1684                   const DeviceMemory<std::complex<float>> &x, int incx,        \
1685                   const DeviceMemory<std::complex<float>> &y, int incy,        \
1686                   DeviceMemory<std::complex<float>> *a, int lda) override;     \
1687   bool DoBlasHer2(Stream *stream, blas::UpperLower uplo, uint64 n,             \
1688                   std::complex<double> alpha,                                  \
1689                   const DeviceMemory<std::complex<double>> &x, int incx,       \
1690                   const DeviceMemory<std::complex<double>> &y, int incy,       \
1691                   DeviceMemory<std::complex<double>> *a, int lda) override;    \
1692   bool DoBlasHpmv(Stream *stream, blas::UpperLower uplo, uint64 n,             \
1693                   std::complex<float> alpha,                                   \
1694                   const DeviceMemory<std::complex<float>> &ap,                 \
1695                   const DeviceMemory<std::complex<float>> &x, int incx,        \
1696                   std::complex<float> beta,                                    \
1697                   DeviceMemory<std::complex<float>> *y, int incy) override;    \
1698   bool DoBlasHpmv(Stream *stream, blas::UpperLower uplo, uint64 n,             \
1699                   std::complex<double> alpha,                                  \
1700                   const DeviceMemory<std::complex<double>> &ap,                \
1701                   const DeviceMemory<std::complex<double>> &x, int incx,       \
1702                   std::complex<double> beta,                                   \
1703                   DeviceMemory<std::complex<double>> *y, int incy) override;   \
1704   bool DoBlasHpr(Stream *stream, blas::UpperLower uplo, uint64 n, float alpha, \
1705                  const DeviceMemory<std::complex<float>> &x, int incx,         \
1706                  DeviceMemory<std::complex<float>> *ap) override;              \
1707   bool DoBlasHpr(Stream *stream, blas::UpperLower uplo, uint64 n,              \
1708                  double alpha, const DeviceMemory<std::complex<double>> &x,    \
1709                  int incx, DeviceMemory<std::complex<double>> *ap) override;   \
1710   bool DoBlasHpr2(Stream *stream, blas::UpperLower uplo, uint64 n,             \
1711                   std::complex<float> alpha,                                   \
1712                   const DeviceMemory<std::complex<float>> &x, int incx,        \
1713                   const DeviceMemory<std::complex<float>> &y, int incy,        \
1714                   DeviceMemory<std::complex<float>> *ap) override;             \
1715   bool DoBlasHpr2(Stream *stream, blas::UpperLower uplo, uint64 n,             \
1716                   std::complex<double> alpha,                                  \
1717                   const DeviceMemory<std::complex<double>> &x, int incx,       \
1718                   const DeviceMemory<std::complex<double>> &y, int incy,       \
1719                   DeviceMemory<std::complex<double>> *ap) override;            \
1720   bool DoBlasSbmv(Stream *stream, blas::UpperLower uplo, uint64 n, uint64 k,   \
1721                   float alpha, const DeviceMemory<float> &a, int lda,          \
1722                   const DeviceMemory<float> &x, int incx, float beta,          \
1723                   DeviceMemory<float> *y, int incy) override;                  \
1724   bool DoBlasSbmv(Stream *stream, blas::UpperLower uplo, uint64 n, uint64 k,   \
1725                   double alpha, const DeviceMemory<double> &a, int lda,        \
1726                   const DeviceMemory<double> &x, int incx, double beta,        \
1727                   DeviceMemory<double> *y, int incy) override;                 \
1728   bool DoBlasSpmv(Stream *stream, blas::UpperLower uplo, uint64 n,             \
1729                   float alpha, const DeviceMemory<float> &ap,                  \
1730                   const DeviceMemory<float> &x, int incx, float beta,          \
1731                   DeviceMemory<float> *y, int incy) override;                  \
1732   bool DoBlasSpmv(Stream *stream, blas::UpperLower uplo, uint64 n,             \
1733                   double alpha, const DeviceMemory<double> &ap,                \
1734                   const DeviceMemory<double> &x, int incx, double beta,        \
1735                   DeviceMemory<double> *y, int incy) override;                 \
1736   bool DoBlasSpr(Stream *stream, blas::UpperLower uplo, uint64 n, float alpha, \
1737                  const DeviceMemory<float> &x, int incx,                       \
1738                  DeviceMemory<float> *ap) override;                            \
1739   bool DoBlasSpr(Stream *stream, blas::UpperLower uplo, uint64 n,              \
1740                  double alpha, const DeviceMemory<double> &x, int incx,        \
1741                  DeviceMemory<double> *ap) override;                           \
1742   bool DoBlasSpr2(Stream *stream, blas::UpperLower uplo, uint64 n,             \
1743                   float alpha, const DeviceMemory<float> &x, int incx,         \
1744                   const DeviceMemory<float> &y, int incy,                      \
1745                   DeviceMemory<float> *ap) override;                           \
1746   bool DoBlasSpr2(Stream *stream, blas::UpperLower uplo, uint64 n,             \
1747                   double alpha, const DeviceMemory<double> &x, int incx,       \
1748                   const DeviceMemory<double> &y, int incy,                     \
1749                   DeviceMemory<double> *ap) override;                          \
1750   bool DoBlasSymv(Stream *stream, blas::UpperLower uplo, uint64 n,             \
1751                   float alpha, const DeviceMemory<float> &a, int lda,          \
1752                   const DeviceMemory<float> &x, int incx, float beta,          \
1753                   DeviceMemory<float> *y, int incy) override;                  \
1754   bool DoBlasSymv(Stream *stream, blas::UpperLower uplo, uint64 n,             \
1755                   double alpha, const DeviceMemory<double> &a, int lda,        \
1756                   const DeviceMemory<double> &x, int incx, double beta,        \
1757                   DeviceMemory<double> *y, int incy) override;                 \
1758   bool DoBlasSyr(Stream *stream, blas::UpperLower uplo, uint64 n, float alpha, \
1759                  const DeviceMemory<float> &x, int incx,                       \
1760                  DeviceMemory<float> *a, int lda) override;                    \
1761   bool DoBlasSyr(Stream *stream, blas::UpperLower uplo, uint64 n,              \
1762                  double alpha, const DeviceMemory<double> &x, int incx,        \
1763                  DeviceMemory<double> *a, int lda) override;                   \
1764   bool DoBlasSyr2(Stream *stream, blas::UpperLower uplo, uint64 n,             \
1765                   float alpha, const DeviceMemory<float> &x, int incx,         \
1766                   const DeviceMemory<float> &y, int incy,                      \
1767                   DeviceMemory<float> *a, int lda) override;                   \
1768   bool DoBlasSyr2(Stream *stream, blas::UpperLower uplo, uint64 n,             \
1769                   double alpha, const DeviceMemory<double> &x, int incx,       \
1770                   const DeviceMemory<double> &y, int incy,                     \
1771                   DeviceMemory<double> *a, int lda) override;                  \
1772   bool DoBlasTbmv(Stream *stream, blas::UpperLower uplo,                       \
1773                   blas::Transpose trans, blas::Diagonal diag, uint64 n,        \
1774                   uint64 k, const DeviceMemory<float> &a, int lda,             \
1775                   DeviceMemory<float> *x, int incx) override;                  \
1776   bool DoBlasTbmv(Stream *stream, blas::UpperLower uplo,                       \
1777                   blas::Transpose trans, blas::Diagonal diag, uint64 n,        \
1778                   uint64 k, const DeviceMemory<double> &a, int lda,            \
1779                   DeviceMemory<double> *x, int incx) override;                 \
1780   bool DoBlasTbmv(Stream *stream, blas::UpperLower uplo,                       \
1781                   blas::Transpose trans, blas::Diagonal diag, uint64 n,        \
1782                   uint64 k, const DeviceMemory<std::complex<float>> &a,        \
1783                   int lda, DeviceMemory<std::complex<float>> *x, int incx)     \
1784       override;                                                                \
1785   bool DoBlasTbmv(Stream *stream, blas::UpperLower uplo,                       \
1786                   blas::Transpose trans, blas::Diagonal diag, uint64 n,        \
1787                   uint64 k, const DeviceMemory<std::complex<double>> &a,       \
1788                   int lda, DeviceMemory<std::complex<double>> *x, int incx)    \
1789       override;                                                                \
1790   bool DoBlasTbsv(Stream *stream, blas::UpperLower uplo,                       \
1791                   blas::Transpose trans, blas::Diagonal diag, uint64 n,        \
1792                   uint64 k, const DeviceMemory<float> &a, int lda,             \
1793                   DeviceMemory<float> *x, int incx) override;                  \
1794   bool DoBlasTbsv(Stream *stream, blas::UpperLower uplo,                       \
1795                   blas::Transpose trans, blas::Diagonal diag, uint64 n,        \
1796                   uint64 k, const DeviceMemory<double> &a, int lda,            \
1797                   DeviceMemory<double> *x, int incx) override;                 \
1798   bool DoBlasTbsv(Stream *stream, blas::UpperLower uplo,                       \
1799                   blas::Transpose trans, blas::Diagonal diag, uint64 n,        \
1800                   uint64 k, const DeviceMemory<std::complex<float>> &a,        \
1801                   int lda, DeviceMemory<std::complex<float>> *x, int incx)     \
1802       override;                                                                \
1803   bool DoBlasTbsv(Stream *stream, blas::UpperLower uplo,                       \
1804                   blas::Transpose trans, blas::Diagonal diag, uint64 n,        \
1805                   uint64 k, const DeviceMemory<std::complex<double>> &a,       \
1806                   int lda, DeviceMemory<std::complex<double>> *x, int incx)    \
1807       override;                                                                \
1808   bool DoBlasTpmv(Stream *stream, blas::UpperLower uplo,                       \
1809                   blas::Transpose trans, blas::Diagonal diag, uint64 n,        \
1810                   const DeviceMemory<float> &ap, DeviceMemory<float> *x,       \
1811                   int incx) override;                                          \
1812   bool DoBlasTpmv(Stream *stream, blas::UpperLower uplo,                       \
1813                   blas::Transpose trans, blas::Diagonal diag, uint64 n,        \
1814                   const DeviceMemory<double> &ap, DeviceMemory<double> *x,     \
1815                   int incx) override;                                          \
1816   bool DoBlasTpmv(Stream *stream, blas::UpperLower uplo,                       \
1817                   blas::Transpose trans, blas::Diagonal diag, uint64 n,        \
1818                   const DeviceMemory<std::complex<float>> &ap,                 \
1819                   DeviceMemory<std::complex<float>> *x, int incx) override;    \
1820   bool DoBlasTpmv(Stream *stream, blas::UpperLower uplo,                       \
1821                   blas::Transpose trans, blas::Diagonal diag, uint64 n,        \
1822                   const DeviceMemory<std::complex<double>> &ap,                \
1823                   DeviceMemory<std::complex<double>> *x, int incx) override;   \
1824   bool DoBlasTpsv(Stream *stream, blas::UpperLower uplo,                       \
1825                   blas::Transpose trans, blas::Diagonal diag, uint64 n,        \
1826                   const DeviceMemory<float> &ap, DeviceMemory<float> *x,       \
1827                   int incx) override;                                          \
1828   bool DoBlasTpsv(Stream *stream, blas::UpperLower uplo,                       \
1829                   blas::Transpose trans, blas::Diagonal diag, uint64 n,        \
1830                   const DeviceMemory<double> &ap, DeviceMemory<double> *x,     \
1831                   int incx) override;                                          \
1832   bool DoBlasTpsv(Stream *stream, blas::UpperLower uplo,                       \
1833                   blas::Transpose trans, blas::Diagonal diag, uint64 n,        \
1834                   const DeviceMemory<std::complex<float>> &ap,                 \
1835                   DeviceMemory<std::complex<float>> *x, int incx) override;    \
1836   bool DoBlasTpsv(Stream *stream, blas::UpperLower uplo,                       \
1837                   blas::Transpose trans, blas::Diagonal diag, uint64 n,        \
1838                   const DeviceMemory<std::complex<double>> &ap,                \
1839                   DeviceMemory<std::complex<double>> *x, int incx) override;   \
1840   bool DoBlasTrmv(Stream *stream, blas::UpperLower uplo,                       \
1841                   blas::Transpose trans, blas::Diagonal diag, uint64 n,        \
1842                   const DeviceMemory<float> &a, int lda,                       \
1843                   DeviceMemory<float> *x, int incx) override;                  \
1844   bool DoBlasTrmv(Stream *stream, blas::UpperLower uplo,                       \
1845                   blas::Transpose trans, blas::Diagonal diag, uint64 n,        \
1846                   const DeviceMemory<double> &a, int lda,                      \
1847                   DeviceMemory<double> *x, int incx) override;                 \
1848   bool DoBlasTrmv(Stream *stream, blas::UpperLower uplo,                       \
1849                   blas::Transpose trans, blas::Diagonal diag, uint64 n,        \
1850                   const DeviceMemory<std::complex<float>> &a, int lda,         \
1851                   DeviceMemory<std::complex<float>> *x, int incx) override;    \
1852   bool DoBlasTrmv(Stream *stream, blas::UpperLower uplo,                       \
1853                   blas::Transpose trans, blas::Diagonal diag, uint64 n,        \
1854                   const DeviceMemory<std::complex<double>> &a, int lda,        \
1855                   DeviceMemory<std::complex<double>> *x, int incx) override;   \
1856   bool DoBlasTrsv(Stream *stream, blas::UpperLower uplo,                       \
1857                   blas::Transpose trans, blas::Diagonal diag, uint64 n,        \
1858                   const DeviceMemory<float> &a, int lda,                       \
1859                   DeviceMemory<float> *x, int incx) override;                  \
1860   bool DoBlasTrsv(Stream *stream, blas::UpperLower uplo,                       \
1861                   blas::Transpose trans, blas::Diagonal diag, uint64 n,        \
1862                   const DeviceMemory<double> &a, int lda,                      \
1863                   DeviceMemory<double> *x, int incx) override;                 \
1864   bool DoBlasTrsv(Stream *stream, blas::UpperLower uplo,                       \
1865                   blas::Transpose trans, blas::Diagonal diag, uint64 n,        \
1866                   const DeviceMemory<std::complex<float>> &a, int lda,         \
1867                   DeviceMemory<std::complex<float>> *x, int incx) override;    \
1868   bool DoBlasTrsv(Stream *stream, blas::UpperLower uplo,                       \
1869                   blas::Transpose trans, blas::Diagonal diag, uint64 n,        \
1870                   const DeviceMemory<std::complex<double>> &a, int lda,        \
1871                   DeviceMemory<std::complex<double>> *x, int incx) override;   \
1872   bool DoBlasGemm(Stream *stream, blas::Transpose transa,                      \
1873                   blas::Transpose transb, uint64 m, uint64 n, uint64 k,        \
1874                   float alpha, const DeviceMemory<Eigen::half> &a, int lda,    \
1875                   const DeviceMemory<Eigen::half> &b, int ldb, float beta,     \
1876                   DeviceMemory<Eigen::half> *c, int ldc) override;             \
1877   bool DoBlasGemm(Stream *stream, blas::Transpose transa,                      \
1878                   blas::Transpose transb, uint64 m, uint64 n, uint64 k,        \
1879                   float alpha, const DeviceMemory<float> &a, int lda,          \
1880                   const DeviceMemory<float> &b, int ldb, float beta,           \
1881                   DeviceMemory<float> *c, int ldc) override;                   \
1882   bool DoBlasGemm(Stream *stream, blas::Transpose transa,                      \
1883                   blas::Transpose transb, uint64 m, uint64 n, uint64 k,        \
1884                   double alpha, const DeviceMemory<double> &a, int lda,        \
1885                   const DeviceMemory<double> &b, int ldb, double beta,         \
1886                   DeviceMemory<double> *c, int ldc) override;                  \
1887   bool DoBlasGemm(Stream *stream, blas::Transpose transa,                      \
1888                   blas::Transpose transb, uint64 m, uint64 n, uint64 k,        \
1889                   std::complex<float> alpha,                                   \
1890                   const DeviceMemory<std::complex<float>> &a, int lda,         \
1891                   const DeviceMemory<std::complex<float>> &b, int ldb,         \
1892                   std::complex<float> beta,                                    \
1893                   DeviceMemory<std::complex<float>> *c, int ldc) override;     \
1894   bool DoBlasGemm(Stream *stream, blas::Transpose transa,                      \
1895                   blas::Transpose transb, uint64 m, uint64 n, uint64 k,        \
1896                   std::complex<double> alpha,                                  \
1897                   const DeviceMemory<std::complex<double>> &a, int lda,        \
1898                   const DeviceMemory<std::complex<double>> &b, int ldb,        \
1899                   std::complex<double> beta,                                   \
1900                   DeviceMemory<std::complex<double>> *c, int ldc) override;    \
1901   bool DoBlasGemmWithProfiling(                                                \
1902       Stream *stream, blas::Transpose transa, blas::Transpose transb,          \
1903       uint64 m, uint64 n, uint64 k, float alpha,                               \
1904       const DeviceMemory<Eigen::half> &a, int lda,                             \
1905       const DeviceMemory<Eigen::half> &b, int ldb, float beta,                 \
1906       DeviceMemory<Eigen::half> *c, int ldc,                                   \
1907       blas::ProfileResult *output_profile_result) override;                    \
1908   bool DoBlasGemmWithProfiling(                                                \
1909       Stream *stream, blas::Transpose transa, blas::Transpose transb,          \
1910       uint64 m, uint64 n, uint64 k, float alpha, const DeviceMemory<float> &a, \
1911       int lda, const DeviceMemory<float> &b, int ldb, float beta,              \
1912       DeviceMemory<float> *c, int ldc,                                         \
1913       blas::ProfileResult *output_profile_result) override;                    \
1914   bool DoBlasGemmWithProfiling(                                                \
1915       Stream *stream, blas::Transpose transa, blas::Transpose transb,          \
1916       uint64 m, uint64 n, uint64 k, double alpha,                              \
1917       const DeviceMemory<double> &a, int lda, const DeviceMemory<double> &b,   \
1918       int ldb, double beta, DeviceMemory<double> *c, int ldc,                  \
1919       blas::ProfileResult *output_profile_result) override;                    \
1920   bool DoBlasGemmWithProfiling(                                                \
1921       Stream *stream, blas::Transpose transa, blas::Transpose transb,          \
1922       uint64 m, uint64 n, uint64 k, std::complex<float> alpha,                 \
1923       const DeviceMemory<std::complex<float>> &a, int lda,                     \
1924       const DeviceMemory<std::complex<float>> &b, int ldb,                     \
1925       std::complex<float> beta, DeviceMemory<std::complex<float>> *c, int ldc, \
1926       blas::ProfileResult *output_profile_result) override;                    \
1927   bool DoBlasGemmWithProfiling(                                                \
1928       Stream *stream, blas::Transpose transa, blas::Transpose transb,          \
1929       uint64 m, uint64 n, uint64 k, std::complex<double> alpha,                \
1930       const DeviceMemory<std::complex<double>> &a, int lda,                    \
1931       const DeviceMemory<std::complex<double>> &b, int ldb,                    \
1932       std::complex<double> beta, DeviceMemory<std::complex<double>> *c,        \
1933       int ldc, blas::ProfileResult *output_profile_result) override;           \
1934   bool GetBlasGemmAlgorithms(std::vector<blas::AlgorithmType> *out_algorithms) \
1935       override;                                                                \
1936   bool DoBlasGemmWithAlgorithm(                                                \
1937       Stream *stream, blas::Transpose transa, blas::Transpose transb,          \
1938       uint64 m, uint64 n, uint64 k, const HostOrDeviceScalar<int> &alpha,      \
1939       const DeviceMemory<int8> &a, int lda, const DeviceMemory<int8> &b,       \
1940       int ldb, const HostOrDeviceScalar<int> &beta, DeviceMemory<int> *c,      \
1941       int ldc, blas::ComputationType computation_type,                         \
1942       blas::AlgorithmType algorithm,                                           \
1943       blas::ProfileResult *output_profile_result) override;                    \
1944   bool DoBlasGemmWithAlgorithm(                                                \
1945       Stream *stream, blas::Transpose transa, blas::Transpose transb,          \
1946       uint64 m, uint64 n, uint64 k,                                            \
1947       const HostOrDeviceScalar<Eigen::half> &alpha,                            \
1948       const DeviceMemory<Eigen::half> &a, int lda,                             \
1949       const DeviceMemory<Eigen::half> &b, int ldb,                             \
1950       const HostOrDeviceScalar<Eigen::half> &beta,                             \
1951       DeviceMemory<Eigen::half> *c, int ldc,                                   \
1952       blas::ComputationType computation_type, blas::AlgorithmType algorithm,   \
1953       blas::ProfileResult *output_profile_result) override;                    \
1954   bool DoBlasGemmWithAlgorithm(                                                \
1955       Stream *stream, blas::Transpose transa, blas::Transpose transb,          \
1956       uint64 m, uint64 n, uint64 k, const HostOrDeviceScalar<float> &alpha,    \
1957       const DeviceMemory<float> &a, int lda, const DeviceMemory<float> &b,     \
1958       int ldb, const HostOrDeviceScalar<float> &beta, DeviceMemory<float> *c,  \
1959       int ldc, blas::ComputationType computation_type,                         \
1960       blas::AlgorithmType algorithm,                                           \
1961       blas::ProfileResult *output_profile_result) override;                    \
1962   bool DoBlasGemmWithAlgorithm(                                                \
1963       Stream *stream, blas::Transpose transa, blas::Transpose transb,          \
1964       uint64 m, uint64 n, uint64 k, const HostOrDeviceScalar<double> &alpha,   \
1965       const DeviceMemory<double> &a, int lda, const DeviceMemory<double> &b,   \
1966       int ldb, const HostOrDeviceScalar<double> &beta,                         \
1967       DeviceMemory<double> *c, int ldc,                                        \
1968       blas::ComputationType computation_type, blas::AlgorithmType algorithm,   \
1969       blas::ProfileResult *output_profile_result) override;                    \
1970   bool DoBlasGemmWithAlgorithm(                                                \
1971       Stream *stream, blas::Transpose transa, blas::Transpose transb,          \
1972       uint64 m, uint64 n, uint64 k,                                            \
1973       const HostOrDeviceScalar<std::complex<float>> &alpha,                    \
1974       const DeviceMemory<std::complex<float>> &a, int lda,                     \
1975       const DeviceMemory<std::complex<float>> &b, int ldb,                     \
1976       const HostOrDeviceScalar<std::complex<float>> &beta,                     \
1977       DeviceMemory<std::complex<float>> *c, int ldc,                           \
1978       blas::ComputationType computation_type, blas::AlgorithmType algorithm,   \
1979       blas::ProfileResult *output_profile_result) override;                    \
1980   bool DoBlasGemmWithAlgorithm(                                                \
1981       Stream *stream, blas::Transpose transa, blas::Transpose transb,          \
1982       uint64 m, uint64 n, uint64 k,                                            \
1983       const HostOrDeviceScalar<std::complex<double>> &alpha,                   \
1984       const DeviceMemory<std::complex<double>> &a, int lda,                    \
1985       const DeviceMemory<std::complex<double>> &b, int ldb,                    \
1986       const HostOrDeviceScalar<std::complex<double>> &beta,                    \
1987       DeviceMemory<std::complex<double>> *c, int ldc,                          \
1988       blas::ComputationType computation_type, blas::AlgorithmType algorithm,   \
1989       blas::ProfileResult *output_profile_result) override;                    \
1990   bool DoBlasGemmBatched(                                                      \
1991       Stream *stream, blas::Transpose transa, blas::Transpose transb,          \
1992       uint64 m, uint64 n, uint64 k, float alpha,                               \
1993       const port::ArraySlice<DeviceMemory<Eigen::half> *> &a, int lda,         \
1994       const port::ArraySlice<DeviceMemory<Eigen::half> *> &b, int ldb,         \
1995       float beta, const port::ArraySlice<DeviceMemory<Eigen::half> *> &c,      \
1996       int ldc, int batch_count, ScratchAllocator *scratch_allocator) override; \
1997   bool DoBlasGemmBatched(                                                      \
1998       Stream *stream, blas::Transpose transa, blas::Transpose transb,          \
1999       uint64 m, uint64 n, uint64 k, float alpha,                               \
2000       const port::ArraySlice<DeviceMemory<float> *> &a, int lda,               \
2001       const port::ArraySlice<DeviceMemory<float> *> &b, int ldb, float beta,   \
2002       const port::ArraySlice<DeviceMemory<float> *> &c, int ldc,               \
2003       int batch_count, ScratchAllocator *scratch_allocator) override;          \
2004   bool DoBlasGemmBatched(                                                      \
2005       Stream *stream, blas::Transpose transa, blas::Transpose transb,          \
2006       uint64 m, uint64 n, uint64 k, double alpha,                              \
2007       const port::ArraySlice<DeviceMemory<double> *> &a, int lda,              \
2008       const port::ArraySlice<DeviceMemory<double> *> &b, int ldb, double beta, \
2009       const port::ArraySlice<DeviceMemory<double> *> &c, int ldc,              \
2010       int batch_count, ScratchAllocator *scratch_allocator) override;          \
2011   bool DoBlasGemmBatched(                                                      \
2012       Stream *stream, blas::Transpose transa, blas::Transpose transb,          \
2013       uint64 m, uint64 n, uint64 k, std::complex<float> alpha,                 \
2014       const port::ArraySlice<DeviceMemory<std::complex<float>> *> &a, int lda, \
2015       const port::ArraySlice<DeviceMemory<std::complex<float>> *> &b, int ldb, \
2016       std::complex<float> beta,                                                \
2017       const port::ArraySlice<DeviceMemory<std::complex<float>> *> &c, int ldc, \
2018       int batch_count, ScratchAllocator *scratch_allocator) override;          \
2019   bool DoBlasGemmBatched(                                                      \
2020       Stream *stream, blas::Transpose transa, blas::Transpose transb,          \
2021       uint64 m, uint64 n, uint64 k, std::complex<double> alpha,                \
2022       const port::ArraySlice<DeviceMemory<std::complex<double>> *> &a,         \
2023       int lda,                                                                 \
2024       const port::ArraySlice<DeviceMemory<std::complex<double>> *> &b,         \
2025       int ldb, std::complex<double> beta,                                      \
2026       const port::ArraySlice<DeviceMemory<std::complex<double>> *> &c,         \
2027       int ldc, int batch_count, ScratchAllocator *scratch_allocator) override; \
2028   bool DoBlasGemmStridedBatched(                                               \
2029       Stream *stream, blas::Transpose transa, blas::Transpose transb,          \
2030       uint64 m, uint64 n, uint64 k, float alpha,                               \
2031       const DeviceMemory<Eigen::half> &a, int lda, int64 stride_a,             \
2032       const DeviceMemory<Eigen::half> &b, int ldb, int64 stride_b, float beta, \
2033       DeviceMemory<Eigen::half> *c, int ldc, int64 stride_c, int batch_count); \
2034   bool DoBlasGemmStridedBatched(                                               \
2035       Stream *stream, blas::Transpose transa, blas::Transpose transb,          \
2036       uint64 m, uint64 n, uint64 k, float alpha, const DeviceMemory<float> &a, \
2037       int lda, int64 stride_a, const DeviceMemory<float> &b, int ldb,          \
2038       int64 stride_b, float beta, DeviceMemory<float> *c, int ldc,             \
2039       int64 stride_c, int batch_count);                                        \
2040   bool DoBlasGemmStridedBatched(                                               \
2041       Stream *stream, blas::Transpose transa, blas::Transpose transb,          \
2042       uint64 m, uint64 n, uint64 k, double alpha,                              \
2043       const DeviceMemory<double> &a, int lda, int64 stride_a,                  \
2044       const DeviceMemory<double> &b, int ldb, int64 stride_b, double beta,     \
2045       DeviceMemory<double> *c, int ldc, int64 stride_c, int batch_count);      \
2046   bool DoBlasGemmStridedBatched(                                               \
2047       Stream *stream, blas::Transpose transa, blas::Transpose transb,          \
2048       uint64 m, uint64 n, uint64 k, std::complex<float> alpha,                 \
2049       const DeviceMemory<std::complex<float>> &a, int lda, int64 stride_a,     \
2050       const DeviceMemory<std::complex<float>> &b, int ldb, int64 stride_b,     \
2051       std::complex<float> beta, DeviceMemory<std::complex<float>> *c, int ldc, \
2052       int64 stride_c, int batch_count);                                        \
2053   bool DoBlasGemmStridedBatched(                                               \
2054       Stream *stream, blas::Transpose transa, blas::Transpose transb,          \
2055       uint64 m, uint64 n, uint64 k, std::complex<double> alpha,                \
2056       const DeviceMemory<std::complex<double>> &a, int lda, int64 stride_a,    \
2057       const DeviceMemory<std::complex<double>> &b, int ldb, int64 stride_b,    \
2058       std::complex<double> beta, DeviceMemory<std::complex<double>> *c,        \
2059       int ldc, int64 stride_c, int batch_count);                               \
2060   bool DoBlasHemm(Stream *stream, blas::Side side, blas::UpperLower uplo,      \
2061                   uint64 m, uint64 n, std::complex<float> alpha,               \
2062                   const DeviceMemory<std::complex<float>> &a, int lda,         \
2063                   const DeviceMemory<std::complex<float>> &b, int ldb,         \
2064                   std::complex<float> beta,                                    \
2065                   DeviceMemory<std::complex<float>> *c, int ldc) override;     \
2066   bool DoBlasHemm(Stream *stream, blas::Side side, blas::UpperLower uplo,      \
2067                   uint64 m, uint64 n, std::complex<double> alpha,              \
2068                   const DeviceMemory<std::complex<double>> &a, int lda,        \
2069                   const DeviceMemory<std::complex<double>> &b, int ldb,        \
2070                   std::complex<double> beta,                                   \
2071                   DeviceMemory<std::complex<double>> *c, int ldc) override;    \
2072   bool DoBlasHerk(Stream *stream, blas::UpperLower uplo,                       \
2073                   blas::Transpose trans, uint64 n, uint64 k, float alpha,      \
2074                   const DeviceMemory<std::complex<float>> &a, int lda,         \
2075                   float beta, DeviceMemory<std::complex<float>> *c, int ldc)   \
2076       override;                                                                \
2077   bool DoBlasHerk(Stream *stream, blas::UpperLower uplo,                       \
2078                   blas::Transpose trans, uint64 n, uint64 k, double alpha,     \
2079                   const DeviceMemory<std::complex<double>> &a, int lda,        \
2080                   double beta, DeviceMemory<std::complex<double>> *c, int ldc) \
2081       override;                                                                \
2082   bool DoBlasHer2k(                                                            \
2083       Stream *stream, blas::UpperLower uplo, blas::Transpose trans, uint64 n,  \
2084       uint64 k, std::complex<float> alpha,                                     \
2085       const DeviceMemory<std::complex<float>> &a, int lda,                     \
2086       const DeviceMemory<std::complex<float>> &b, int ldb, float beta,         \
2087       DeviceMemory<std::complex<float>> *c, int ldc) override;                 \
2088   bool DoBlasHer2k(                                                            \
2089       Stream *stream, blas::UpperLower uplo, blas::Transpose trans, uint64 n,  \
2090       uint64 k, std::complex<double> alpha,                                    \
2091       const DeviceMemory<std::complex<double>> &a, int lda,                    \
2092       const DeviceMemory<std::complex<double>> &b, int ldb, double beta,       \
2093       DeviceMemory<std::complex<double>> *c, int ldc) override;                \
2094   bool DoBlasSymm(Stream *stream, blas::Side side, blas::UpperLower uplo,      \
2095                   uint64 m, uint64 n, float alpha,                             \
2096                   const DeviceMemory<float> &a, int lda,                       \
2097                   const DeviceMemory<float> &b, int ldb, float beta,           \
2098                   DeviceMemory<float> *c, int ldc) override;                   \
2099   bool DoBlasSymm(Stream *stream, blas::Side side, blas::UpperLower uplo,      \
2100                   uint64 m, uint64 n, double alpha,                            \
2101                   const DeviceMemory<double> &a, int lda,                      \
2102                   const DeviceMemory<double> &b, int ldb, double beta,         \
2103                   DeviceMemory<double> *c, int ldc) override;                  \
2104   bool DoBlasSymm(Stream *stream, blas::Side side, blas::UpperLower uplo,      \
2105                   uint64 m, uint64 n, std::complex<float> alpha,               \
2106                   const DeviceMemory<std::complex<float>> &a, int lda,         \
2107                   const DeviceMemory<std::complex<float>> &b, int ldb,         \
2108                   std::complex<float> beta,                                    \
2109                   DeviceMemory<std::complex<float>> *c, int ldc) override;     \
2110   bool DoBlasSymm(Stream *stream, blas::Side side, blas::UpperLower uplo,      \
2111                   uint64 m, uint64 n, std::complex<double> alpha,              \
2112                   const DeviceMemory<std::complex<double>> &a, int lda,        \
2113                   const DeviceMemory<std::complex<double>> &b, int ldb,        \
2114                   std::complex<double> beta,                                   \
2115                   DeviceMemory<std::complex<double>> *c, int ldc) override;    \
2116   bool DoBlasSyrk(Stream *stream, blas::UpperLower uplo,                       \
2117                   blas::Transpose trans, uint64 n, uint64 k, float alpha,      \
2118                   const DeviceMemory<float> &a, int lda, float beta,           \
2119                   DeviceMemory<float> *c, int ldc) override;                   \
2120   bool DoBlasSyrk(Stream *stream, blas::UpperLower uplo,                       \
2121                   blas::Transpose trans, uint64 n, uint64 k, double alpha,     \
2122                   const DeviceMemory<double> &a, int lda, double beta,         \
2123                   DeviceMemory<double> *c, int ldc) override;                  \
2124   bool DoBlasSyrk(Stream *stream, blas::UpperLower uplo,                       \
2125                   blas::Transpose trans, uint64 n, uint64 k,                   \
2126                   std::complex<float> alpha,                                   \
2127                   const DeviceMemory<std::complex<float>> &a, int lda,         \
2128                   std::complex<float> beta,                                    \
2129                   DeviceMemory<std::complex<float>> *c, int ldc) override;     \
2130   bool DoBlasSyrk(Stream *stream, blas::UpperLower uplo,                       \
2131                   blas::Transpose trans, uint64 n, uint64 k,                   \
2132                   std::complex<double> alpha,                                  \
2133                   const DeviceMemory<std::complex<double>> &a, int lda,        \
2134                   std::complex<double> beta,                                   \
2135                   DeviceMemory<std::complex<double>> *c, int ldc) override;    \
2136   bool DoBlasSyr2k(Stream *stream, blas::UpperLower uplo,                      \
2137                    blas::Transpose trans, uint64 n, uint64 k, float alpha,     \
2138                    const DeviceMemory<float> &a, int lda,                      \
2139                    const DeviceMemory<float> &b, int ldb, float beta,          \
2140                    DeviceMemory<float> *c, int ldc) override;                  \
2141   bool DoBlasSyr2k(Stream *stream, blas::UpperLower uplo,                      \
2142                    blas::Transpose trans, uint64 n, uint64 k, double alpha,    \
2143                    const DeviceMemory<double> &a, int lda,                     \
2144                    const DeviceMemory<double> &b, int ldb, double beta,        \
2145                    DeviceMemory<double> *c, int ldc) override;                 \
2146   bool DoBlasSyr2k(Stream *stream, blas::UpperLower uplo,                      \
2147                    blas::Transpose trans, uint64 n, uint64 k,                  \
2148                    std::complex<float> alpha,                                  \
2149                    const DeviceMemory<std::complex<float>> &a, int lda,        \
2150                    const DeviceMemory<std::complex<float>> &b, int ldb,        \
2151                    std::complex<float> beta,                                   \
2152                    DeviceMemory<std::complex<float>> *c, int ldc) override;    \
2153   bool DoBlasSyr2k(Stream *stream, blas::UpperLower uplo,                      \
2154                    blas::Transpose trans, uint64 n, uint64 k,                  \
2155                    std::complex<double> alpha,                                 \
2156                    const DeviceMemory<std::complex<double>> &a, int lda,       \
2157                    const DeviceMemory<std::complex<double>> &b, int ldb,       \
2158                    std::complex<double> beta,                                  \
2159                    DeviceMemory<std::complex<double>> *c, int ldc) override;   \
2160   bool DoBlasTrmm(Stream *stream, blas::Side side, blas::UpperLower uplo,      \
2161                   blas::Transpose transa, blas::Diagonal diag, uint64 m,       \
2162                   uint64 n, float alpha, const DeviceMemory<float> &a,         \
2163                   int lda, DeviceMemory<float> *b, int ldb) override;          \
2164   bool DoBlasTrmm(Stream *stream, blas::Side side, blas::UpperLower uplo,      \
2165                   blas::Transpose transa, blas::Diagonal diag, uint64 m,       \
2166                   uint64 n, double alpha, const DeviceMemory<double> &a,       \
2167                   int lda, DeviceMemory<double> *b, int ldb) override;         \
2168   bool DoBlasTrmm(Stream *stream, blas::Side side, blas::UpperLower uplo,      \
2169                   blas::Transpose transa, blas::Diagonal diag, uint64 m,       \
2170                   uint64 n, std::complex<float> alpha,                         \
2171                   const DeviceMemory<std::complex<float>> &a, int lda,         \
2172                   DeviceMemory<std::complex<float>> *b, int ldb) override;     \
2173   bool DoBlasTrmm(Stream *stream, blas::Side side, blas::UpperLower uplo,      \
2174                   blas::Transpose transa, blas::Diagonal diag, uint64 m,       \
2175                   uint64 n, std::complex<double> alpha,                        \
2176                   const DeviceMemory<std::complex<double>> &a, int lda,        \
2177                   DeviceMemory<std::complex<double>> *b, int ldb) override;    \
2178   bool DoBlasTrsm(Stream *stream, blas::Side side, blas::UpperLower uplo,      \
2179                   blas::Transpose transa, blas::Diagonal diag, uint64 m,       \
2180                   uint64 n, float alpha, const DeviceMemory<float> &a,         \
2181                   int lda, DeviceMemory<float> *b, int ldb) override;          \
2182   bool DoBlasTrsm(Stream *stream, blas::Side side, blas::UpperLower uplo,      \
2183                   blas::Transpose transa, blas::Diagonal diag, uint64 m,       \
2184                   uint64 n, double alpha, const DeviceMemory<double> &a,       \
2185                   int lda, DeviceMemory<double> *b, int ldb) override;         \
2186   bool DoBlasTrsm(Stream *stream, blas::Side side, blas::UpperLower uplo,      \
2187                   blas::Transpose transa, blas::Diagonal diag, uint64 m,       \
2188                   uint64 n, std::complex<float> alpha,                         \
2189                   const DeviceMemory<std::complex<float>> &a, int lda,         \
2190                   DeviceMemory<std::complex<float>> *b, int ldb) override;     \
2191   bool DoBlasTrsm(Stream *stream, blas::Side side, blas::UpperLower uplo,      \
2192                   blas::Transpose transa, blas::Diagonal diag, uint64 m,       \
2193                   uint64 n, std::complex<double> alpha,                        \
2194                   const DeviceMemory<std::complex<double>> &a, int lda,        \
2195                   DeviceMemory<std::complex<double>> *b, int ldb) override;
2196 
2197 }  // namespace blas
2198 }  // namespace stream_executor
2199 
2200 #endif  // TENSORFLOW_STREAM_EXECUTOR_BLAS_H_
2201