1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/stream_executor/stream.h"
17 
18 #include "tensorflow/stream_executor/platform/port.h"
19 
20 #include "absl/strings/str_cat.h"
21 #include "third_party/eigen3/Eigen/Core"
22 #include "tensorflow/stream_executor/blas.h"
23 #include "tensorflow/stream_executor/host_or_device_scalar.h"
24 #include "tensorflow/stream_executor/lib/stacktrace.h"
25 #include "tensorflow/stream_executor/platform.h"
26 #include "tensorflow/stream_executor/platform/logging.h"
27 #include "tensorflow/stream_executor/rng.h"
28 #include "tensorflow/stream_executor/stream_executor_internal.h"
29 #include "tensorflow/stream_executor/stream_executor_pimpl.h"
30 
31 namespace stream_executor {
32 
33 namespace {
34 // Code to turn parameters to functions on stream into strings that
35 // will be VLOG'ed. We need overloads, instead of
36 // e.g. BatchDescriptorToVlogString(), as the code that calls these
37 // functions does not know what the type of the parameter is.
ToVlogString(const dnn::BatchDescriptor & descriptor)38 string ToVlogString(const dnn::BatchDescriptor &descriptor) {
39   return descriptor.ToShortString();
40 }
41 
ToVlogString(const dnn::FilterDescriptor & descriptor)42 string ToVlogString(const dnn::FilterDescriptor &descriptor) {
43   return descriptor.ToShortString();
44 }
45 
ToVlogString(const dnn::ConvolutionDescriptor & descriptor)46 string ToVlogString(const dnn::ConvolutionDescriptor &descriptor) {
47   return descriptor.ToShortString();
48 }
49 
ToVlogString(const dnn::PoolingDescriptor & descriptor)50 string ToVlogString(const dnn::PoolingDescriptor &descriptor) {
51   return descriptor.ToShortString();
52 }
53 
ToVlogString(const dnn::NormalizeDescriptor & descriptor)54 string ToVlogString(const dnn::NormalizeDescriptor &descriptor) {
55   return descriptor.ToShortString();
56 }
57 
ToVlogString(dnn::ActivationMode mode)58 string ToVlogString(dnn::ActivationMode mode) {
59   return dnn::ActivationModeString(mode);
60 }
61 
ToVlogString(const dnn::AlgorithmConfig & algo_config)62 string ToVlogString(const dnn::AlgorithmConfig &algo_config) {
63   return algo_config.ToString();
64 }
65 
ToVlogString(dnn::ElementwiseOperation op)66 string ToVlogString(dnn::ElementwiseOperation op) {
67   return dnn::ElementwiseOperationString(op);
68 }
69 
ToVlogString(dnn::QuantizedActivationMode mode)70 string ToVlogString(dnn::QuantizedActivationMode mode) {
71   return dnn::QuantizedActivationModeString(mode);
72 }
73 
ToVlogString(blas::Transpose t)74 string ToVlogString(blas::Transpose t) { return blas::TransposeString(t); }
75 
ToVlogString(blas::UpperLower ul)76 string ToVlogString(blas::UpperLower ul) { return blas::UpperLowerString(ul); }
77 
ToVlogString(blas::Diagonal d)78 string ToVlogString(blas::Diagonal d) { return blas::DiagonalString(d); }
79 
ToVlogString(blas::Side s)80 string ToVlogString(blas::Side s) { return blas::SideString(s); }
81 
ToVlogString(blas::ComputationType ty)82 string ToVlogString(blas::ComputationType ty) {
83   return blas::ComputationTypeString(ty);
84 }
85 
ToVlogString(const void * ptr)86 string ToVlogString(const void *ptr) {
87   if (ptr == nullptr) {
88     return "null";
89   }
90 
91   // StrCat does not convert pointers to text.
92   std::ostringstream out;
93   out << ptr;
94   return out.str();
95 }
96 
97 template <class T>
ToVlogString(const std::complex<T> & c)98 string ToVlogString(const std::complex<T> &c) {
99   // StrCat does not convert std::complex to text.
100   std::ostringstream out;
101   out << c;
102   return out.str();
103 }
104 
105 template <class T>
ToVlogString(const std::function<T> & f)106 string ToVlogString(const std::function<T> &f) {
107   return f == nullptr ? "null" : "<non-null function>";
108 }
109 
ToVlogString(const DeviceMemoryBase & memory)110 string ToVlogString(const DeviceMemoryBase &memory) {
111   return ToVlogString(memory.opaque());
112 }
113 
ToVlogString(const DeviceMemoryBase * memory)114 string ToVlogString(const DeviceMemoryBase *memory) {
115   return memory == nullptr ? "null" : ToVlogString(*memory);
116 }
117 
ToVlogString(const Eigen::half & h)118 string ToVlogString(const Eigen::half &h) {
119   return absl::StrCat(static_cast<float>(h));
120 }
121 
ToVlogString(int i)122 string ToVlogString(int i) { return absl::StrCat(i); }
123 
ToVlogString(uint32 i)124 string ToVlogString(uint32 i) { return absl::StrCat(i); }
125 
ToVlogString(uint64 i)126 string ToVlogString(uint64 i) { return absl::StrCat(i); }
127 
ToVlogString(int64 i)128 string ToVlogString(int64 i) { return absl::StrCat(i); }
129 
ToVlogString(float f)130 string ToVlogString(float f) { return absl::StrCat(f); }
131 
ToVlogString(double d)132 string ToVlogString(double d) { return absl::StrCat(d); }
133 
134 template <typename T>
ToVlogString(const HostOrDeviceScalar<T> & memory_or_constant)135 string ToVlogString(const HostOrDeviceScalar<T> &memory_or_constant) {
136   if (memory_or_constant.is_pointer()) {
137     return ToVlogString(memory_or_constant.pointer());
138   }
139   return ToVlogString(memory_or_constant.value());
140 }
141 
142 template <class T>
ToVlogString(port::ArraySlice<T> elements)143 string ToVlogString(port::ArraySlice<T> elements) {
144   string str = absl::StrCat(
145       ToVlogString(reinterpret_cast<const void *>(elements.data())), "[",
146       elements.size(), "]{");
147   const char *separator = "";
148   size_t max_to_show = std::numeric_limits<size_t>::max();
149   if (!VLOG_IS_ON(2)) {
150     max_to_show = 5;
151   } else if (!VLOG_IS_ON(3)) {
152     max_to_show = 20;
153   } else if (!VLOG_IS_ON(11)) {
154     max_to_show = 1000;
155   }
156   for (size_t i = 0; i < elements.size(); ++i) {
157     if (i == max_to_show) {
158       str += ", ...";
159       break;
160     }
161     absl::StrAppend(&str, separator, ToVlogString(elements[i]));
162     separator = ", ";
163   }
164   str += "}";
165   return str;
166 }
167 
168 template <class T>
ToVlogString(port::MutableArraySlice<T> elements)169 string ToVlogString(port::MutableArraySlice<T> elements) {
170   return ToVlogString(port::ArraySlice<T>(elements));
171 }
172 
ToVlogString(dnn::DepthToSpaceLayout depth_to_space_layout)173 string ToVlogString(dnn::DepthToSpaceLayout depth_to_space_layout) {
174   switch (depth_to_space_layout) {
175     case dnn::DepthToSpaceLayout::DepthHeightWidth:
176       return "DepthToSpaceLayout::DepthHeightWidth";
177   }
178   return "unknown DepthToSpaceLayout";
179 }
180 
ToVlogString(dnn::DataType data_type)181 string ToVlogString(dnn::DataType data_type) {
182   switch (data_type) {
183     case dnn::DataType::kFloat:
184       return "dnn::DataType::kFloat";
185     case dnn::DataType::kDouble:
186       return "dnn::DataType::kDouble";
187     case dnn::DataType::kHalf:
188       return "dnn::DataType::kHalf";
189     case dnn::DataType::kInt8:
190       return "dnn::DataType::kInt8";
191     case dnn::DataType::kInt32:
192       return "dnn::DataType::kInt32";
193     default:
194       return "unknown DataType";
195   }
196 }
197 
198 // Used together with PARAM to VLOG calls made to the stream. Intended
199 // to be used like this:
200 //
201 //   VLOG(1) << CallStr("MyFunction", this, {PARAM(a), PARAM(b)});
202 //
203 // where a and b are the parameters to MyFunction.
204 //
205 // See VLOG_CALL for a short-hand for this. This way of doing it saves
206 // a tremendous amount of boilerplate code given how many functions
207 // there are on Stream and how many parameters they each have.
CallStr(const char * function_name,Stream * stream,std::vector<std::pair<const char *,string>> params)208 string CallStr(const char *function_name, Stream *stream,
209                std::vector<std::pair<const char *, string>> params) {
210   // Do not call this function unless VLOG is on since just
211   // constructing all the strings in params is expensive.
212   CHECK(VLOG_IS_ON(1));
213 
214   string str = absl::StrCat(stream->DebugStreamPointers(),
215                             " Called Stream::", function_name, "(");
216   const char *separator = "";
217   for (const auto &param : params) {
218     absl::StrAppend(&str, separator, param.first, "=", param.second);
219     separator = ", ";
220   }
221   absl::StrAppend(&str, ")");
222   if (VLOG_IS_ON(10)) {
223     absl::StrAppend(&str, " ", port::CurrentStackTrace(), "\n");
224   }
225   return str;
226 }
227 
228 // Use this macro to avoid having to type every parameter twice to log
229 // it with VLOG and CallStr.
230 #define PARAM(parameter) \
231   { #parameter, ToVlogString(parameter) }
232 
233 // Use this macro to avoid having to type out the name of each
234 // function and to save some boilerplate. Intended to be used like this:
235 //
236 //   VLOG_CALL(PARAM(a), PARAM(b))
237 //
238 // This saves a tremendous amount of boilerplate compared to the alternative:
239 //
240 //   VLOG(1) << "Calling MyFunction(a=" << ToVlogString(a)
241 //           << ", b=" << ToVlogString(b);
242 //
243 // Note here that most of the parameter names are not short and that
244 // most of the functions take many more than 2 parameters.
245 #define VLOG_CALL(...) VLOG(1) << CallStr(__func__, this, {__VA_ARGS__})
246 
247 }  // namespace
248 
Stream(StreamExecutor * parent)249 Stream::Stream(StreamExecutor *parent)
250     : parent_(parent),
251       implementation_(parent->implementation()->GetStreamImplementation()),
252       allocated_(false),
253       ok_(false),
254       temporary_memory_manager_(this) {
255   VLOG_CALL(PARAM(parent));
256 }
257 
Stream(StreamExecutor * parent,internal::StreamInterface * implementation)258 Stream::Stream(StreamExecutor *parent,
259                internal::StreamInterface *implementation)
260     : parent_(parent),
261       implementation_(implementation),
262       allocated_(false),
263       ok_(false),
264       temporary_memory_manager_(this) {
265   VLOG_CALL(PARAM(parent), PARAM(implementation));
266 }
267 
~Stream()268 Stream::~Stream() {
269   VLOG_CALL();
270 
271   // Ensure the stream is completed.
272   auto status = BlockHostUntilDone();
273   if (!status.ok()) {
274     LOG(WARNING) << "Error blocking host until done in stream destructor: "
275                  << status;
276   }
277   temporary_memory_manager_.ForceDeallocateAll();
278 
279   if (allocated_) {
280     parent_->DeallocateStream(this);
281   }
282 }
283 
RefreshStatus()284 port::Status Stream::RefreshStatus() {
285   port::Status status = parent_->GetStatus(this);
286   CheckStatus(status);
287   return status;
288 }
289 
Init()290 Stream &Stream::Init() {
291   VLOG_CALL();
292 
293   mutex_lock lock(mu_);
294   CHECK_EQ(false, allocated_)
295       << "stream appears to already have been initialized";
296   CHECK(!ok_) << "stream should be in !ok() state pre-initialization";
297 
298   if (parent_->AllocateStream(this)) {
299     // Successful initialization!
300     allocated_ = true;
301     ok_ = true;
302   } else {
303     LOG(ERROR) << "failed to allocate stream during initialization";
304   }
305 
306   return *this;
307 }
308 
InitTimer(Timer * timer)309 Stream &Stream::InitTimer(Timer *timer) {
310   VLOG_CALL(PARAM(timer));
311 
312   if (ok()) {
313     CheckError(parent_->AllocateTimer(timer));
314   } else {
315     LOG(INFO) << "did not allocate timer: " << timer;
316   }
317   return *this;
318 }
319 
InitWithTimer(Timer * timer)320 Stream &Stream::InitWithTimer(Timer *timer) {
321   VLOG_CALL(PARAM(timer));
322 
323   return Init().InitTimer(timer);
324 }
325 
ThenRecordEvent(Event * event)326 Stream &Stream::ThenRecordEvent(Event *event) {
327   VLOG_CALL(PARAM(event));
328 
329   port::Status status = parent_->RecordEvent(this, event);
330   if (!status.ok()) {
331     LOG(ERROR) << "Error recording event in stream: " << status.error_message()
332                << "; not marking stream as bad, as the Event object may be "
333                << "at fault. Monitor for further errors.";
334   }
335 
336   return *this;
337 }
338 
ThenBatchNormalizationForward(const DeviceMemory<float> & x,const DeviceMemory<float> & scale,const DeviceMemory<float> & offset,const DeviceMemory<float> & estimated_mean,const DeviceMemory<float> & estimated_variance,const dnn::BatchDescriptor & x_desc,const dnn::BatchDescriptor & scale_offset_desc,const double epsilon,DeviceMemory<float> * y,DeviceMemory<float> * batch_mean,DeviceMemory<float> * batch_var,DeviceMemory<float> * saved_mean,DeviceMemory<float> * saved_inv_var,bool is_training,std::function<const DeviceMemory<float> & ()> var_to_inv_var,std::function<void ()> inv_var_to_var)339 Stream &Stream::ThenBatchNormalizationForward(
340     const DeviceMemory<float> &x, const DeviceMemory<float> &scale,
341     const DeviceMemory<float> &offset,
342     const DeviceMemory<float> &estimated_mean,
343     const DeviceMemory<float> &estimated_variance,
344     const dnn::BatchDescriptor &x_desc,
345     const dnn::BatchDescriptor &scale_offset_desc, const double epsilon,
346     DeviceMemory<float> *y, DeviceMemory<float> *batch_mean,
347     DeviceMemory<float> *batch_var, DeviceMemory<float> *saved_mean,
348     DeviceMemory<float> *saved_inv_var, bool is_training,
349     std::function<const DeviceMemory<float> &()> var_to_inv_var,
350     std::function<void()> inv_var_to_var) {
351   VLOG_CALL(PARAM(x), PARAM(scale), PARAM(offset), PARAM(x_desc),
352             PARAM(scale_offset_desc), PARAM(epsilon), PARAM(y));
353   if (ok()) {
354     if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
355       CheckError(dnn->DoBatchNormalizationForward(
356           this, x, scale, offset, estimated_mean, estimated_variance, x_desc,
357           scale_offset_desc, epsilon, y, batch_mean, batch_var, saved_mean,
358           saved_inv_var, is_training, std::move(var_to_inv_var),
359           std::move(inv_var_to_var)));
360     } else {
361       SetErrorAndLogNoDnnSupport();
362     }
363   }
364   return *this;
365 }
366 
ThenBatchNormalizationBackward(const DeviceMemory<float> & y_backprop,const DeviceMemory<float> & x,const DeviceMemory<float> & scale,const DeviceMemory<float> & mean,const DeviceMemory<float> & inv_var,const dnn::BatchDescriptor & x_desc,const dnn::BatchDescriptor & scale_offset_desc,const double epsilon,DeviceMemory<float> * x_backprop,DeviceMemory<float> * scale_backprop,DeviceMemory<float> * offset_backprop)367 Stream &Stream::ThenBatchNormalizationBackward(
368     const DeviceMemory<float> &y_backprop, const DeviceMemory<float> &x,
369     const DeviceMemory<float> &scale, const DeviceMemory<float> &mean,
370     const DeviceMemory<float> &inv_var, const dnn::BatchDescriptor &x_desc,
371     const dnn::BatchDescriptor &scale_offset_desc, const double epsilon,
372     DeviceMemory<float> *x_backprop, DeviceMemory<float> *scale_backprop,
373     DeviceMemory<float> *offset_backprop) {
374   VLOG_CALL(PARAM(y_backprop), PARAM(x), PARAM(scale), PARAM(x_desc),
375             PARAM(scale_offset_desc), PARAM(epsilon), PARAM(x_backprop),
376             PARAM(scale_backprop), PARAM(offset_backprop));
377   if (ok()) {
378     if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
379       CheckError(dnn->DoBatchNormalizationBackward(
380           this, y_backprop, x, scale, mean, inv_var, x_desc, scale_offset_desc,
381           epsilon, x_backprop, scale_backprop, offset_backprop));
382     } else {
383       SetErrorAndLogNoDnnSupport();
384     }
385   }
386   return *this;
387 }
388 
ThenBatchNormalizationForward(const DeviceMemory<Eigen::half> & x,const DeviceMemory<float> & scale,const DeviceMemory<float> & offset,const DeviceMemory<float> & estimated_mean,const DeviceMemory<float> & estimated_variance,const dnn::BatchDescriptor & x_desc,const dnn::BatchDescriptor & scale_offset_desc,const double epsilon,DeviceMemory<Eigen::half> * y,DeviceMemory<float> * batch_mean,DeviceMemory<float> * batch_var,DeviceMemory<float> * saved_mean,DeviceMemory<float> * saved_inv_var,bool is_training,std::function<const DeviceMemory<float> & ()> var_to_inv_var,std::function<void ()> inv_var_to_var)389 Stream &Stream::ThenBatchNormalizationForward(
390     const DeviceMemory<Eigen::half> &x, const DeviceMemory<float> &scale,
391     const DeviceMemory<float> &offset,
392     const DeviceMemory<float> &estimated_mean,
393     const DeviceMemory<float> &estimated_variance,
394     const dnn::BatchDescriptor &x_desc,
395     const dnn::BatchDescriptor &scale_offset_desc, const double epsilon,
396     DeviceMemory<Eigen::half> *y, DeviceMemory<float> *batch_mean,
397     DeviceMemory<float> *batch_var, DeviceMemory<float> *saved_mean,
398     DeviceMemory<float> *saved_inv_var, bool is_training,
399     std::function<const DeviceMemory<float> &()> var_to_inv_var,
400     std::function<void()> inv_var_to_var) {
401   VLOG_CALL(PARAM(x), PARAM(scale), PARAM(offset), PARAM(x_desc),
402             PARAM(scale_offset_desc), PARAM(epsilon), PARAM(y));
403   if (ok()) {
404     if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
405       CheckError(dnn->DoBatchNormalizationForward(
406           this, x, scale, offset, estimated_mean, estimated_variance, x_desc,
407           scale_offset_desc, epsilon, y, batch_mean, batch_var, saved_mean,
408           saved_inv_var, is_training, std::move(var_to_inv_var),
409           std::move(inv_var_to_var)));
410     } else {
411       SetErrorAndLogNoDnnSupport();
412     }
413   }
414   return *this;
415 }
416 
ThenBatchNormalizationBackward(const DeviceMemory<Eigen::half> & y_backprop,const DeviceMemory<Eigen::half> & x,const DeviceMemory<float> & scale,const DeviceMemory<float> & mean,const DeviceMemory<float> & inv_var,const dnn::BatchDescriptor & x_desc,const dnn::BatchDescriptor & scale_offset_desc,const double epsilon,DeviceMemory<Eigen::half> * x_backprop,DeviceMemory<float> * scale_backprop,DeviceMemory<float> * offset_backprop)417 Stream &Stream::ThenBatchNormalizationBackward(
418     const DeviceMemory<Eigen::half> &y_backprop,
419     const DeviceMemory<Eigen::half> &x, const DeviceMemory<float> &scale,
420     const DeviceMemory<float> &mean, const DeviceMemory<float> &inv_var,
421     const dnn::BatchDescriptor &x_desc,
422     const dnn::BatchDescriptor &scale_offset_desc, const double epsilon,
423     DeviceMemory<Eigen::half> *x_backprop, DeviceMemory<float> *scale_backprop,
424     DeviceMemory<float> *offset_backprop) {
425   VLOG_CALL(PARAM(y_backprop), PARAM(x), PARAM(scale), PARAM(x_desc),
426             PARAM(scale_offset_desc), PARAM(epsilon), PARAM(x_backprop),
427             PARAM(scale_backprop), PARAM(offset_backprop));
428   if (ok()) {
429     if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
430       CheckError(dnn->DoBatchNormalizationBackward(
431           this, y_backprop, x, scale, mean, inv_var, x_desc, scale_offset_desc,
432           epsilon, x_backprop, scale_backprop, offset_backprop));
433     } else {
434       SetErrorAndLogNoDnnSupport();
435     }
436   }
437   return *this;
438 }
439 
ThenFusedConvolveWithAlgorithm(const dnn::BatchDescriptor & conv_input_descriptor,const DeviceMemory<double> & conv_input_data,double conv_input_scale,const dnn::FilterDescriptor & filter_descriptor,const DeviceMemory<double> & filter_data,const dnn::ConvolutionDescriptor & convolution_descriptor,const DeviceMemory<double> & side_input_data,double side_input_scale,const dnn::BatchDescriptor & bias_descriptor,const DeviceMemory<double> & biases,dnn::ActivationMode activation_mode,const dnn::BatchDescriptor & output_descriptor,DeviceMemory<double> * output,ScratchAllocator * scratch_allocator,const dnn::AlgorithmConfig & algorithm_config,dnn::ProfileResult * output_profile_result)440 Stream &Stream::ThenFusedConvolveWithAlgorithm(
441     const dnn::BatchDescriptor &conv_input_descriptor,
442     const DeviceMemory<double> &conv_input_data, double conv_input_scale,
443     const dnn::FilterDescriptor &filter_descriptor,
444     const DeviceMemory<double> &filter_data,
445     const dnn::ConvolutionDescriptor &convolution_descriptor,
446     const DeviceMemory<double> &side_input_data, double side_input_scale,
447     const dnn::BatchDescriptor &bias_descriptor,
448     const DeviceMemory<double> &biases, dnn::ActivationMode activation_mode,
449     const dnn::BatchDescriptor &output_descriptor, DeviceMemory<double> *output,
450     ScratchAllocator *scratch_allocator,
451     const dnn::AlgorithmConfig &algorithm_config,
452     dnn::ProfileResult *output_profile_result) {
453   VLOG_CALL(PARAM(conv_input_descriptor), PARAM(conv_input_data),
454             PARAM(conv_input_scale), PARAM(filter_descriptor),
455             PARAM(filter_data), PARAM(convolution_descriptor), PARAM(biases),
456             PARAM(side_input_data), PARAM(side_input_scale),
457             PARAM(activation_mode), PARAM(output_descriptor), PARAM(output),
458             PARAM(algorithm_config));
459 
460   if (ok()) {
461     if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
462       auto status = dnn->DoFusedConvolve(
463           this, conv_input_descriptor, conv_input_data, conv_input_scale,
464           filter_descriptor, filter_data, convolution_descriptor,
465           side_input_data, side_input_scale, bias_descriptor, biases,
466           activation_mode, output_descriptor, output, scratch_allocator,
467           algorithm_config, output_profile_result);
468       if (!status && !output_profile_result) {
469         SetError();
470       }
471     } else {
472       SetErrorAndLogNoDnnSupport();
473     }
474   }
475   return *this;
476 }
477 
ThenFusedConvolveWithAlgorithm(const dnn::BatchDescriptor & conv_input_descriptor,const DeviceMemory<float> & conv_input_data,float conv_input_scale,const dnn::FilterDescriptor & filter_descriptor,const DeviceMemory<float> & filter_data,const dnn::ConvolutionDescriptor & convolution_descriptor,const DeviceMemory<float> & side_input_data,float side_input_scale,const dnn::BatchDescriptor & bias_descriptor,const DeviceMemory<float> & biases,dnn::ActivationMode activation_mode,const dnn::BatchDescriptor & output_descriptor,DeviceMemory<float> * output,ScratchAllocator * scratch_allocator,const dnn::AlgorithmConfig & algorithm_config,dnn::ProfileResult * output_profile_result)478 Stream &Stream::ThenFusedConvolveWithAlgorithm(
479     const dnn::BatchDescriptor &conv_input_descriptor,
480     const DeviceMemory<float> &conv_input_data, float conv_input_scale,
481     const dnn::FilterDescriptor &filter_descriptor,
482     const DeviceMemory<float> &filter_data,
483     const dnn::ConvolutionDescriptor &convolution_descriptor,
484     const DeviceMemory<float> &side_input_data, float side_input_scale,
485     const dnn::BatchDescriptor &bias_descriptor,
486     const DeviceMemory<float> &biases, dnn::ActivationMode activation_mode,
487     const dnn::BatchDescriptor &output_descriptor, DeviceMemory<float> *output,
488     ScratchAllocator *scratch_allocator,
489     const dnn::AlgorithmConfig &algorithm_config,
490     dnn::ProfileResult *output_profile_result) {
491   VLOG_CALL(PARAM(conv_input_descriptor), PARAM(conv_input_data),
492             PARAM(conv_input_scale), PARAM(filter_descriptor),
493             PARAM(filter_data), PARAM(convolution_descriptor), PARAM(biases),
494             PARAM(side_input_data), PARAM(side_input_scale),
495             PARAM(activation_mode), PARAM(output_descriptor), PARAM(output),
496             PARAM(algorithm_config));
497 
498   if (ok()) {
499     if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
500       auto status = dnn->DoFusedConvolve(
501           this, conv_input_descriptor, conv_input_data, conv_input_scale,
502           filter_descriptor, filter_data, convolution_descriptor,
503           side_input_data, side_input_scale, bias_descriptor, biases,
504           activation_mode, output_descriptor, output, scratch_allocator,
505           algorithm_config, output_profile_result);
506       if (!status && !output_profile_result) {
507         SetError();
508       }
509     } else {
510       SetErrorAndLogNoDnnSupport();
511     }
512   }
513   return *this;
514 }
515 
ThenFusedConvolveWithAlgorithm(const dnn::BatchDescriptor & conv_input_descriptor,const DeviceMemory<Eigen::half> & conv_input_data,float conv_input_scale,const dnn::FilterDescriptor & filter_descriptor,const DeviceMemory<Eigen::half> & filter_data,const dnn::ConvolutionDescriptor & convolution_descriptor,const DeviceMemory<Eigen::half> & side_input_data,float side_input_scale,const dnn::BatchDescriptor & bias_descriptor,const DeviceMemory<Eigen::half> & biases,dnn::ActivationMode activation_mode,const dnn::BatchDescriptor & output_descriptor,DeviceMemory<Eigen::half> * output,ScratchAllocator * scratch_allocator,const dnn::AlgorithmConfig & algorithm_config,dnn::ProfileResult * output_profile_result)516 Stream &Stream::ThenFusedConvolveWithAlgorithm(
517     const dnn::BatchDescriptor &conv_input_descriptor,
518     const DeviceMemory<Eigen::half> &conv_input_data, float conv_input_scale,
519     const dnn::FilterDescriptor &filter_descriptor,
520     const DeviceMemory<Eigen::half> &filter_data,
521     const dnn::ConvolutionDescriptor &convolution_descriptor,
522     const DeviceMemory<Eigen::half> &side_input_data, float side_input_scale,
523     const dnn::BatchDescriptor &bias_descriptor,
524     const DeviceMemory<Eigen::half> &biases,
525     dnn::ActivationMode activation_mode,
526     const dnn::BatchDescriptor &output_descriptor,
527     DeviceMemory<Eigen::half> *output, ScratchAllocator *scratch_allocator,
528     const dnn::AlgorithmConfig &algorithm_config,
529     dnn::ProfileResult *output_profile_result) {
530   VLOG_CALL(PARAM(conv_input_descriptor), PARAM(conv_input_data),
531             PARAM(conv_input_scale), PARAM(filter_descriptor),
532             PARAM(filter_data), PARAM(convolution_descriptor), PARAM(biases),
533             PARAM(side_input_data), PARAM(side_input_scale),
534             PARAM(bias_descriptor), PARAM(biases), PARAM(activation_mode),
535             PARAM(output_descriptor), PARAM(output), PARAM(algorithm_config));
536 
537   if (ok()) {
538     if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
539       auto status = dnn->DoFusedConvolve(
540           this, conv_input_descriptor, conv_input_data, conv_input_scale,
541           filter_descriptor, filter_data, convolution_descriptor,
542           side_input_data, side_input_scale, bias_descriptor, biases,
543           activation_mode, output_descriptor, output, scratch_allocator,
544           algorithm_config, output_profile_result);
545       if (!status && !output_profile_result) {
546         SetError();
547       }
548     } else {
549       SetErrorAndLogNoDnnSupport();
550     }
551   }
552   return *this;
553 }
554 
ThenFusedConvolveWithAlgorithm(const dnn::BatchDescriptor & conv_input_descriptor,const DeviceMemory<int8> & conv_input_data,float conv_input_scale,const dnn::FilterDescriptor & filter_descriptor,const DeviceMemory<int8> & filter_data,const dnn::ConvolutionDescriptor & convolution_descriptor,const DeviceMemory<int8> & side_input_data,float side_input_scale,const dnn::BatchDescriptor & bias_descriptor,const DeviceMemory<float> & biases,dnn::ActivationMode activation_mode,const dnn::BatchDescriptor & output_descriptor,DeviceMemory<int8> * output,ScratchAllocator * scratch_allocator,const dnn::AlgorithmConfig & algorithm_config,dnn::ProfileResult * output_profile_result)555 Stream &Stream::ThenFusedConvolveWithAlgorithm(
556     const dnn::BatchDescriptor &conv_input_descriptor,
557     const DeviceMemory<int8> &conv_input_data, float conv_input_scale,
558     const dnn::FilterDescriptor &filter_descriptor,
559     const DeviceMemory<int8> &filter_data,
560     const dnn::ConvolutionDescriptor &convolution_descriptor,
561     const DeviceMemory<int8> &side_input_data, float side_input_scale,
562     const dnn::BatchDescriptor &bias_descriptor,
563     const DeviceMemory<float> &biases, dnn::ActivationMode activation_mode,
564     const dnn::BatchDescriptor &output_descriptor, DeviceMemory<int8> *output,
565     ScratchAllocator *scratch_allocator,
566     const dnn::AlgorithmConfig &algorithm_config,
567     dnn::ProfileResult *output_profile_result) {
568   VLOG_CALL(PARAM(conv_input_descriptor), PARAM(conv_input_data),
569             PARAM(conv_input_scale), PARAM(filter_descriptor),
570             PARAM(filter_data), PARAM(convolution_descriptor), PARAM(biases),
571             PARAM(side_input_data), PARAM(side_input_scale),
572             PARAM(bias_descriptor), PARAM(biases), PARAM(activation_mode),
573             PARAM(output_descriptor), PARAM(output), PARAM(algorithm_config));
574 
575   if (ok()) {
576     if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
577       auto status = dnn->DoFusedConvolve(
578           this, conv_input_descriptor, conv_input_data, conv_input_scale,
579           filter_descriptor, filter_data, convolution_descriptor,
580           side_input_data, side_input_scale, bias_descriptor, biases,
581           activation_mode, output_descriptor, output, scratch_allocator,
582           algorithm_config, output_profile_result);
583       if (!status && !output_profile_result) {
584         SetError();
585       }
586     } else {
587       SetErrorAndLogNoDnnSupport();
588     }
589   }
590   return *this;
591 }
592 
ThenConvolveWithAlgorithm(const dnn::BatchDescriptor & input_descriptor,const DeviceMemory<double> & input_data,const dnn::FilterDescriptor & filter_descriptor,const DeviceMemory<double> & filter_data,const dnn::ConvolutionDescriptor & convolution_descriptor,const dnn::BatchDescriptor & output_descriptor,DeviceMemory<double> * output,ScratchAllocator * scratch_allocator,const dnn::AlgorithmConfig & algorithm_config,dnn::ProfileResult * output_profile_result)593 Stream &Stream::ThenConvolveWithAlgorithm(
594     const dnn::BatchDescriptor &input_descriptor,
595     const DeviceMemory<double> &input_data,
596     const dnn::FilterDescriptor &filter_descriptor,
597     const DeviceMemory<double> &filter_data,
598     const dnn::ConvolutionDescriptor &convolution_descriptor,
599     const dnn::BatchDescriptor &output_descriptor, DeviceMemory<double> *output,
600     ScratchAllocator *scratch_allocator,
601     const dnn::AlgorithmConfig &algorithm_config,
602     dnn::ProfileResult *output_profile_result) {
603   VLOG_CALL(PARAM(input_descriptor), PARAM(input_data),
604             PARAM(filter_descriptor), PARAM(filter_data),
605             PARAM(convolution_descriptor), PARAM(output_descriptor),
606             PARAM(output), PARAM(algorithm_config));
607 
608   if (ok()) {
609     if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
610       DeviceMemory<uint8> scratch_memory;
611       dnn::AlgorithmDesc algorithm_desc;
612       auto status =
613           dnn->PrepareForConvolution(
614                  dnn::ConvolutionKind::FORWARD, this, input_descriptor,
615                  input_data, filter_descriptor, filter_data, output_descriptor,
616                  *output, convolution_descriptor, algorithm_config,
617                  scratch_allocator, &algorithm_desc, &scratch_memory)
618               .ok();
619       if (status) {
620         status = dnn->DoConvolve(
621             this, input_descriptor, input_data, filter_descriptor, filter_data,
622             convolution_descriptor, output_descriptor, output, algorithm_desc,
623             &scratch_memory, output_profile_result);
624       }
625       if (!status && !output_profile_result) {
626         SetError();
627       }
628     } else {
629       SetErrorAndLogNoDnnSupport();
630     }
631   }
632   return *this;
633 }
634 
ThenConvolveWithAlgorithm(const dnn::BatchDescriptor & input_descriptor,const DeviceMemory<float> & input_data,const dnn::FilterDescriptor & filter_descriptor,const DeviceMemory<float> & filter_data,const dnn::ConvolutionDescriptor & convolution_descriptor,const dnn::BatchDescriptor & output_descriptor,DeviceMemory<float> * output,ScratchAllocator * scratch_allocator,const dnn::AlgorithmConfig & algorithm_config,dnn::ProfileResult * output_profile_result)635 Stream &Stream::ThenConvolveWithAlgorithm(
636     const dnn::BatchDescriptor &input_descriptor,
637     const DeviceMemory<float> &input_data,
638     const dnn::FilterDescriptor &filter_descriptor,
639     const DeviceMemory<float> &filter_data,
640     const dnn::ConvolutionDescriptor &convolution_descriptor,
641     const dnn::BatchDescriptor &output_descriptor, DeviceMemory<float> *output,
642     ScratchAllocator *scratch_allocator,
643     const dnn::AlgorithmConfig &algorithm_config,
644     dnn::ProfileResult *output_profile_result) {
645   VLOG_CALL(PARAM(input_descriptor), PARAM(input_data),
646             PARAM(filter_descriptor), PARAM(filter_data),
647             PARAM(convolution_descriptor), PARAM(output_descriptor),
648             PARAM(output), PARAM(algorithm_config));
649 
650   if (ok()) {
651     if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
652       DeviceMemory<uint8> scratch_memory;
653       dnn::AlgorithmDesc algorithm_desc;
654       auto status =
655           dnn->PrepareForConvolution(
656                  dnn::ConvolutionKind::FORWARD, this, input_descriptor,
657                  input_data, filter_descriptor, filter_data, output_descriptor,
658                  *output, convolution_descriptor, algorithm_config,
659                  scratch_allocator, &algorithm_desc, &scratch_memory)
660               .ok();
661       if (status) {
662         status = dnn->DoConvolve(
663             this, input_descriptor, input_data, filter_descriptor, filter_data,
664             convolution_descriptor, output_descriptor, output, algorithm_desc,
665             &scratch_memory, output_profile_result);
666       }
667       if (!status && !output_profile_result) {
668         SetError();
669       }
670     } else {
671       SetErrorAndLogNoDnnSupport();
672     }
673   }
674   return *this;
675 }
676 
ThenConvolveWithAlgorithm(const dnn::BatchDescriptor & input_descriptor,const DeviceMemory<Eigen::half> & input_data,const dnn::FilterDescriptor & filter_descriptor,const DeviceMemory<Eigen::half> & filter_data,const dnn::ConvolutionDescriptor & convolution_descriptor,const dnn::BatchDescriptor & output_descriptor,DeviceMemory<Eigen::half> * output,ScratchAllocator * scratch_allocator,const dnn::AlgorithmConfig & algorithm_config,dnn::ProfileResult * output_profile_result)677 Stream &Stream::ThenConvolveWithAlgorithm(
678     const dnn::BatchDescriptor &input_descriptor,
679     const DeviceMemory<Eigen::half> &input_data,
680     const dnn::FilterDescriptor &filter_descriptor,
681     const DeviceMemory<Eigen::half> &filter_data,
682     const dnn::ConvolutionDescriptor &convolution_descriptor,
683     const dnn::BatchDescriptor &output_descriptor,
684     DeviceMemory<Eigen::half> *output, ScratchAllocator *scratch_allocator,
685     const dnn::AlgorithmConfig &algorithm_config,
686     dnn::ProfileResult *output_profile_result) {
687   VLOG_CALL(PARAM(input_descriptor), PARAM(input_data),
688             PARAM(filter_descriptor), PARAM(filter_data),
689             PARAM(convolution_descriptor), PARAM(output_descriptor),
690             PARAM(output), PARAM(algorithm_config));
691 
692   if (ok()) {
693     if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
694       DeviceMemory<uint8> scratch_memory;
695       dnn::AlgorithmDesc algorithm_desc;
696       auto status =
697           dnn->PrepareForConvolution(
698                  dnn::ConvolutionKind::FORWARD, this, input_descriptor,
699                  input_data, filter_descriptor, filter_data, output_descriptor,
700                  *output, convolution_descriptor, algorithm_config,
701                  scratch_allocator, &algorithm_desc, &scratch_memory)
702               .ok();
703       if (status) {
704         status = dnn->DoConvolve(
705             this, input_descriptor, input_data, filter_descriptor, filter_data,
706             convolution_descriptor, output_descriptor, output, algorithm_desc,
707             &scratch_memory, output_profile_result);
708       }
709       if (!status && !output_profile_result) {
710         SetError();
711       }
712     } else {
713       SetErrorAndLogNoDnnSupport();
714     }
715   }
716   return *this;
717 }
718 
ThenConvolve(const dnn::BatchDescriptor & input_descriptor,const DeviceMemory<float> & input_data,const dnn::FilterDescriptor & filter_descriptor,const DeviceMemory<float> & filter_data,const dnn::ConvolutionDescriptor & convolution_descriptor,const dnn::BatchDescriptor & output_descriptor,DeviceMemory<float> * output)719 Stream &Stream::ThenConvolve(
720     const dnn::BatchDescriptor &input_descriptor,
721     const DeviceMemory<float> &input_data,
722     const dnn::FilterDescriptor &filter_descriptor,
723     const DeviceMemory<float> &filter_data,
724     const dnn::ConvolutionDescriptor &convolution_descriptor,
725     const dnn::BatchDescriptor &output_descriptor,
726     DeviceMemory<float> *output) {
727   return ThenConvolveWithAlgorithm(
728       input_descriptor, input_data, filter_descriptor, filter_data,
729       convolution_descriptor, output_descriptor, output,
730       /*scratch_allocator=*/nullptr, dnn::AlgorithmConfig(),
731       /*output_profile_result=*/nullptr);
732 }
733 
ThenConvolveQuantized(const dnn::BatchDescriptor & input_descriptor,const DeviceMemory<float> & input_data,const dnn::FilterDescriptor & filter_descriptor,const DeviceMemory<int8> & filter_coefficients,const DeviceMemory<float> & coefficient_scales,const dnn::ConvolutionDescriptor & convolution_descriptor,const dnn::BatchDescriptor & output_descriptor,DeviceMemory<float> * output)734 Stream &Stream::ThenConvolveQuantized(
735     const dnn::BatchDescriptor &input_descriptor,
736     const DeviceMemory<float> &input_data,
737     const dnn::FilterDescriptor &filter_descriptor,
738     const DeviceMemory<int8> &filter_coefficients,
739     const DeviceMemory<float> &coefficient_scales,
740     const dnn::ConvolutionDescriptor &convolution_descriptor,
741     const dnn::BatchDescriptor &output_descriptor,
742     DeviceMemory<float> *output) {
743   VLOG_CALL(PARAM(input_descriptor), PARAM(input_data),
744             PARAM(filter_descriptor), PARAM(filter_coefficients),
745             PARAM(coefficient_scales), PARAM(convolution_descriptor),
746             PARAM(output_descriptor), PARAM(output));
747 
748   if (ok()) {
749     if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
750       CheckError(dnn->DoConvolveQuantized(
751           this, input_descriptor, input_data, filter_descriptor,
752           filter_coefficients, coefficient_scales, convolution_descriptor,
753           output_descriptor, output));
754     } else {
755       SetError();
756       LOG(WARNING)
757           << "attempting to perform DNN operation using StreamExecutor "
758              "without DNN support";
759     }
760   }
761   return *this;
762 }
763 
ThenConvolveQuantized(const dnn::BatchDescriptor & input_descriptor,const DeviceMemory<float> & input_data,const dnn::FilterDescriptor & filter_descriptor,const DeviceMemory<int16> & filter_coefficients,const DeviceMemory<float> & coefficient_scales,const dnn::ConvolutionDescriptor & convolution_descriptor,const dnn::BatchDescriptor & output_descriptor,DeviceMemory<float> * output)764 Stream &Stream::ThenConvolveQuantized(
765     const dnn::BatchDescriptor &input_descriptor,
766     const DeviceMemory<float> &input_data,
767     const dnn::FilterDescriptor &filter_descriptor,
768     const DeviceMemory<int16> &filter_coefficients,
769     const DeviceMemory<float> &coefficient_scales,
770     const dnn::ConvolutionDescriptor &convolution_descriptor,
771     const dnn::BatchDescriptor &output_descriptor,
772     DeviceMemory<float> *output) {
773   VLOG_CALL(PARAM(input_descriptor), PARAM(input_data),
774             PARAM(filter_descriptor), PARAM(filter_coefficients),
775             PARAM(coefficient_scales), PARAM(convolution_descriptor),
776             PARAM(output_descriptor), PARAM(output));
777 
778   if (ok()) {
779     if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
780       CheckError(dnn->DoConvolveQuantized(
781           this, input_descriptor, input_data, filter_descriptor,
782           filter_coefficients, coefficient_scales, convolution_descriptor,
783           output_descriptor, output));
784     } else {
785       SetError();
786       LOG(WARNING)
787           << "attempting to perform DNN operation using StreamExecutor "
788              "without DNN support";
789     }
790   }
791   return *this;
792 }
793 
ThenSeparableConvolve(const dnn::BatchDescriptor & batch_descriptor,const DeviceMemory<float> & input_data,const dnn::FilterDescriptor & filter_descriptor,int depth_multiplier,const DeviceMemory<float> & first_weights,const DeviceMemory<float> & second_weights,const dnn::ConvolutionDescriptor & convolution_descriptor,const dnn::BatchDescriptor & output_descriptor,DeviceMemory<float> * output)794 Stream &Stream::ThenSeparableConvolve(
795     const dnn::BatchDescriptor &batch_descriptor,
796     const DeviceMemory<float> &input_data,
797     const dnn::FilterDescriptor &filter_descriptor, int depth_multiplier,
798     const DeviceMemory<float> &first_weights,
799     const DeviceMemory<float> &second_weights,
800     const dnn::ConvolutionDescriptor &convolution_descriptor,
801     const dnn::BatchDescriptor &output_descriptor,
802     DeviceMemory<float> *output) {
803   VLOG_CALL(
804       PARAM(batch_descriptor), PARAM(input_data), PARAM(filter_descriptor),
805       PARAM(depth_multiplier), PARAM(first_weights), PARAM(second_weights),
806       PARAM(convolution_descriptor), PARAM(output_descriptor), PARAM(output));
807 
808   if (ok()) {
809     if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
810       CheckError(dnn->DoSeparableConvolve(
811           this, batch_descriptor, input_data, filter_descriptor,
812           depth_multiplier, first_weights, second_weights,
813           convolution_descriptor, output_descriptor, output));
814     } else {
815       SetErrorAndLogNoDnnSupport();
816     }
817   }
818   return *this;
819 }
820 
ThenConvolveBackwardDataWithAlgorithm(const dnn::FilterDescriptor & filter_descriptor,const DeviceMemory<double> & filter_data,const dnn::BatchDescriptor & output_descriptor,DeviceMemory<double> backward_output_data,const dnn::ConvolutionDescriptor & convolution_descriptor,const dnn::BatchDescriptor & input_descriptor,DeviceMemory<double> * backward_input_data,ScratchAllocator * scratch_allocator,const dnn::AlgorithmConfig & algorithm_config,dnn::ProfileResult * output_profile_result)821 Stream &Stream::ThenConvolveBackwardDataWithAlgorithm(
822     const dnn::FilterDescriptor &filter_descriptor,
823     const DeviceMemory<double> &filter_data,
824     const dnn::BatchDescriptor &output_descriptor,
825     DeviceMemory<double> backward_output_data,
826     const dnn::ConvolutionDescriptor &convolution_descriptor,
827     const dnn::BatchDescriptor &input_descriptor,
828     DeviceMemory<double> *backward_input_data,
829     ScratchAllocator *scratch_allocator,
830     const dnn::AlgorithmConfig &algorithm_config,
831     dnn::ProfileResult *output_profile_result) {
832   VLOG_CALL(PARAM(filter_descriptor), PARAM(filter_data),
833             PARAM(output_descriptor), PARAM(backward_output_data),
834             PARAM(convolution_descriptor), PARAM(input_descriptor),
835             PARAM(backward_input_data));
836 
837   if (ok()) {
838     if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
839       DeviceMemory<uint8> scratch_memory;
840       dnn::AlgorithmDesc algorithm_desc;
841       auto status =
842           dnn->PrepareForConvolution(
843                  dnn::ConvolutionKind::BACKWARD_DATA, this, input_descriptor,
844                  *backward_input_data, filter_descriptor, filter_data,
845                  output_descriptor, backward_output_data,
846                  convolution_descriptor, algorithm_config, scratch_allocator,
847                  &algorithm_desc, &scratch_memory)
848               .ok();
849       if (status) {
850         status = dnn->DoConvolveBackwardData(
851             this, filter_descriptor, filter_data, output_descriptor,
852             backward_output_data, convolution_descriptor, input_descriptor,
853             backward_input_data, algorithm_desc, &scratch_memory,
854             output_profile_result);
855       }
856       if (!status && !output_profile_result) {
857         SetError();
858       }
859     } else {
860       SetErrorAndLogNoDnnSupport();
861     }
862   }
863   return *this;
864 }
865 
ThenConvolveBackwardDataWithAlgorithm(const dnn::FilterDescriptor & filter_descriptor,const DeviceMemory<float> & filter_data,const dnn::BatchDescriptor & output_descriptor,DeviceMemory<float> backward_output_data,const dnn::ConvolutionDescriptor & convolution_descriptor,const dnn::BatchDescriptor & input_descriptor,DeviceMemory<float> * backward_input_data,ScratchAllocator * scratch_allocator,const dnn::AlgorithmConfig & algorithm_config,dnn::ProfileResult * output_profile_result)866 Stream &Stream::ThenConvolveBackwardDataWithAlgorithm(
867     const dnn::FilterDescriptor &filter_descriptor,
868     const DeviceMemory<float> &filter_data,
869     const dnn::BatchDescriptor &output_descriptor,
870     DeviceMemory<float> backward_output_data,
871     const dnn::ConvolutionDescriptor &convolution_descriptor,
872     const dnn::BatchDescriptor &input_descriptor,
873     DeviceMemory<float> *backward_input_data,
874     ScratchAllocator *scratch_allocator,
875     const dnn::AlgorithmConfig &algorithm_config,
876     dnn::ProfileResult *output_profile_result) {
877   VLOG_CALL(PARAM(filter_descriptor), PARAM(filter_data),
878             PARAM(output_descriptor), PARAM(backward_output_data),
879             PARAM(convolution_descriptor), PARAM(input_descriptor),
880             PARAM(backward_input_data));
881 
882   if (ok()) {
883     if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
884       DeviceMemory<uint8> scratch_memory;
885       dnn::AlgorithmDesc algorithm_desc;
886       auto status =
887           dnn->PrepareForConvolution(
888                  dnn::ConvolutionKind::BACKWARD_DATA, this, input_descriptor,
889                  *backward_input_data, filter_descriptor, filter_data,
890                  output_descriptor, backward_output_data,
891                  convolution_descriptor, algorithm_config, scratch_allocator,
892                  &algorithm_desc, &scratch_memory)
893               .ok();
894       if (status) {
895         status = dnn->DoConvolveBackwardData(
896             this, filter_descriptor, filter_data, output_descriptor,
897             backward_output_data, convolution_descriptor, input_descriptor,
898             backward_input_data, algorithm_desc, &scratch_memory,
899             output_profile_result);
900       }
901       if (!status && !output_profile_result) {
902         SetError();
903       }
904     } else {
905       SetErrorAndLogNoDnnSupport();
906     }
907   }
908   return *this;
909 }
910 
ThenConvolveBackwardDataWithAlgorithm(const dnn::FilterDescriptor & filter_descriptor,const DeviceMemory<Eigen::half> & filter_data,const dnn::BatchDescriptor & output_descriptor,DeviceMemory<Eigen::half> backward_output_data,const dnn::ConvolutionDescriptor & convolution_descriptor,const dnn::BatchDescriptor & input_descriptor,DeviceMemory<Eigen::half> * backward_input_data,ScratchAllocator * scratch_allocator,const dnn::AlgorithmConfig & algorithm_config,dnn::ProfileResult * output_profile_result)911 Stream &Stream::ThenConvolveBackwardDataWithAlgorithm(
912     const dnn::FilterDescriptor &filter_descriptor,
913     const DeviceMemory<Eigen::half> &filter_data,
914     const dnn::BatchDescriptor &output_descriptor,
915     DeviceMemory<Eigen::half> backward_output_data,
916     const dnn::ConvolutionDescriptor &convolution_descriptor,
917     const dnn::BatchDescriptor &input_descriptor,
918     DeviceMemory<Eigen::half> *backward_input_data,
919     ScratchAllocator *scratch_allocator,
920     const dnn::AlgorithmConfig &algorithm_config,
921     dnn::ProfileResult *output_profile_result) {
922   VLOG_CALL(PARAM(filter_descriptor), PARAM(filter_data),
923             PARAM(output_descriptor), PARAM(backward_output_data),
924             PARAM(convolution_descriptor), PARAM(input_descriptor),
925             PARAM(backward_input_data));
926 
927   if (ok()) {
928     if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
929       DeviceMemory<uint8> scratch_memory;
930       dnn::AlgorithmDesc algorithm_desc;
931       auto status =
932           dnn->PrepareForConvolution(
933                  dnn::ConvolutionKind::BACKWARD_DATA, this, input_descriptor,
934                  *backward_input_data, filter_descriptor, filter_data,
935                  output_descriptor, backward_output_data,
936                  convolution_descriptor, algorithm_config, scratch_allocator,
937                  &algorithm_desc, &scratch_memory)
938               .ok();
939       if (status) {
940         status = dnn->DoConvolveBackwardData(
941             this, filter_descriptor, filter_data, output_descriptor,
942             backward_output_data, convolution_descriptor, input_descriptor,
943             backward_input_data, algorithm_desc, &scratch_memory,
944             output_profile_result);
945       }
946       if (!status && !output_profile_result) {
947         SetError();
948       }
949     } else {
950       SetErrorAndLogNoDnnSupport();
951     }
952   }
953   return *this;
954 }
955 
ThenConvolveBackwardFilterWithAlgorithm(const dnn::BatchDescriptor & input_descriptor,const DeviceMemory<double> & input_data,const dnn::BatchDescriptor & output_descriptor,DeviceMemory<double> backward_output_data,const dnn::ConvolutionDescriptor & convolution_descriptor,const dnn::FilterDescriptor & filter_descriptor,DeviceMemory<double> * backward_filter_data,ScratchAllocator * scratch_allocator,const dnn::AlgorithmConfig & algorithm_config,dnn::ProfileResult * output_profile_result)956 Stream &Stream::ThenConvolveBackwardFilterWithAlgorithm(
957     const dnn::BatchDescriptor &input_descriptor,
958     const DeviceMemory<double> &input_data,
959     const dnn::BatchDescriptor &output_descriptor,
960     DeviceMemory<double> backward_output_data,
961     const dnn::ConvolutionDescriptor &convolution_descriptor,
962     const dnn::FilterDescriptor &filter_descriptor,
963     DeviceMemory<double> *backward_filter_data,
964     ScratchAllocator *scratch_allocator,
965     const dnn::AlgorithmConfig &algorithm_config,
966     dnn::ProfileResult *output_profile_result) {
967   VLOG_CALL(PARAM(input_descriptor), PARAM(input_data),
968             PARAM(output_descriptor), PARAM(backward_output_data),
969             PARAM(convolution_descriptor), PARAM(filter_descriptor),
970             PARAM(backward_filter_data));
971 
972   if (ok()) {
973     if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
974       DeviceMemory<uint8> scratch_memory;
975       dnn::AlgorithmDesc algorithm_desc;
976       auto status =
977           dnn->PrepareForConvolution(
978                  dnn::ConvolutionKind::BACKWARD_FILTER, this, input_descriptor,
979                  input_data, filter_descriptor, *backward_filter_data,
980                  output_descriptor, backward_output_data,
981                  convolution_descriptor, algorithm_config, scratch_allocator,
982                  &algorithm_desc, &scratch_memory)
983               .ok();
984       if (status) {
985         status = dnn->DoConvolveBackwardFilter(
986             this, input_descriptor, input_data, output_descriptor,
987             backward_output_data, convolution_descriptor, filter_descriptor,
988             backward_filter_data, algorithm_desc, &scratch_memory,
989             output_profile_result);
990       }
991       if (!status && !output_profile_result) {
992         SetError();
993       }
994     } else {
995       SetErrorAndLogNoDnnSupport();
996     }
997   }
998   return *this;
999 }
1000 
ThenConvolveBackwardFilterWithAlgorithm(const dnn::BatchDescriptor & input_descriptor,const DeviceMemory<float> & input_data,const dnn::BatchDescriptor & output_descriptor,DeviceMemory<float> backward_output_data,const dnn::ConvolutionDescriptor & convolution_descriptor,const dnn::FilterDescriptor & filter_descriptor,DeviceMemory<float> * backward_filter_data,ScratchAllocator * scratch_allocator,const dnn::AlgorithmConfig & algorithm_config,dnn::ProfileResult * output_profile_result)1001 Stream &Stream::ThenConvolveBackwardFilterWithAlgorithm(
1002     const dnn::BatchDescriptor &input_descriptor,
1003     const DeviceMemory<float> &input_data,
1004     const dnn::BatchDescriptor &output_descriptor,
1005     DeviceMemory<float> backward_output_data,
1006     const dnn::ConvolutionDescriptor &convolution_descriptor,
1007     const dnn::FilterDescriptor &filter_descriptor,
1008     DeviceMemory<float> *backward_filter_data,
1009     ScratchAllocator *scratch_allocator,
1010     const dnn::AlgorithmConfig &algorithm_config,
1011     dnn::ProfileResult *output_profile_result) {
1012   VLOG_CALL(PARAM(input_descriptor), PARAM(input_data),
1013             PARAM(output_descriptor), PARAM(backward_output_data),
1014             PARAM(convolution_descriptor), PARAM(filter_descriptor),
1015             PARAM(backward_filter_data));
1016 
1017   if (ok()) {
1018     if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
1019       DeviceMemory<uint8> scratch_memory;
1020       dnn::AlgorithmDesc algorithm_desc;
1021       auto status =
1022           dnn->PrepareForConvolution(
1023                  dnn::ConvolutionKind::BACKWARD_FILTER, this, input_descriptor,
1024                  input_data, filter_descriptor, *backward_filter_data,
1025                  output_descriptor, backward_output_data,
1026                  convolution_descriptor, algorithm_config, scratch_allocator,
1027                  &algorithm_desc, &scratch_memory)
1028               .ok();
1029       if (status) {
1030         status = dnn->DoConvolveBackwardFilter(
1031             this, input_descriptor, input_data, output_descriptor,
1032             backward_output_data, convolution_descriptor, filter_descriptor,
1033             backward_filter_data, algorithm_desc, &scratch_memory,
1034             output_profile_result);
1035       }
1036       if (!status && !output_profile_result) {
1037         SetError();
1038       }
1039     } else {
1040       SetErrorAndLogNoDnnSupport();
1041     }
1042   }
1043   return *this;
1044 }
1045 
ThenConvolveBackwardFilterWithAlgorithm(const dnn::BatchDescriptor & input_descriptor,const DeviceMemory<Eigen::half> & input_data,const dnn::BatchDescriptor & output_descriptor,DeviceMemory<Eigen::half> backward_output_data,const dnn::ConvolutionDescriptor & convolution_descriptor,const dnn::FilterDescriptor & filter_descriptor,DeviceMemory<Eigen::half> * backward_filter_data,ScratchAllocator * scratch_allocator,const dnn::AlgorithmConfig & algorithm_config,dnn::ProfileResult * output_profile_result)1046 Stream &Stream::ThenConvolveBackwardFilterWithAlgorithm(
1047     const dnn::BatchDescriptor &input_descriptor,
1048     const DeviceMemory<Eigen::half> &input_data,
1049     const dnn::BatchDescriptor &output_descriptor,
1050     DeviceMemory<Eigen::half> backward_output_data,
1051     const dnn::ConvolutionDescriptor &convolution_descriptor,
1052     const dnn::FilterDescriptor &filter_descriptor,
1053     DeviceMemory<Eigen::half> *backward_filter_data,
1054     ScratchAllocator *scratch_allocator,
1055     const dnn::AlgorithmConfig &algorithm_config,
1056     dnn::ProfileResult *output_profile_result) {
1057   VLOG_CALL(PARAM(input_descriptor), PARAM(input_data),
1058             PARAM(output_descriptor), PARAM(backward_output_data),
1059             PARAM(convolution_descriptor), PARAM(filter_descriptor),
1060             PARAM(backward_filter_data));
1061 
1062   if (ok()) {
1063     if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
1064       DeviceMemory<uint8> scratch_memory;
1065       dnn::AlgorithmDesc algorithm_desc;
1066       auto status =
1067           dnn->PrepareForConvolution(
1068                  dnn::ConvolutionKind::BACKWARD_FILTER, this, input_descriptor,
1069                  input_data, filter_descriptor, *backward_filter_data,
1070                  output_descriptor, backward_output_data,
1071                  convolution_descriptor, algorithm_config, scratch_allocator,
1072                  &algorithm_desc, &scratch_memory)
1073               .ok();
1074       if (status) {
1075         status = dnn->DoConvolveBackwardFilter(
1076             this, input_descriptor, input_data, output_descriptor,
1077             backward_output_data, convolution_descriptor, filter_descriptor,
1078             backward_filter_data, algorithm_desc, &scratch_memory,
1079             output_profile_result);
1080       }
1081       if (!status && !output_profile_result) {
1082         SetError();
1083       }
1084     } else {
1085       SetErrorAndLogNoDnnSupport();
1086     }
1087   }
1088   return *this;
1089 }
1090 
1091 template <typename T>
ThenConvolveBackwardBiasImpl(const dnn::BatchDescriptor & input_descriptor,const DeviceMemory<T> & input_data,const dnn::BatchDescriptor & bias_descriptor,DeviceMemory<T> * backward_bias_data)1092 Stream &Stream::ThenConvolveBackwardBiasImpl(
1093     const dnn::BatchDescriptor &input_descriptor,
1094     const DeviceMemory<T> &input_data,
1095     const dnn::BatchDescriptor &bias_descriptor,
1096     DeviceMemory<T> *backward_bias_data) {
1097   VLOG_CALL(PARAM(input_descriptor), PARAM(input_data), PARAM(bias_descriptor),
1098             PARAM(backward_bias_data));
1099 
1100   if (ok()) {
1101     if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
1102       CheckError(dnn->DoConvolveBackwardBias(this, input_descriptor, input_data,
1103                                              bias_descriptor,
1104                                              backward_bias_data));
1105     } else {
1106       SetErrorAndLogNoDnnSupport();
1107     }
1108   }
1109   return *this;
1110 }
1111 
ThenConvolveBackwardBias(const dnn::BatchDescriptor & input_descriptor,const DeviceMemory<double> & input_data,const dnn::BatchDescriptor & bias_descriptor,DeviceMemory<double> * backward_bias_data)1112 Stream &Stream::ThenConvolveBackwardBias(
1113     const dnn::BatchDescriptor &input_descriptor,
1114     const DeviceMemory<double> &input_data,
1115     const dnn::BatchDescriptor &bias_descriptor,
1116     DeviceMemory<double> *backward_bias_data) {
1117   return ThenConvolveBackwardBiasImpl(input_descriptor, input_data,
1118                                       bias_descriptor, backward_bias_data);
1119 }
1120 
ThenConvolveBackwardBias(const dnn::BatchDescriptor & input_descriptor,const DeviceMemory<float> & input_data,const dnn::BatchDescriptor & bias_descriptor,DeviceMemory<float> * backward_bias_data)1121 Stream &Stream::ThenConvolveBackwardBias(
1122     const dnn::BatchDescriptor &input_descriptor,
1123     const DeviceMemory<float> &input_data,
1124     const dnn::BatchDescriptor &bias_descriptor,
1125     DeviceMemory<float> *backward_bias_data) {
1126   return ThenConvolveBackwardBiasImpl(input_descriptor, input_data,
1127                                       bias_descriptor, backward_bias_data);
1128 }
1129 
ThenConvolveBackwardBias(const dnn::BatchDescriptor & input_descriptor,const DeviceMemory<Eigen::half> & input_data,const dnn::BatchDescriptor & bias_descriptor,DeviceMemory<Eigen::half> * backward_bias_data)1130 Stream &Stream::ThenConvolveBackwardBias(
1131     const dnn::BatchDescriptor &input_descriptor,
1132     const DeviceMemory<Eigen::half> &input_data,
1133     const dnn::BatchDescriptor &bias_descriptor,
1134     DeviceMemory<Eigen::half> *backward_bias_data) {
1135   return ThenConvolveBackwardBiasImpl(input_descriptor, input_data,
1136                                       bias_descriptor, backward_bias_data);
1137 }
1138 
ThenMatMul(const DeviceMemory<float> & input_data,const DeviceMemory<float> & weights,const dnn::BatchDescriptor & input_dimensions,const dnn::BatchDescriptor & output_dimensions,DeviceMemory<float> * output_data)1139 Stream &Stream::ThenMatMul(const DeviceMemory<float> &input_data,
1140                            const DeviceMemory<float> &weights,
1141                            const dnn::BatchDescriptor &input_dimensions,
1142                            const dnn::BatchDescriptor &output_dimensions,
1143                            DeviceMemory<float> *output_data) {
1144   VLOG_CALL(PARAM(input_data), PARAM(weights), PARAM(input_dimensions),
1145             PARAM(output_dimensions), PARAM(output_data));
1146 
1147   if (ok()) {
1148     if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
1149       CheckError(dnn->DoMatMul(this, input_data, weights, input_dimensions,
1150                                output_dimensions, output_data));
1151     } else {
1152       SetErrorAndLogNoDnnSupport();
1153     }
1154   }
1155   return *this;
1156 }
1157 
ThenMatMulQuantized(const DeviceMemory<float> & input_data,const DeviceMemory<int8> & weights,const DeviceMemory<float> & weight_scales,const dnn::BatchDescriptor & input_dimensions,const dnn::BatchDescriptor & output_dimensions,DeviceMemory<float> * output_data)1158 Stream &Stream::ThenMatMulQuantized(
1159     const DeviceMemory<float> &input_data, const DeviceMemory<int8> &weights,
1160     const DeviceMemory<float> &weight_scales,
1161     const dnn::BatchDescriptor &input_dimensions,
1162     const dnn::BatchDescriptor &output_dimensions,
1163     DeviceMemory<float> *output_data) {
1164   VLOG_CALL(PARAM(input_data), PARAM(weights), PARAM(weight_scales),
1165             PARAM(input_dimensions), PARAM(output_dimensions),
1166             PARAM(output_data));
1167 
1168   if (ok()) {
1169     if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
1170       CheckError(dnn->DoMatMulQuantized(this, input_data, weights,
1171                                         weight_scales, input_dimensions,
1172                                         output_dimensions, output_data));
1173     } else {
1174       SetErrorAndLogNoDnnSupport();
1175     }
1176   }
1177   return *this;
1178 }
1179 
ThenMatMulQuantized(const DeviceMemory<float> & input_data,const DeviceMemory<int16> & weights,const DeviceMemory<float> & weight_scales,const dnn::BatchDescriptor & input_dimensions,const dnn::BatchDescriptor & output_dimensions,DeviceMemory<float> * output_data)1180 Stream &Stream::ThenMatMulQuantized(
1181     const DeviceMemory<float> &input_data, const DeviceMemory<int16> &weights,
1182     const DeviceMemory<float> &weight_scales,
1183     const dnn::BatchDescriptor &input_dimensions,
1184     const dnn::BatchDescriptor &output_dimensions,
1185     DeviceMemory<float> *output_data) {
1186   VLOG_CALL(PARAM(input_data), PARAM(weights), PARAM(weight_scales),
1187             PARAM(input_dimensions), PARAM(output_dimensions),
1188             PARAM(output_data));
1189 
1190   if (ok()) {
1191     if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
1192       CheckError(dnn->DoMatMulQuantized(this, input_data, weights,
1193                                         weight_scales, input_dimensions,
1194                                         output_dimensions, output_data));
1195     } else {
1196       SetErrorAndLogNoDnnSupport();
1197     }
1198   }
1199   return *this;
1200 }
1201 
ThenBiasAdd(const DeviceMemory<float> & input_data,const DeviceMemory<float> & biases,const dnn::BatchDescriptor & dimensions,DeviceMemory<float> * output_data)1202 Stream &Stream::ThenBiasAdd(const DeviceMemory<float> &input_data,
1203                             const DeviceMemory<float> &biases,
1204                             const dnn::BatchDescriptor &dimensions,
1205                             DeviceMemory<float> *output_data) {
1206   VLOG_CALL(PARAM(input_data), PARAM(biases), PARAM(dimensions),
1207             PARAM(output_data));
1208 
1209   if (ok()) {
1210     if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
1211       CheckError(
1212           dnn->DoBiasAdd(this, input_data, biases, dimensions, output_data));
1213     } else {
1214       SetErrorAndLogNoDnnSupport();
1215     }
1216   }
1217   return *this;
1218 }
1219 
ThenPoolForward(const dnn::PoolingDescriptor & pooling_dimensions,const dnn::BatchDescriptor & input_dimensions,const DeviceMemory<double> & input_data,const dnn::BatchDescriptor & output_dimensions,DeviceMemory<double> * output_data,ScratchAllocator * workspace_allocator)1220 Stream &Stream::ThenPoolForward(
1221     const dnn::PoolingDescriptor &pooling_dimensions,
1222     const dnn::BatchDescriptor &input_dimensions,
1223     const DeviceMemory<double> &input_data,
1224     const dnn::BatchDescriptor &output_dimensions,
1225     DeviceMemory<double> *output_data, ScratchAllocator *workspace_allocator) {
1226   VLOG_CALL(PARAM(pooling_dimensions), PARAM(input_dimensions),
1227             PARAM(input_data), PARAM(output_dimensions), PARAM(output_data),
1228             PARAM(workspace_allocator));
1229 
1230   if (ok()) {
1231     if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
1232       CheckError(dnn->DoPoolForward(this, pooling_dimensions, input_dimensions,
1233                                     input_data, output_dimensions, output_data,
1234                                     workspace_allocator));
1235     } else {
1236       SetError();
1237       LOG(WARNING)
1238           << "attempting to perform DNN operation using StreamExecutor "
1239              "without DNN support";
1240     }
1241   }
1242   return *this;
1243 }
1244 
ThenPoolForward(const dnn::PoolingDescriptor & pooling_dimensions,const dnn::BatchDescriptor & input_dimensions,const DeviceMemory<float> & input_data,const dnn::BatchDescriptor & output_dimensions,DeviceMemory<float> * output_data,ScratchAllocator * workspace_allocator)1245 Stream &Stream::ThenPoolForward(
1246     const dnn::PoolingDescriptor &pooling_dimensions,
1247     const dnn::BatchDescriptor &input_dimensions,
1248     const DeviceMemory<float> &input_data,
1249     const dnn::BatchDescriptor &output_dimensions,
1250     DeviceMemory<float> *output_data, ScratchAllocator *workspace_allocator) {
1251   VLOG_CALL(PARAM(pooling_dimensions), PARAM(input_dimensions),
1252             PARAM(input_data), PARAM(output_dimensions), PARAM(output_data),
1253             PARAM(workspace_allocator));
1254 
1255   if (ok()) {
1256     if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
1257       CheckError(dnn->DoPoolForward(this, pooling_dimensions, input_dimensions,
1258                                     input_data, output_dimensions, output_data,
1259                                     workspace_allocator));
1260     } else {
1261       SetErrorAndLogNoDnnSupport();
1262     }
1263   }
1264   return *this;
1265 }
1266 
ThenPoolForward(const dnn::PoolingDescriptor & pooling_dimensions,const dnn::BatchDescriptor & input_dimensions,const DeviceMemory<Eigen::half> & input_data,const dnn::BatchDescriptor & output_dimensions,DeviceMemory<Eigen::half> * output_data,ScratchAllocator * workspace_allocator)1267 Stream &Stream::ThenPoolForward(
1268     const dnn::PoolingDescriptor &pooling_dimensions,
1269     const dnn::BatchDescriptor &input_dimensions,
1270     const DeviceMemory<Eigen::half> &input_data,
1271     const dnn::BatchDescriptor &output_dimensions,
1272     DeviceMemory<Eigen::half> *output_data,
1273     ScratchAllocator *workspace_allocator) {
1274   VLOG_CALL(PARAM(pooling_dimensions), PARAM(input_dimensions),
1275             PARAM(input_data), PARAM(output_dimensions), PARAM(output_data),
1276             PARAM(workspace_allocator));
1277 
1278   if (ok()) {
1279     if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
1280       CheckError(dnn->DoPoolForward(this, pooling_dimensions, input_dimensions,
1281                                     input_data, output_dimensions, output_data,
1282                                     workspace_allocator));
1283     } else {
1284       SetErrorAndLogNoDnnSupport();
1285     }
1286   }
1287   return *this;
1288 }
1289 
ThenPoolForward(const dnn::PoolingDescriptor & pooling_dimensions,const dnn::BatchDescriptor & input_dimensions,const DeviceMemory<int8> & input_data,const dnn::BatchDescriptor & output_dimensions,DeviceMemory<int8> * output_data,ScratchAllocator * workspace_allocator)1290 Stream &Stream::ThenPoolForward(
1291     const dnn::PoolingDescriptor &pooling_dimensions,
1292     const dnn::BatchDescriptor &input_dimensions,
1293     const DeviceMemory<int8> &input_data,
1294     const dnn::BatchDescriptor &output_dimensions,
1295     DeviceMemory<int8> *output_data, ScratchAllocator *workspace_allocator) {
1296   VLOG_CALL(PARAM(pooling_dimensions), PARAM(input_dimensions),
1297             PARAM(input_data), PARAM(output_dimensions), PARAM(output_data),
1298             PARAM(workspace_allocator));
1299 
1300   if (ok()) {
1301     if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
1302       CheckError(dnn->DoPoolForward(this, pooling_dimensions, input_dimensions,
1303                                     input_data, output_dimensions, output_data,
1304                                     workspace_allocator));
1305     } else {
1306       SetErrorAndLogNoDnnSupport();
1307     }
1308   }
1309   return *this;
1310 }
1311 
ThenPoolBackward(const dnn::PoolingDescriptor & pooling_dimensions,const dnn::BatchDescriptor & input_dimensions,const DeviceMemory<double> & input_data,const dnn::BatchDescriptor & output_dimensions,const DeviceMemory<double> & output_data,const DeviceMemory<double> & input_diff_data,DeviceMemory<double> * output_diff_data,ScratchAllocator * workspace_allocator)1312 Stream &Stream::ThenPoolBackward(
1313     const dnn::PoolingDescriptor &pooling_dimensions,
1314     const dnn::BatchDescriptor &input_dimensions,
1315     const DeviceMemory<double> &input_data,
1316     const dnn::BatchDescriptor &output_dimensions,
1317     const DeviceMemory<double> &output_data,
1318     const DeviceMemory<double> &input_diff_data,
1319     DeviceMemory<double> *output_diff_data,
1320     ScratchAllocator *workspace_allocator) {
1321   VLOG_CALL(PARAM(pooling_dimensions), PARAM(input_dimensions),
1322             PARAM(input_data), PARAM(output_dimensions), PARAM(output_data),
1323             PARAM(input_diff_data), PARAM(output_diff_data),
1324             PARAM(workspace_allocator));
1325 
1326   if (ok()) {
1327     if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
1328       CheckError(dnn->DoPoolBackward(this, pooling_dimensions, input_dimensions,
1329                                      input_data, output_dimensions, output_data,
1330                                      input_diff_data, output_diff_data,
1331                                      workspace_allocator));
1332     } else {
1333       SetError();
1334       LOG(WARNING)
1335           << "attempting to perform DNN operation using StreamExecutor "
1336              "without DNN support";
1337     }
1338   }
1339   return *this;
1340 }
1341 
ThenPoolBackward(const dnn::PoolingDescriptor & pooling_dimensions,const dnn::BatchDescriptor & input_dimensions,const DeviceMemory<float> & input_data,const dnn::BatchDescriptor & output_dimensions,const DeviceMemory<float> & output_data,const DeviceMemory<float> & input_diff_data,DeviceMemory<float> * output_diff_data,ScratchAllocator * workspace_allocator)1342 Stream &Stream::ThenPoolBackward(
1343     const dnn::PoolingDescriptor &pooling_dimensions,
1344     const dnn::BatchDescriptor &input_dimensions,
1345     const DeviceMemory<float> &input_data,
1346     const dnn::BatchDescriptor &output_dimensions,
1347     const DeviceMemory<float> &output_data,
1348     const DeviceMemory<float> &input_diff_data,
1349     DeviceMemory<float> *output_diff_data,
1350     ScratchAllocator *workspace_allocator) {
1351   VLOG_CALL(PARAM(pooling_dimensions), PARAM(input_dimensions),
1352             PARAM(input_data), PARAM(output_dimensions), PARAM(output_data),
1353             PARAM(input_diff_data), PARAM(output_diff_data),
1354             PARAM(workspace_allocator));
1355 
1356   if (ok()) {
1357     if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
1358       CheckError(dnn->DoPoolBackward(this, pooling_dimensions, input_dimensions,
1359                                      input_data, output_dimensions, output_data,
1360                                      input_diff_data, output_diff_data,
1361                                      workspace_allocator));
1362     } else {
1363       SetErrorAndLogNoDnnSupport();
1364     }
1365   }
1366   return *this;
1367 }
1368 
ThenPoolBackward(const dnn::PoolingDescriptor & pooling_dimensions,const dnn::BatchDescriptor & input_dimensions,const DeviceMemory<Eigen::half> & input_data,const dnn::BatchDescriptor & output_dimensions,const DeviceMemory<Eigen::half> & output_data,const DeviceMemory<Eigen::half> & input_diff_data,DeviceMemory<Eigen::half> * output_diff_data,ScratchAllocator * workspace_allocator)1369 Stream &Stream::ThenPoolBackward(
1370     const dnn::PoolingDescriptor &pooling_dimensions,
1371     const dnn::BatchDescriptor &input_dimensions,
1372     const DeviceMemory<Eigen::half> &input_data,
1373     const dnn::BatchDescriptor &output_dimensions,
1374     const DeviceMemory<Eigen::half> &output_data,
1375     const DeviceMemory<Eigen::half> &input_diff_data,
1376     DeviceMemory<Eigen::half> *output_diff_data,
1377     ScratchAllocator *workspace_allocator) {
1378   VLOG_CALL(PARAM(pooling_dimensions), PARAM(input_dimensions),
1379             PARAM(input_data), PARAM(output_dimensions), PARAM(output_data),
1380             PARAM(input_diff_data), PARAM(output_diff_data),
1381             PARAM(workspace_allocator));
1382 
1383   if (ok()) {
1384     if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
1385       CheckError(dnn->DoPoolBackward(this, pooling_dimensions, input_dimensions,
1386                                      input_data, output_dimensions, output_data,
1387                                      input_diff_data, output_diff_data,
1388                                      workspace_allocator));
1389     } else {
1390       SetErrorAndLogNoDnnSupport();
1391     }
1392   }
1393   return *this;
1394 }
1395 
ThenNormalizeWithDimensions(const dnn::NormalizeDescriptor & normalize_descriptor,const dnn::BatchDescriptor & dimensions,const DeviceMemory<float> & input_data,DeviceMemory<float> * output_data)1396 Stream &Stream::ThenNormalizeWithDimensions(
1397     const dnn::NormalizeDescriptor &normalize_descriptor,
1398     const dnn::BatchDescriptor &dimensions,
1399     const DeviceMemory<float> &input_data, DeviceMemory<float> *output_data) {
1400   VLOG_CALL(PARAM(normalize_descriptor), PARAM(dimensions), PARAM(input_data),
1401             PARAM(output_data));
1402 
1403   if (ok()) {
1404     if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
1405       CheckError(dnn->DoNormalizeWithDimensions(
1406           this, normalize_descriptor, dimensions, input_data, output_data));
1407     } else {
1408       SetErrorAndLogNoDnnSupport();
1409     }
1410   }
1411   return *this;
1412 }
1413 
ThenNormalizeBackwardWithDimensions(const dnn::NormalizeDescriptor & normalize_descriptor,const dnn::BatchDescriptor & dimensions,const DeviceMemory<float> & raw_data,const DeviceMemory<float> & normalized_data,const DeviceMemory<float> & normalized_variable_gradient,DeviceMemory<float> * raw_variable_gradient,ScratchAllocator * workspace_allocator)1414 Stream &Stream::ThenNormalizeBackwardWithDimensions(
1415     const dnn::NormalizeDescriptor &normalize_descriptor,
1416     const dnn::BatchDescriptor &dimensions, const DeviceMemory<float> &raw_data,
1417     const DeviceMemory<float> &normalized_data,
1418     const DeviceMemory<float> &normalized_variable_gradient,
1419     DeviceMemory<float> *raw_variable_gradient,
1420     ScratchAllocator *workspace_allocator) {
1421   VLOG_CALL(PARAM(normalize_descriptor), PARAM(dimensions), PARAM(raw_data),
1422             PARAM(normalized_data), PARAM(normalized_variable_gradient),
1423             PARAM(raw_variable_gradient), PARAM(workspace_allocator));
1424 
1425   if (ok()) {
1426     if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
1427       CheckError(dnn->DoNormalizeBackwardWithDimensions(
1428           this, normalize_descriptor, dimensions, raw_data, normalized_data,
1429           normalized_variable_gradient, raw_variable_gradient,
1430           workspace_allocator));
1431     } else {
1432       SetErrorAndLogNoDnnSupport();
1433     }
1434   }
1435   return *this;
1436 }
1437 
ThenActivate(dnn::ActivationMode activation_mode,const dnn::BatchDescriptor & dimensions,const DeviceMemory<float> & input_data,DeviceMemory<float> * output_data)1438 Stream &Stream::ThenActivate(dnn::ActivationMode activation_mode,
1439                              const dnn::BatchDescriptor &dimensions,
1440                              const DeviceMemory<float> &input_data,
1441                              DeviceMemory<float> *output_data) {
1442   return ThenActivateWithOptions(activation_mode, dimensions, input_data,
1443                                  output_data, /*options=*/0);
1444 }
1445 
ThenActivateWithOptions(dnn::ActivationMode activation_mode,const dnn::BatchDescriptor & dimensions,const DeviceMemory<float> & input_data,DeviceMemory<float> * output_data,uint64 options)1446 Stream &Stream::ThenActivateWithOptions(dnn::ActivationMode activation_mode,
1447                                         const dnn::BatchDescriptor &dimensions,
1448                                         const DeviceMemory<float> &input_data,
1449                                         DeviceMemory<float> *output_data,
1450                                         uint64 options) {
1451   VLOG_CALL(PARAM(activation_mode), PARAM(dimensions), PARAM(input_data),
1452             PARAM(output_data), PARAM(options));
1453 
1454   if (ok()) {
1455     if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
1456       CheckError(dnn->DoActivate(this, activation_mode, dimensions, input_data,
1457                                  output_data, options));
1458     } else {
1459       SetErrorAndLogNoDnnSupport();
1460     }
1461   }
1462   return *this;
1463 }
1464 
ThenDepthConcatenate(port::ArraySlice<dnn::BatchDescriptor> input_dimensions,port::ArraySlice<const DeviceMemory<float> * > input_data,DeviceMemory<float> * output_data)1465 Stream &Stream::ThenDepthConcatenate(
1466     port::ArraySlice<dnn::BatchDescriptor> input_dimensions,
1467     port::ArraySlice<const DeviceMemory<float> *> input_data,
1468     DeviceMemory<float> *output_data) {
1469   VLOG_CALL(PARAM(input_dimensions), PARAM(input_data), PARAM(output_data));
1470 
1471   for (size_t i = 1; i < input_dimensions.size(); ++i) {
1472     if (input_dimensions[i].count() != input_dimensions[0].count() ||
1473         input_dimensions[i].height() != input_dimensions[0].height() ||
1474         input_dimensions[i].width() != input_dimensions[0].width()) {
1475       SetError();
1476       LOG(ERROR) << "Incompatible dimensions for depth concatenation.\n"
1477                  << "input_dimensions[0]: " << input_dimensions[0].ToString()
1478                  << "input_dimensions[" << i
1479                  << "]: " << input_dimensions[i].ToString();
1480       return *this;
1481     }
1482   }
1483 
1484   if (ok()) {
1485     if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
1486       CheckError(dnn->DoDepthConcatenate(this, input_dimensions, input_data,
1487                                          output_data));
1488     } else {
1489       SetErrorAndLogNoDnnSupport();
1490     }
1491   }
1492   return *this;
1493 }
1494 
ThenSpaceConcatenate(port::ArraySlice<dnn::BatchDescriptor> input_dimensions,port::ArraySlice<const DeviceMemory<float> * > input_data,DeviceMemory<float> * output_data,dnn::SpaceConcatenateMode concat_direction)1495 Stream &Stream::ThenSpaceConcatenate(
1496     port::ArraySlice<dnn::BatchDescriptor> input_dimensions,
1497     port::ArraySlice<const DeviceMemory<float> *> input_data,
1498     DeviceMemory<float> *output_data,
1499     dnn::SpaceConcatenateMode concat_direction) {
1500   VLOG_CALL(PARAM(input_dimensions), PARAM(input_data), PARAM(output_data));
1501 
1502   // Check that the input dimensions of all the other batches match those of the
1503   // first batch.
1504   for (size_t i = 1; i < input_dimensions.size(); ++i) {
1505     if ((concat_direction == dnn::SpaceConcatenateMode::XDirection) &&
1506         (input_dimensions[i].count() != input_dimensions[0].count() ||
1507          input_dimensions[i].height() != input_dimensions[0].height() ||
1508          input_dimensions[i].feature_map_count() !=
1509              input_dimensions[0].feature_map_count())) {
1510       SetError();
1511       LOG(ERROR) << "Incompatible dimensions for X concatenation.\n"
1512                  << "input_dimensions[0]: " << input_dimensions[0].ToString()
1513                  << "input_dimensions[" << i
1514                  << "]: " << input_dimensions[i].ToString();
1515       return *this;
1516     }
1517 
1518     if ((concat_direction == dnn::SpaceConcatenateMode::YDirection) &&
1519         (input_dimensions[i].count() != input_dimensions[0].count() ||
1520          input_dimensions[i].width() != input_dimensions[0].width() ||
1521          input_dimensions[i].feature_map_count() !=
1522              input_dimensions[0].feature_map_count())) {
1523       SetError();
1524       LOG(ERROR) << "Incompatible dimensions for Y concatenation.\n"
1525                  << "input_dimensions[0]: " << input_dimensions[0].ToString()
1526                  << "input_dimensions[" << i
1527                  << "]: " << input_dimensions[i].ToString();
1528       return *this;
1529     }
1530   }
1531   if (ok()) {
1532     if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
1533       CheckError(dnn->DoSpaceConcatenate(this, input_dimensions, input_data,
1534                                          output_data, concat_direction));
1535     } else {
1536       SetErrorAndLogNoDnnSupport();
1537     }
1538   }
1539   return *this;
1540 }
1541 
ThenReshape(const dnn::BatchDescriptor & input_dimensions,const DeviceMemory<float> & input_data,const dnn::BatchDescriptor & output_dimensions,DeviceMemory<float> * output_data)1542 Stream &Stream::ThenReshape(const dnn::BatchDescriptor &input_dimensions,
1543                             const DeviceMemory<float> &input_data,
1544                             const dnn::BatchDescriptor &output_dimensions,
1545                             DeviceMemory<float> *output_data) {
1546   VLOG_CALL(PARAM(input_dimensions), PARAM(input_data),
1547             PARAM(output_dimensions), PARAM(output_data));
1548 
1549   if (ok()) {
1550     if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
1551       CheckError(dnn->DoReshape(this, input_dimensions, input_data,
1552                                 output_dimensions, output_data));
1553     } else {
1554       SetErrorAndLogNoDnnSupport();
1555     }
1556   }
1557   return *this;
1558 }
1559 
ThenDepthToSpace(const dnn::BatchDescriptor & input_dimensions,const DeviceMemory<float> & input_data,const dnn::DepthToSpaceLayout & depth_to_space_layout,const int sqrt_depth_reduction,DeviceMemory<float> * output_data)1560 Stream &Stream::ThenDepthToSpace(
1561     const dnn::BatchDescriptor &input_dimensions,
1562     const DeviceMemory<float> &input_data,
1563     const dnn::DepthToSpaceLayout &depth_to_space_layout,
1564     const int sqrt_depth_reduction, DeviceMemory<float> *output_data) {
1565   VLOG_CALL(PARAM(input_dimensions), PARAM(input_data),
1566             PARAM(depth_to_space_layout), PARAM(sqrt_depth_reduction),
1567             PARAM(output_data));
1568 
1569   if (ok()) {
1570     if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
1571       CheckError(dnn->DoDepthToSpace(this, input_dimensions, input_data,
1572                                      depth_to_space_layout,
1573                                      sqrt_depth_reduction, output_data));
1574     } else {
1575       SetErrorAndLogNoDnnSupport();
1576     }
1577   }
1578   return *this;
1579 }
1580 
ThenSpaceToDepth(const dnn::BatchDescriptor & input_dimensions,const DeviceMemory<float> & input_data,const dnn::DepthToSpaceLayout & space_to_depth_layout,const int sqrt_depth_increase,DeviceMemory<float> * output_data)1581 Stream &Stream::ThenSpaceToDepth(
1582     const dnn::BatchDescriptor &input_dimensions,
1583     const DeviceMemory<float> &input_data,
1584     const dnn::DepthToSpaceLayout &space_to_depth_layout,
1585     const int sqrt_depth_increase, DeviceMemory<float> *output_data) {
1586   VLOG_CALL(PARAM(input_dimensions), PARAM(input_data),
1587             PARAM(space_to_depth_layout), PARAM(sqrt_depth_increase),
1588             PARAM(output_data));
1589 
1590   if (ok()) {
1591     if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
1592       CheckError(dnn->DoSpaceToDepth(this, input_dimensions, input_data,
1593                                      space_to_depth_layout, sqrt_depth_increase,
1594                                      output_data));
1595     } else {
1596       SetErrorAndLogNoDnnSupport();
1597     }
1598   }
1599   return *this;
1600 }
1601 
ThenElementwiseOperate(dnn::ElementwiseOperation operation,port::ArraySlice<dnn::BatchDescriptor> input_dimensions,port::ArraySlice<const DeviceMemory<float> * > input_data,const dnn::BatchDescriptor & output_dimensions,DeviceMemory<float> * output_data)1602 Stream &Stream::ThenElementwiseOperate(
1603     dnn::ElementwiseOperation operation,
1604     port::ArraySlice<dnn::BatchDescriptor> input_dimensions,
1605     port::ArraySlice<const DeviceMemory<float> *> input_data,
1606     const dnn::BatchDescriptor &output_dimensions,
1607     DeviceMemory<float> *output_data) {
1608   VLOG_CALL(PARAM(operation), PARAM(input_dimensions), PARAM(input_data),
1609             PARAM(output_dimensions), PARAM(output_data));
1610 
1611   if (ok()) {
1612     if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
1613       CheckError(dnn->DoElementwiseOperate(this, operation, input_dimensions,
1614                                            input_data, output_dimensions,
1615                                            output_data));
1616     } else {
1617       SetErrorAndLogNoDnnSupport();
1618     }
1619   }
1620   return *this;
1621 }
1622 
ThenElementwiseOperateScaledQuantized(dnn::ElementwiseOperation operation,port::ArraySlice<int> input_multiplicands,int output_divisor,port::ArraySlice<dnn::BatchDescriptor> input_dimensions,port::ArraySlice<const DeviceMemory<float> * > input_data,const dnn::BatchDescriptor & output_dimensions,DeviceMemory<float> * output_data)1623 Stream &Stream::ThenElementwiseOperateScaledQuantized(
1624     dnn::ElementwiseOperation operation,
1625     port::ArraySlice<int> input_multiplicands, int output_divisor,
1626     port::ArraySlice<dnn::BatchDescriptor> input_dimensions,
1627     port::ArraySlice<const DeviceMemory<float> *> input_data,
1628     const dnn::BatchDescriptor &output_dimensions,
1629     DeviceMemory<float> *output_data) {
1630   VLOG_CALL(PARAM(operation), PARAM(input_multiplicands), PARAM(output_divisor),
1631             PARAM(input_dimensions), PARAM(input_data),
1632             PARAM(output_dimensions), PARAM(output_data));
1633 
1634   if (ok()) {
1635     if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
1636       CheckError(dnn->DoElementwiseOperateScaledQuantized(
1637           this, operation, input_multiplicands, output_divisor,
1638           input_dimensions, input_data, output_dimensions, output_data));
1639     } else {
1640       SetErrorAndLogNoDnnSupport();
1641     }
1642   }
1643   return *this;
1644 }
1645 
ThenXYPad(const dnn::BatchDescriptor & dimensions,const DeviceMemory<float> & input_data,int64 left_pad,int64 right_pad,int64 top_pad,int64 bottom_pad,DeviceMemory<float> * output_data)1646 Stream &Stream::ThenXYPad(const dnn::BatchDescriptor &dimensions,
1647                           const DeviceMemory<float> &input_data, int64 left_pad,
1648                           int64 right_pad, int64 top_pad, int64 bottom_pad,
1649                           DeviceMemory<float> *output_data) {
1650   VLOG_CALL(PARAM(dimensions), PARAM(input_data), PARAM(left_pad),
1651             PARAM(right_pad), PARAM(top_pad), PARAM(bottom_pad),
1652             PARAM(output_data));
1653 
1654   if (ok()) {
1655     if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
1656       CheckError(dnn->DoXYPad(this, dimensions, input_data, left_pad, right_pad,
1657                               top_pad, bottom_pad, output_data));
1658     } else {
1659       SetErrorAndLogNoDnnSupport();
1660     }
1661   }
1662   return *this;
1663 }
1664 
ThenXYSlice(const dnn::BatchDescriptor & dimensions,const DeviceMemory<float> & input_data,int64 left_trim,int64 right_trim,int64 top_trim,int64 bottom_trim,DeviceMemory<float> * output_data)1665 Stream &Stream::ThenXYSlice(const dnn::BatchDescriptor &dimensions,
1666                             const DeviceMemory<float> &input_data,
1667                             int64 left_trim, int64 right_trim, int64 top_trim,
1668                             int64 bottom_trim,
1669                             DeviceMemory<float> *output_data) {
1670   VLOG_CALL(PARAM(dimensions), PARAM(input_data), PARAM(left_trim),
1671             PARAM(right_trim), PARAM(top_trim), PARAM(bottom_trim),
1672             PARAM(output_data));
1673 
1674   if (ok()) {
1675     if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
1676       CheckError(dnn->DoXYSlice(this, dimensions, input_data, left_trim,
1677                                 right_trim, top_trim, bottom_trim,
1678                                 output_data));
1679     } else {
1680       SetErrorAndLogNoDnnSupport();
1681     }
1682   }
1683   return *this;
1684 }
1685 
ThenXYBroadcast(const dnn::BatchDescriptor & dimensions,const DeviceMemory<float> & input_data,int64 replicate_x,int64 replicate_y,DeviceMemory<float> * output_data)1686 Stream &Stream::ThenXYBroadcast(const dnn::BatchDescriptor &dimensions,
1687                                 const DeviceMemory<float> &input_data,
1688                                 int64 replicate_x, int64 replicate_y,
1689                                 DeviceMemory<float> *output_data) {
1690   VLOG_CALL(PARAM(dimensions), PARAM(input_data), PARAM(replicate_x),
1691             PARAM(replicate_y), PARAM(output_data));
1692 
1693   if (ok()) {
1694     if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
1695       CheckError(dnn->DoXYBroadcast(this, dimensions, input_data, replicate_x,
1696                                     replicate_y, output_data));
1697     } else {
1698       SetErrorAndLogNoDnnSupport();
1699     }
1700   }
1701   return *this;
1702 }
1703 
ThenMemcpyD2HQuantized(const DeviceMemory<float> & gpu_unquantized_src,dnn::QuantizedActivationMode mode,void * host_dst,uint64 size)1704 Stream &Stream::ThenMemcpyD2HQuantized(
1705     const DeviceMemory<float> &gpu_unquantized_src,
1706     dnn::QuantizedActivationMode mode, void *host_dst, uint64 size) {
1707   VLOG_CALL(PARAM(gpu_unquantized_src), PARAM(mode), PARAM(host_dst),
1708             PARAM(size));
1709 
1710   if (ok()) {
1711     if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
1712       CheckError(dnn->DoMemcpyD2HQuantized(this, gpu_unquantized_src, mode,
1713                                            host_dst, size));
1714     } else {
1715       SetErrorAndLogNoDnnSupport();
1716     }
1717   }
1718   return *this;
1719 }
1720 
ThenMemcpyH2DQuantized(const void * host_src,uint64 size,dnn::QuantizedActivationMode mode,DeviceMemory<float> * gpu_unquantized_dst)1721 Stream &Stream::ThenMemcpyH2DQuantized(
1722     const void *host_src, uint64 size, dnn::QuantizedActivationMode mode,
1723     DeviceMemory<float> *gpu_unquantized_dst) {
1724   VLOG_CALL(PARAM(host_src), PARAM(size), PARAM(mode),
1725             PARAM(gpu_unquantized_dst));
1726 
1727   if (ok()) {
1728     if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
1729       CheckError(dnn->DoMemcpyH2DQuantized(this, host_src, size, mode,
1730                                            gpu_unquantized_dst));
1731     } else {
1732       SetErrorAndLogNoDnnSupport();
1733     }
1734   }
1735   return *this;
1736 }
1737 
GetOrCreateSubStream()1738 Stream *Stream::GetOrCreateSubStream() {
1739   mutex_lock lock(mu_);
1740 
1741   // Look for the first reusable sub_stream that is ok, dropping !ok sub_streams
1742   // we encounter along the way.
1743   for (int64 index = 0; index < sub_streams_.size();) {
1744     std::pair<std::unique_ptr<Stream>, bool> &pair = sub_streams_[index];
1745     if (pair.second) {
1746       // The sub_stream is reusable.
1747       Stream *sub_stream = pair.first.get();
1748       if (sub_stream->ok()) {
1749         VLOG(1) << DebugStreamPointers() << " reusing sub_stream "
1750                 << sub_stream->DebugStreamPointers();
1751         pair.second = false;
1752         return sub_stream;
1753       }
1754 
1755       // The stream is reusable and not ok. Streams have a monotonic state
1756       // machine; the stream will remain in !ok forever. Swap it with the last
1757       // stream and pop it off.
1758       const int64 last = sub_streams_.size() - 1;
1759       if (index != last) {
1760         std::swap(pair, sub_streams_[last]);
1761       }
1762       sub_streams_.pop_back();
1763       VLOG(1) << DebugStreamPointers() << " dropped !ok sub_stream "
1764               << sub_stream->DebugStreamPointers();
1765     } else {
1766       // The sub_stream is not reusable, move on to the next one.
1767       ++index;
1768     }
1769   }
1770 
1771   // No streams are reusable; create a new stream.
1772   sub_streams_.emplace_back(std::unique_ptr<Stream>{new Stream{parent_}},
1773                             false);
1774   Stream *sub_stream = sub_streams_.back().first.get();
1775   sub_stream->Init();
1776   if (!sub_stream->ok_) {
1777     LOG(ERROR) << "sub-stream failed to be initialized";
1778   }
1779   VLOG(1) << DebugStreamPointers() << " created new sub_stream "
1780           << sub_stream->DebugStreamPointers();
1781 
1782   return sub_stream;
1783 }
1784 
ReturnSubStream(Stream * sub_stream)1785 void Stream::ReturnSubStream(Stream *sub_stream) {
1786   mutex_lock lock(mu_);
1787 
1788   // Look for the sub-stream.
1789   for (int64 index = 0; index < sub_streams_.size(); ++index) {
1790     std::pair<std::unique_ptr<Stream>, bool> &pair = sub_streams_[index];
1791     if (pair.first.get() != sub_stream) {
1792       continue;
1793     }
1794 
1795     // Found the sub_stream.
1796     if (sub_stream->ok()) {
1797       VLOG(1) << DebugStreamPointers() << " returned ok sub_stream "
1798               << sub_stream->DebugStreamPointers();
1799       pair.second = true;
1800     } else {
1801       // The returned stream is not ok. Streams have a monotonic state
1802       // machine; the stream will remain in !ok forever. Swap it with the last
1803       // stream and pop it off.
1804       VLOG(1) << DebugStreamPointers() << " returned !ok sub_stream "
1805               << sub_stream->DebugStreamPointers();
1806       const int64 last = sub_streams_.size() - 1;
1807       if (index != last) {
1808         std::swap(pair, sub_streams_[last]);
1809       }
1810       sub_streams_.pop_back();
1811     }
1812     return;
1813   }
1814 
1815   LOG(FATAL) << DebugStreamPointers()
1816              << " did not create the returned sub-stream "
1817              << sub_stream->DebugStreamPointers();
1818 }
1819 
ThenStartTimer(Timer * t)1820 Stream &Stream::ThenStartTimer(Timer *t) {
1821   VLOG_CALL(PARAM(t));
1822 
1823   if (ok()) {
1824     CheckError(parent_->StartTimer(this, t));
1825   } else {
1826     LOG(INFO) << DebugStreamPointers()
1827               << " did not enqueue 'start timer': " << t;
1828   }
1829   return *this;
1830 }
1831 
ThenStopTimer(Timer * t)1832 Stream &Stream::ThenStopTimer(Timer *t) {
1833   VLOG_CALL(PARAM(t));
1834 
1835   if (ok()) {
1836     CheckError(parent_->StopTimer(this, t));
1837   } else {
1838     LOG(INFO) << DebugStreamPointers()
1839               << " did not enqueue 'stop timer': " << t;
1840   }
1841   return *this;
1842 }
1843 
ThenWaitFor(Stream * other)1844 Stream &Stream::ThenWaitFor(Stream *other) {
1845   VLOG_CALL(PARAM(other));
1846 
1847   CHECK(this != other) << "stream cannot wait for itself";
1848   if (ok() && other->ok()) {
1849     CheckError(parent_->CreateStreamDependency(this, other));
1850   } else {
1851     SetError();
1852     LOG(INFO) << DebugStreamPointers() << " did not wait for "
1853               << other->DebugStreamPointers();
1854   }
1855   return *this;
1856 }
1857 
ThenWaitFor(Event * event)1858 Stream &Stream::ThenWaitFor(Event *event) {
1859   VLOG_CALL(PARAM(event));
1860 
1861   if (ok()) {
1862     port::Status status = parent_->WaitForEvent(this, event);
1863     if (!status.ok()) {
1864       LOG(ERROR) << "Error waiting for event in stream: "
1865                  << status.error_message()
1866                  << "; not marking stream as bad, as the Event object may be "
1867                  << "at fault. Monitor for further errors.";
1868     }
1869   } else {
1870     LOG(INFO) << DebugStreamPointers() << " did not wait for an event.";
1871   }
1872   return *this;
1873 }
1874 
1875 // A functor that implements ThenBlasXXX interfaces, which calls DoBlasXXX
1876 // functions and logs for errors.
1877 template <typename... Args>
1878 struct ThenBlasImpl {
1879   // blas_func is the DoBlasXXX member function pointer, and args are its
1880   // arguments except the first one of Stream* type.
operator ()stream_executor::ThenBlasImpl1881   Stream &operator()(Stream *stream,
1882                      bool (blas::BlasSupport::*blas_func)(Stream *, Args...),
1883                      Args... args) {
1884     return Run(stream, blas_func, /*record_error=*/true, args...);
1885   }
1886 
1887   // Like operator(), but only calls stream->CheckError() if record_error is
1888   // true.
1889   Stream &Run(Stream *stream,
1890               bool (blas::BlasSupport::*blas_func)(Stream *, Args...),
1891               bool record_error, Args... args);
1892 };
1893 
1894 template <typename... Args>
Run(Stream * stream,bool (blas::BlasSupport::* blas_func)(Stream *,Args...),bool record_error,Args...args)1895 Stream &ThenBlasImpl<Args...>::Run(
1896     Stream *stream, bool (blas::BlasSupport::*blas_func)(Stream *, Args...),
1897     bool record_error, Args... args) {
1898   if (stream->ok()) {
1899     bool ok;
1900     if (blas::BlasSupport *blas = stream->parent_->AsBlas()) {
1901       ok = (blas->*blas_func)(stream, args...);
1902     } else {
1903       LOG(WARNING)
1904           << "attempting to perform BLAS operation using StreamExecutor "
1905              "without BLAS support";
1906       ok = false;
1907     }
1908     if (record_error) {
1909       stream->CheckError(ok);
1910     }
1911   }
1912   return *stream;
1913 }
1914 
ThenBlasAsum(uint64 elem_count,const DeviceMemory<float> & x,int incx,DeviceMemory<float> * result)1915 Stream &Stream::ThenBlasAsum(uint64 elem_count, const DeviceMemory<float> &x,
1916                              int incx, DeviceMemory<float> *result) {
1917   VLOG_CALL(PARAM(elem_count), PARAM(x), PARAM(incx), PARAM(result));
1918 
1919   ThenBlasImpl<uint64, const DeviceMemory<float> &, int, DeviceMemory<float> *>
1920       impl;
1921   return impl(this, &blas::BlasSupport::DoBlasAsum, elem_count, x, incx,
1922               result);
1923 }
1924 
ThenBlasAsum(uint64 elem_count,const DeviceMemory<double> & x,int incx,DeviceMemory<double> * result)1925 Stream &Stream::ThenBlasAsum(uint64 elem_count, const DeviceMemory<double> &x,
1926                              int incx, DeviceMemory<double> *result) {
1927   VLOG_CALL(PARAM(elem_count), PARAM(x), PARAM(incx), PARAM(result));
1928 
1929   ThenBlasImpl<uint64, const DeviceMemory<double> &, int,
1930                DeviceMemory<double> *> impl;
1931   return impl(this, &blas::BlasSupport::DoBlasAsum, elem_count, x, incx,
1932               result);
1933 }
1934 
ThenBlasAsum(uint64 elem_count,const DeviceMemory<std::complex<float>> & x,int incx,DeviceMemory<float> * result)1935 Stream &Stream::ThenBlasAsum(uint64 elem_count,
1936                              const DeviceMemory<std::complex<float>> &x,
1937                              int incx, DeviceMemory<float> *result) {
1938   VLOG_CALL(PARAM(elem_count), PARAM(x), PARAM(incx), PARAM(result));
1939 
1940   ThenBlasImpl<uint64, const DeviceMemory<std::complex<float>> &, int,
1941                DeviceMemory<float> *> impl;
1942   return impl(this, &blas::BlasSupport::DoBlasAsum, elem_count, x, incx,
1943               result);
1944 }
1945 
ThenBlasAsum(uint64 elem_count,const DeviceMemory<std::complex<double>> & x,int incx,DeviceMemory<double> * result)1946 Stream &Stream::ThenBlasAsum(uint64 elem_count,
1947                              const DeviceMemory<std::complex<double>> &x,
1948                              int incx, DeviceMemory<double> *result) {
1949   VLOG_CALL(PARAM(elem_count), PARAM(x), PARAM(incx), PARAM(result));
1950 
1951   ThenBlasImpl<uint64, const DeviceMemory<std::complex<double>> &, int,
1952                DeviceMemory<double> *> impl;
1953   return impl(this, &blas::BlasSupport::DoBlasAsum, elem_count, x, incx,
1954               result);
1955 }
1956 
ThenBlasAxpy(uint64 elem_count,float alpha,const DeviceMemory<float> & x,int incx,DeviceMemory<float> * y,int incy)1957 Stream &Stream::ThenBlasAxpy(uint64 elem_count, float alpha,
1958                              const DeviceMemory<float> &x, int incx,
1959                              DeviceMemory<float> *y, int incy) {
1960   VLOG_CALL(PARAM(elem_count), PARAM(alpha), PARAM(x), PARAM(incx), PARAM(y),
1961             PARAM(incy));
1962 
1963   ThenBlasImpl<uint64, float, const DeviceMemory<float> &, int,
1964                DeviceMemory<float> *, int> impl;
1965   return impl(this, &blas::BlasSupport::DoBlasAxpy, elem_count, alpha, x, incx,
1966               y, incy);
1967 }
1968 
ThenBlasAxpy(uint64 elem_count,double alpha,const DeviceMemory<double> & x,int incx,DeviceMemory<double> * y,int incy)1969 Stream &Stream::ThenBlasAxpy(uint64 elem_count, double alpha,
1970                              const DeviceMemory<double> &x, int incx,
1971                              DeviceMemory<double> *y, int incy) {
1972   VLOG_CALL(PARAM(elem_count), PARAM(alpha), PARAM(x), PARAM(incx), PARAM(y),
1973             PARAM(incy));
1974 
1975   ThenBlasImpl<uint64, double, const DeviceMemory<double> &, int,
1976                DeviceMemory<double> *, int> impl;
1977   return impl(this, &blas::BlasSupport::DoBlasAxpy, elem_count, alpha, x, incx,
1978               y, incy);
1979 }
1980 
ThenBlasAxpy(uint64 elem_count,std::complex<float> alpha,const DeviceMemory<std::complex<float>> & x,int incx,DeviceMemory<std::complex<float>> * y,int incy)1981 Stream &Stream::ThenBlasAxpy(uint64 elem_count, std::complex<float> alpha,
1982                              const DeviceMemory<std::complex<float>> &x,
1983                              int incx, DeviceMemory<std::complex<float>> *y,
1984                              int incy) {
1985   VLOG_CALL(PARAM(elem_count), PARAM(alpha), PARAM(x), PARAM(incx), PARAM(y),
1986             PARAM(incy));
1987 
1988   ThenBlasImpl<uint64, std::complex<float>,
1989                const DeviceMemory<std::complex<float>> &, int,
1990                DeviceMemory<std::complex<float>> *, int> impl;
1991   return impl(this, &blas::BlasSupport::DoBlasAxpy, elem_count, alpha, x, incx,
1992               y, incy);
1993 }
1994 
ThenBlasAxpy(uint64 elem_count,std::complex<double> alpha,const DeviceMemory<std::complex<double>> & x,int incx,DeviceMemory<std::complex<double>> * y,int incy)1995 Stream &Stream::ThenBlasAxpy(uint64 elem_count, std::complex<double> alpha,
1996                              const DeviceMemory<std::complex<double>> &x,
1997                              int incx, DeviceMemory<std::complex<double>> *y,
1998                              int incy) {
1999   VLOG_CALL(PARAM(elem_count), PARAM(alpha), PARAM(x), PARAM(incx), PARAM(y),
2000             PARAM(incy));
2001 
2002   ThenBlasImpl<uint64, std::complex<double>,
2003                const DeviceMemory<std::complex<double>> &, int,
2004                DeviceMemory<std::complex<double>> *, int> impl;
2005   return impl(this, &blas::BlasSupport::DoBlasAxpy, elem_count, alpha, x, incx,
2006               y, incy);
2007 }
2008 
ThenBlasCopy(uint64 elem_count,const DeviceMemory<float> & x,int incx,DeviceMemory<float> * y,int incy)2009 Stream &Stream::ThenBlasCopy(uint64 elem_count, const DeviceMemory<float> &x,
2010                              int incx, DeviceMemory<float> *y, int incy) {
2011   VLOG_CALL(PARAM(elem_count), PARAM(x), PARAM(incx), PARAM(y), PARAM(incy));
2012 
2013   ThenBlasImpl<uint64, const DeviceMemory<float> &, int, DeviceMemory<float> *,
2014                int> impl;
2015   return impl(this, &blas::BlasSupport::DoBlasCopy, elem_count, x, incx, y,
2016               incy);
2017 }
2018 
ThenBlasCopy(uint64 elem_count,const DeviceMemory<double> & x,int incx,DeviceMemory<double> * y,int incy)2019 Stream &Stream::ThenBlasCopy(uint64 elem_count, const DeviceMemory<double> &x,
2020                              int incx, DeviceMemory<double> *y, int incy) {
2021   VLOG_CALL(PARAM(elem_count), PARAM(x), PARAM(incx), PARAM(y), PARAM(incy));
2022 
2023   ThenBlasImpl<uint64, const DeviceMemory<double> &, int,
2024                DeviceMemory<double> *, int> impl;
2025   return impl(this, &blas::BlasSupport::DoBlasCopy, elem_count, x, incx, y,
2026               incy);
2027 }
2028 
ThenBlasCopy(uint64 elem_count,const DeviceMemory<std::complex<float>> & x,int incx,DeviceMemory<std::complex<float>> * y,int incy)2029 Stream &Stream::ThenBlasCopy(uint64 elem_count,
2030                              const DeviceMemory<std::complex<float>> &x,
2031                              int incx, DeviceMemory<std::complex<float>> *y,
2032                              int incy) {
2033   VLOG_CALL(PARAM(elem_count), PARAM(x), PARAM(incx), PARAM(y), PARAM(incy));
2034 
2035   ThenBlasImpl<uint64, const DeviceMemory<std::complex<float>> &, int,
2036                DeviceMemory<std::complex<float>> *, int> impl;
2037   return impl(this, &blas::BlasSupport::DoBlasCopy, elem_count, x, incx, y,
2038               incy);
2039 }
2040 
ThenBlasCopy(uint64 elem_count,const DeviceMemory<std::complex<double>> & x,int incx,DeviceMemory<std::complex<double>> * y,int incy)2041 Stream &Stream::ThenBlasCopy(uint64 elem_count,
2042                              const DeviceMemory<std::complex<double>> &x,
2043                              int incx, DeviceMemory<std::complex<double>> *y,
2044                              int incy) {
2045   VLOG_CALL(PARAM(elem_count), PARAM(x), PARAM(incx), PARAM(y), PARAM(incy));
2046 
2047   ThenBlasImpl<uint64, const DeviceMemory<std::complex<double>> &, int,
2048                DeviceMemory<std::complex<double>> *, int> impl;
2049   return impl(this, &blas::BlasSupport::DoBlasCopy, elem_count, x, incx, y,
2050               incy);
2051 }
2052 
ThenBlasDot(uint64 elem_count,const DeviceMemory<float> & x,int incx,const DeviceMemory<float> & y,int incy,DeviceMemory<float> * result)2053 Stream &Stream::ThenBlasDot(uint64 elem_count, const DeviceMemory<float> &x,
2054                             int incx, const DeviceMemory<float> &y, int incy,
2055                             DeviceMemory<float> *result) {
2056   VLOG_CALL(PARAM(elem_count), PARAM(x), PARAM(incx), PARAM(y), PARAM(incy),
2057             PARAM(result));
2058 
2059   ThenBlasImpl<uint64, const DeviceMemory<float> &, int,
2060                const DeviceMemory<float> &, int, DeviceMemory<float> *> impl;
2061   return impl(this, &blas::BlasSupport::DoBlasDot, elem_count, x, incx, y, incy,
2062               result);
2063 }
2064 
ThenBlasDot(uint64 elem_count,const DeviceMemory<double> & x,int incx,const DeviceMemory<double> & y,int incy,DeviceMemory<double> * result)2065 Stream &Stream::ThenBlasDot(uint64 elem_count, const DeviceMemory<double> &x,
2066                             int incx, const DeviceMemory<double> &y, int incy,
2067                             DeviceMemory<double> *result) {
2068   VLOG_CALL(PARAM(elem_count), PARAM(x), PARAM(incx), PARAM(y), PARAM(incy),
2069             PARAM(result));
2070 
2071   ThenBlasImpl<uint64, const DeviceMemory<double> &, int,
2072                const DeviceMemory<double> &, int, DeviceMemory<double> *> impl;
2073   return impl(this, &blas::BlasSupport::DoBlasDot, elem_count, x, incx, y, incy,
2074               result);
2075 }
2076 
ThenBlasDotc(uint64 elem_count,const DeviceMemory<std::complex<float>> & x,int incx,const DeviceMemory<std::complex<float>> & y,int incy,DeviceMemory<std::complex<float>> * result)2077 Stream &Stream::ThenBlasDotc(uint64 elem_count,
2078                              const DeviceMemory<std::complex<float>> &x,
2079                              int incx,
2080                              const DeviceMemory<std::complex<float>> &y,
2081                              int incy,
2082                              DeviceMemory<std::complex<float>> *result) {
2083   VLOG_CALL(PARAM(elem_count), PARAM(x), PARAM(incx), PARAM(y), PARAM(incy),
2084             PARAM(result));
2085 
2086   ThenBlasImpl<uint64, const DeviceMemory<std::complex<float>> &, int,
2087                const DeviceMemory<std::complex<float>> &, int,
2088                DeviceMemory<std::complex<float>> *> impl;
2089   return impl(this, &blas::BlasSupport::DoBlasDotc, elem_count, x, incx, y,
2090               incy, result);
2091 }
2092 
ThenBlasDotc(uint64 elem_count,const DeviceMemory<std::complex<double>> & x,int incx,const DeviceMemory<std::complex<double>> & y,int incy,DeviceMemory<std::complex<double>> * result)2093 Stream &Stream::ThenBlasDotc(uint64 elem_count,
2094                              const DeviceMemory<std::complex<double>> &x,
2095                              int incx,
2096                              const DeviceMemory<std::complex<double>> &y,
2097                              int incy,
2098                              DeviceMemory<std::complex<double>> *result) {
2099   VLOG_CALL(PARAM(elem_count), PARAM(x), PARAM(incx), PARAM(y), PARAM(incy),
2100             PARAM(result));
2101 
2102   ThenBlasImpl<uint64, const DeviceMemory<std::complex<double>> &, int,
2103                const DeviceMemory<std::complex<double>> &, int,
2104                DeviceMemory<std::complex<double>> *> impl;
2105   return impl(this, &blas::BlasSupport::DoBlasDotc, elem_count, x, incx, y,
2106               incy, result);
2107 }
2108 
ThenBlasDotu(uint64 elem_count,const DeviceMemory<std::complex<float>> & x,int incx,const DeviceMemory<std::complex<float>> & y,int incy,DeviceMemory<std::complex<float>> * result)2109 Stream &Stream::ThenBlasDotu(uint64 elem_count,
2110                              const DeviceMemory<std::complex<float>> &x,
2111                              int incx,
2112                              const DeviceMemory<std::complex<float>> &y,
2113                              int incy,
2114                              DeviceMemory<std::complex<float>> *result) {
2115   VLOG_CALL(PARAM(elem_count), PARAM(x), PARAM(incx), PARAM(y), PARAM(incy),
2116             PARAM(result));
2117 
2118   ThenBlasImpl<uint64, const DeviceMemory<std::complex<float>> &, int,
2119                const DeviceMemory<std::complex<float>> &, int,
2120                DeviceMemory<std::complex<float>> *> impl;
2121   return impl(this, &blas::BlasSupport::DoBlasDotu, elem_count, x, incx, y,
2122               incy, result);
2123 }
2124 
ThenBlasDotu(uint64 elem_count,const DeviceMemory<std::complex<double>> & x,int incx,const DeviceMemory<std::complex<double>> & y,int incy,DeviceMemory<std::complex<double>> * result)2125 Stream &Stream::ThenBlasDotu(uint64 elem_count,
2126                              const DeviceMemory<std::complex<double>> &x,
2127                              int incx,
2128                              const DeviceMemory<std::complex<double>> &y,
2129                              int incy,
2130                              DeviceMemory<std::complex<double>> *result) {
2131   VLOG_CALL(PARAM(elem_count), PARAM(x), PARAM(incx), PARAM(y), PARAM(incy),
2132             PARAM(result));
2133 
2134   ThenBlasImpl<uint64, const DeviceMemory<std::complex<double>> &, int,
2135                const DeviceMemory<std::complex<double>> &, int,
2136                DeviceMemory<std::complex<double>> *> impl;
2137   return impl(this, &blas::BlasSupport::DoBlasDotu, elem_count, x, incx, y,
2138               incy, result);
2139 }
2140 
ThenBlasNrm2(uint64 elem_count,const DeviceMemory<float> & x,int incx,DeviceMemory<float> * result)2141 Stream &Stream::ThenBlasNrm2(uint64 elem_count, const DeviceMemory<float> &x,
2142                              int incx, DeviceMemory<float> *result) {
2143   VLOG_CALL(PARAM(elem_count), PARAM(x), PARAM(incx), PARAM(result));
2144 
2145   ThenBlasImpl<uint64, const DeviceMemory<float> &, int, DeviceMemory<float> *>
2146       impl;
2147   return impl(this, &blas::BlasSupport::DoBlasNrm2, elem_count, x, incx,
2148               result);
2149 }
2150 
ThenBlasNrm2(uint64 elem_count,const DeviceMemory<double> & x,int incx,DeviceMemory<double> * result)2151 Stream &Stream::ThenBlasNrm2(uint64 elem_count, const DeviceMemory<double> &x,
2152                              int incx, DeviceMemory<double> *result) {
2153   VLOG_CALL(PARAM(elem_count), PARAM(x), PARAM(incx), PARAM(result));
2154 
2155   ThenBlasImpl<uint64, const DeviceMemory<double> &, int,
2156                DeviceMemory<double> *> impl;
2157   return impl(this, &blas::BlasSupport::DoBlasNrm2, elem_count, x, incx,
2158               result);
2159 }
2160 
ThenBlasNrm2(uint64 elem_count,const DeviceMemory<std::complex<float>> & x,int incx,DeviceMemory<float> * result)2161 Stream &Stream::ThenBlasNrm2(uint64 elem_count,
2162                              const DeviceMemory<std::complex<float>> &x,
2163                              int incx, DeviceMemory<float> *result) {
2164   VLOG_CALL(PARAM(elem_count), PARAM(x), PARAM(incx), PARAM(result));
2165 
2166   ThenBlasImpl<uint64, const DeviceMemory<std::complex<float>> &, int,
2167                DeviceMemory<float> *> impl;
2168   return impl(this, &blas::BlasSupport::DoBlasNrm2, elem_count, x, incx,
2169               result);
2170 }
2171 
ThenBlasNrm2(uint64 elem_count,const DeviceMemory<std::complex<double>> & x,int incx,DeviceMemory<double> * result)2172 Stream &Stream::ThenBlasNrm2(uint64 elem_count,
2173                              const DeviceMemory<std::complex<double>> &x,
2174                              int incx, DeviceMemory<double> *result) {
2175   VLOG_CALL(PARAM(elem_count), PARAM(x), PARAM(incx), PARAM(result));
2176 
2177   ThenBlasImpl<uint64, const DeviceMemory<std::complex<double>> &, int,
2178                DeviceMemory<double> *> impl;
2179   return impl(this, &blas::BlasSupport::DoBlasNrm2, elem_count, x, incx,
2180               result);
2181 }
2182 
ThenBlasRot(uint64 elem_count,DeviceMemory<float> * x,int incx,DeviceMemory<float> * y,int incy,float c,float s)2183 Stream &Stream::ThenBlasRot(uint64 elem_count, DeviceMemory<float> *x, int incx,
2184                             DeviceMemory<float> *y, int incy, float c,
2185                             float s) {
2186   VLOG_CALL(PARAM(elem_count), PARAM(x), PARAM(incx), PARAM(y), PARAM(incy),
2187             PARAM(c), PARAM(s));
2188 
2189   ThenBlasImpl<uint64, DeviceMemory<float> *, int, DeviceMemory<float> *, int,
2190                float, float> impl;
2191   return impl(this, &blas::BlasSupport::DoBlasRot, elem_count, x, incx, y, incy,
2192               c, s);
2193 }
2194 
ThenBlasRot(uint64 elem_count,DeviceMemory<double> * x,int incx,DeviceMemory<double> * y,int incy,double c,double s)2195 Stream &Stream::ThenBlasRot(uint64 elem_count, DeviceMemory<double> *x,
2196                             int incx, DeviceMemory<double> *y, int incy,
2197                             double c, double s) {
2198   VLOG_CALL(PARAM(elem_count), PARAM(x), PARAM(incx), PARAM(y), PARAM(incy),
2199             PARAM(c), PARAM(s));
2200 
2201   ThenBlasImpl<uint64, DeviceMemory<double> *, int, DeviceMemory<double> *, int,
2202                double, double> impl;
2203   return impl(this, &blas::BlasSupport::DoBlasRot, elem_count, x, incx, y, incy,
2204               c, s);
2205 }
2206 
ThenBlasRot(uint64 elem_count,DeviceMemory<std::complex<float>> * x,int incx,DeviceMemory<std::complex<float>> * y,int incy,float c,float s)2207 Stream &Stream::ThenBlasRot(uint64 elem_count,
2208                             DeviceMemory<std::complex<float>> *x, int incx,
2209                             DeviceMemory<std::complex<float>> *y, int incy,
2210                             float c, float s) {
2211   VLOG_CALL(PARAM(elem_count), PARAM(x), PARAM(incx), PARAM(y), PARAM(incy),
2212             PARAM(c), PARAM(s));
2213 
2214   ThenBlasImpl<uint64, DeviceMemory<std::complex<float>> *, int,
2215                DeviceMemory<std::complex<float>> *, int, float, float> impl;
2216   return impl(this, &blas::BlasSupport::DoBlasRot, elem_count, x, incx, y, incy,
2217               c, s);
2218 }
2219 
ThenBlasRot(uint64 elem_count,DeviceMemory<std::complex<double>> * x,int incx,DeviceMemory<std::complex<double>> * y,int incy,double c,double s)2220 Stream &Stream::ThenBlasRot(uint64 elem_count,
2221                             DeviceMemory<std::complex<double>> *x, int incx,
2222                             DeviceMemory<std::complex<double>> *y, int incy,
2223                             double c, double s) {
2224   VLOG_CALL(PARAM(elem_count), PARAM(x), PARAM(incx), PARAM(y), PARAM(incy),
2225             PARAM(c), PARAM(s));
2226 
2227   ThenBlasImpl<uint64, DeviceMemory<std::complex<double>> *, int,
2228                DeviceMemory<std::complex<double>> *, int, double, double> impl;
2229   return impl(this, &blas::BlasSupport::DoBlasRot, elem_count, x, incx, y, incy,
2230               c, s);
2231 }
2232 
ThenBlasRotg(DeviceMemory<float> * a,DeviceMemory<float> * b,DeviceMemory<float> * c,DeviceMemory<float> * s)2233 Stream &Stream::ThenBlasRotg(DeviceMemory<float> *a, DeviceMemory<float> *b,
2234                              DeviceMemory<float> *c, DeviceMemory<float> *s) {
2235   VLOG_CALL(PARAM(a), PARAM(b), PARAM(c), PARAM(s));
2236 
2237   ThenBlasImpl<DeviceMemory<float> *, DeviceMemory<float> *,
2238                DeviceMemory<float> *, DeviceMemory<float> *> impl;
2239   return impl(this, &blas::BlasSupport::DoBlasRotg, a, b, c, s);
2240 }
2241 
ThenBlasRotg(DeviceMemory<double> * a,DeviceMemory<double> * b,DeviceMemory<double> * c,DeviceMemory<double> * s)2242 Stream &Stream::ThenBlasRotg(DeviceMemory<double> *a, DeviceMemory<double> *b,
2243                              DeviceMemory<double> *c, DeviceMemory<double> *s) {
2244   VLOG_CALL(PARAM(a), PARAM(b), PARAM(c), PARAM(s));
2245 
2246   ThenBlasImpl<DeviceMemory<double> *, DeviceMemory<double> *,
2247                DeviceMemory<double> *, DeviceMemory<double> *> impl;
2248   return impl(this, &blas::BlasSupport::DoBlasRotg, a, b, c, s);
2249 }
2250 
ThenBlasRotg(DeviceMemory<std::complex<float>> * a,DeviceMemory<std::complex<float>> * b,DeviceMemory<float> * c,DeviceMemory<std::complex<float>> * s)2251 Stream &Stream::ThenBlasRotg(DeviceMemory<std::complex<float>> *a,
2252                              DeviceMemory<std::complex<float>> *b,
2253                              DeviceMemory<float> *c,
2254                              DeviceMemory<std::complex<float>> *s) {
2255   VLOG_CALL(PARAM(a), PARAM(b), PARAM(c), PARAM(s));
2256 
2257   ThenBlasImpl<DeviceMemory<std::complex<float>> *,
2258                DeviceMemory<std::complex<float>> *, DeviceMemory<float> *,
2259                DeviceMemory<std::complex<float>> *> impl;
2260   return impl(this, &blas::BlasSupport::DoBlasRotg, a, b, c, s);
2261 }
2262 
ThenBlasRotg(DeviceMemory<std::complex<double>> * a,DeviceMemory<std::complex<double>> * b,DeviceMemory<double> * c,DeviceMemory<std::complex<double>> * s)2263 Stream &Stream::ThenBlasRotg(DeviceMemory<std::complex<double>> *a,
2264                              DeviceMemory<std::complex<double>> *b,
2265                              DeviceMemory<double> *c,
2266                              DeviceMemory<std::complex<double>> *s) {
2267   VLOG_CALL(PARAM(a), PARAM(b), PARAM(c), PARAM(s));
2268 
2269   ThenBlasImpl<DeviceMemory<std::complex<double>> *,
2270                DeviceMemory<std::complex<double>> *, DeviceMemory<double> *,
2271                DeviceMemory<std::complex<double>> *> impl;
2272   return impl(this, &blas::BlasSupport::DoBlasRotg, a, b, c, s);
2273 }
2274 
ThenBlasRotm(uint64 elem_count,DeviceMemory<float> * x,int incx,DeviceMemory<float> * y,int incy,const DeviceMemory<float> & param)2275 Stream &Stream::ThenBlasRotm(uint64 elem_count, DeviceMemory<float> *x,
2276                              int incx, DeviceMemory<float> *y, int incy,
2277                              const DeviceMemory<float> &param) {
2278   VLOG_CALL(PARAM(elem_count), PARAM(x), PARAM(incx), PARAM(y), PARAM(incy),
2279             PARAM(param));
2280 
2281   ThenBlasImpl<uint64, DeviceMemory<float> *, int, DeviceMemory<float> *, int,
2282                const DeviceMemory<float> &> impl;
2283   return impl(this, &blas::BlasSupport::DoBlasRotm, elem_count, x, incx, y,
2284               incy, param);
2285 }
2286 
ThenBlasRotm(uint64 elem_count,DeviceMemory<double> * x,int incx,DeviceMemory<double> * y,int incy,const DeviceMemory<double> & param)2287 Stream &Stream::ThenBlasRotm(uint64 elem_count, DeviceMemory<double> *x,
2288                              int incx, DeviceMemory<double> *y, int incy,
2289                              const DeviceMemory<double> &param) {
2290   VLOG_CALL(PARAM(elem_count), PARAM(x), PARAM(incx), PARAM(y), PARAM(incy),
2291             PARAM(param));
2292 
2293   ThenBlasImpl<uint64, DeviceMemory<double> *, int, DeviceMemory<double> *, int,
2294                const DeviceMemory<double> &> impl;
2295   return impl(this, &blas::BlasSupport::DoBlasRotm, elem_count, x, incx, y,
2296               incy, param);
2297 }
2298 
ThenBlasRotmg(DeviceMemory<float> * d1,DeviceMemory<float> * d2,DeviceMemory<float> * x1,const DeviceMemory<float> & y1,DeviceMemory<float> * param)2299 Stream &Stream::ThenBlasRotmg(DeviceMemory<float> *d1, DeviceMemory<float> *d2,
2300                               DeviceMemory<float> *x1,
2301                               const DeviceMemory<float> &y1,
2302                               DeviceMemory<float> *param) {
2303   VLOG_CALL(PARAM(d1), PARAM(d2), PARAM(x1), PARAM(y1), PARAM(param));
2304 
2305   ThenBlasImpl<DeviceMemory<float> *, DeviceMemory<float> *,
2306                DeviceMemory<float> *, const DeviceMemory<float> &,
2307                DeviceMemory<float> *> impl;
2308   return impl(this, &blas::BlasSupport::DoBlasRotmg, d1, d2, x1, y1, param);
2309 }
2310 
ThenBlasRotmg(DeviceMemory<double> * d1,DeviceMemory<double> * d2,DeviceMemory<double> * x1,const DeviceMemory<double> & y1,DeviceMemory<double> * param)2311 Stream &Stream::ThenBlasRotmg(DeviceMemory<double> *d1,
2312                               DeviceMemory<double> *d2,
2313                               DeviceMemory<double> *x1,
2314                               const DeviceMemory<double> &y1,
2315                               DeviceMemory<double> *param) {
2316   VLOG_CALL(PARAM(d1), PARAM(d2), PARAM(x1), PARAM(y1), PARAM(param));
2317 
2318   ThenBlasImpl<DeviceMemory<double> *, DeviceMemory<double> *,
2319                DeviceMemory<double> *, const DeviceMemory<double> &,
2320                DeviceMemory<double> *> impl;
2321   return impl(this, &blas::BlasSupport::DoBlasRotmg, d1, d2, x1, y1, param);
2322 }
2323 
ThenBlasScal(uint64 elem_count,float alpha,DeviceMemory<float> * x,int incx)2324 Stream &Stream::ThenBlasScal(uint64 elem_count, float alpha,
2325                              DeviceMemory<float> *x, int incx) {
2326   VLOG_CALL(PARAM(elem_count), PARAM(alpha), PARAM(x), PARAM(incx));
2327 
2328   ThenBlasImpl<uint64, float, DeviceMemory<float> *, int> impl;
2329   return impl(this, &blas::BlasSupport::DoBlasScal, elem_count, alpha, x, incx);
2330 }
2331 
ThenBlasScal(uint64 elem_count,double alpha,DeviceMemory<double> * x,int incx)2332 Stream &Stream::ThenBlasScal(uint64 elem_count, double alpha,
2333                              DeviceMemory<double> *x, int incx) {
2334   VLOG_CALL(PARAM(elem_count), PARAM(alpha), PARAM(x), PARAM(incx));
2335 
2336   ThenBlasImpl<uint64, double, DeviceMemory<double> *, int> impl;
2337   return impl(this, &blas::BlasSupport::DoBlasScal, elem_count, alpha, x, incx);
2338 }
2339 
ThenBlasScal(uint64 elem_count,float alpha,DeviceMemory<std::complex<float>> * x,int incx)2340 Stream &Stream::ThenBlasScal(uint64 elem_count, float alpha,
2341                              DeviceMemory<std::complex<float>> *x, int incx) {
2342   VLOG_CALL(PARAM(elem_count), PARAM(alpha), PARAM(x), PARAM(incx));
2343 
2344   ThenBlasImpl<uint64, float, DeviceMemory<std::complex<float>> *, int> impl;
2345   return impl(this, &blas::BlasSupport::DoBlasScal, elem_count, alpha, x, incx);
2346 }
2347 
ThenBlasScal(uint64 elem_count,double alpha,DeviceMemory<std::complex<double>> * x,int incx)2348 Stream &Stream::ThenBlasScal(uint64 elem_count, double alpha,
2349                              DeviceMemory<std::complex<double>> *x, int incx) {
2350   VLOG_CALL(PARAM(elem_count), PARAM(alpha), PARAM(x), PARAM(incx));
2351 
2352   ThenBlasImpl<uint64, double, DeviceMemory<std::complex<double>> *, int> impl;
2353   return impl(this, &blas::BlasSupport::DoBlasScal, elem_count, alpha, x, incx);
2354 }
2355 
ThenBlasScal(uint64 elem_count,std::complex<float> alpha,DeviceMemory<std::complex<float>> * x,int incx)2356 Stream &Stream::ThenBlasScal(uint64 elem_count, std::complex<float> alpha,
2357                              DeviceMemory<std::complex<float>> *x, int incx) {
2358   VLOG_CALL(PARAM(elem_count), PARAM(alpha), PARAM(x), PARAM(incx));
2359 
2360   ThenBlasImpl<uint64, std::complex<float>, DeviceMemory<std::complex<float>> *,
2361                int> impl;
2362   return impl(this, &blas::BlasSupport::DoBlasScal, elem_count, alpha, x, incx);
2363 }
2364 
ThenBlasScal(uint64 elem_count,std::complex<double> alpha,DeviceMemory<std::complex<double>> * x,int incx)2365 Stream &Stream::ThenBlasScal(uint64 elem_count, std::complex<double> alpha,
2366                              DeviceMemory<std::complex<double>> *x, int incx) {
2367   VLOG_CALL(PARAM(elem_count), PARAM(alpha), PARAM(x), PARAM(incx));
2368 
2369   ThenBlasImpl<uint64, std::complex<double>,
2370                DeviceMemory<std::complex<double>> *, int> impl;
2371   return impl(this, &blas::BlasSupport::DoBlasScal, elem_count, alpha, x, incx);
2372 }
2373 
ThenBlasSwap(uint64 elem_count,DeviceMemory<float> * x,int incx,DeviceMemory<float> * y,int incy)2374 Stream &Stream::ThenBlasSwap(uint64 elem_count, DeviceMemory<float> *x,
2375                              int incx, DeviceMemory<float> *y, int incy) {
2376   VLOG_CALL(PARAM(elem_count), PARAM(x), PARAM(incx), PARAM(y), PARAM(incy));
2377 
2378   ThenBlasImpl<uint64, DeviceMemory<float> *, int, DeviceMemory<float> *, int>
2379       impl;
2380   return impl(this, &blas::BlasSupport::DoBlasSwap, elem_count, x, incx, y,
2381               incy);
2382 }
2383 
ThenBlasSwap(uint64 elem_count,DeviceMemory<double> * x,int incx,DeviceMemory<double> * y,int incy)2384 Stream &Stream::ThenBlasSwap(uint64 elem_count, DeviceMemory<double> *x,
2385                              int incx, DeviceMemory<double> *y, int incy) {
2386   VLOG_CALL(PARAM(elem_count), PARAM(x), PARAM(incx), PARAM(y), PARAM(incy));
2387 
2388   ThenBlasImpl<uint64, DeviceMemory<double> *, int, DeviceMemory<double> *, int>
2389       impl;
2390   return impl(this, &blas::BlasSupport::DoBlasSwap, elem_count, x, incx, y,
2391               incy);
2392 }
2393 
ThenBlasSwap(uint64 elem_count,DeviceMemory<std::complex<float>> * x,int incx,DeviceMemory<std::complex<float>> * y,int incy)2394 Stream &Stream::ThenBlasSwap(uint64 elem_count,
2395                              DeviceMemory<std::complex<float>> *x, int incx,
2396                              DeviceMemory<std::complex<float>> *y, int incy) {
2397   VLOG_CALL(PARAM(elem_count), PARAM(x), PARAM(incx), PARAM(y), PARAM(incy));
2398 
2399   ThenBlasImpl<uint64, DeviceMemory<std::complex<float>> *, int,
2400                DeviceMemory<std::complex<float>> *, int> impl;
2401   return impl(this, &blas::BlasSupport::DoBlasSwap, elem_count, x, incx, y,
2402               incy);
2403 }
2404 
ThenBlasSwap(uint64 elem_count,DeviceMemory<std::complex<double>> * x,int incx,DeviceMemory<std::complex<double>> * y,int incy)2405 Stream &Stream::ThenBlasSwap(uint64 elem_count,
2406                              DeviceMemory<std::complex<double>> *x, int incx,
2407                              DeviceMemory<std::complex<double>> *y, int incy) {
2408   VLOG_CALL(PARAM(elem_count), PARAM(x), PARAM(incx), PARAM(y), PARAM(incy));
2409 
2410   ThenBlasImpl<uint64, DeviceMemory<std::complex<double>> *, int,
2411                DeviceMemory<std::complex<double>> *, int> impl;
2412   return impl(this, &blas::BlasSupport::DoBlasSwap, elem_count, x, incx, y,
2413               incy);
2414 }
2415 
ThenBlasIamax(uint64 elem_count,const DeviceMemory<float> & x,int incx,DeviceMemory<int> * result)2416 Stream &Stream::ThenBlasIamax(uint64 elem_count, const DeviceMemory<float> &x,
2417                               int incx, DeviceMemory<int> *result) {
2418   VLOG_CALL(PARAM(elem_count), PARAM(x), PARAM(incx), PARAM(result));
2419 
2420   ThenBlasImpl<uint64, const DeviceMemory<float> &, int, DeviceMemory<int> *>
2421       impl;
2422   return impl(this, &blas::BlasSupport::DoBlasIamax, elem_count, x, incx,
2423               result);
2424 }
2425 
ThenBlasIamax(uint64 elem_count,const DeviceMemory<double> & x,int incx,DeviceMemory<int> * result)2426 Stream &Stream::ThenBlasIamax(uint64 elem_count, const DeviceMemory<double> &x,
2427                               int incx, DeviceMemory<int> *result) {
2428   VLOG_CALL(PARAM(elem_count), PARAM(x), PARAM(incx), PARAM(result));
2429 
2430   ThenBlasImpl<uint64, const DeviceMemory<double> &, int, DeviceMemory<int> *>
2431       impl;
2432   return impl(this, &blas::BlasSupport::DoBlasIamax, elem_count, x, incx,
2433               result);
2434 }
2435 
ThenBlasIamax(uint64 elem_count,const DeviceMemory<std::complex<float>> & x,int incx,DeviceMemory<int> * result)2436 Stream &Stream::ThenBlasIamax(uint64 elem_count,
2437                               const DeviceMemory<std::complex<float>> &x,
2438                               int incx, DeviceMemory<int> *result) {
2439   VLOG_CALL(PARAM(elem_count), PARAM(x), PARAM(incx), PARAM(result));
2440 
2441   ThenBlasImpl<uint64, const DeviceMemory<std::complex<float>> &, int,
2442                DeviceMemory<int> *> impl;
2443   return impl(this, &blas::BlasSupport::DoBlasIamax, elem_count, x, incx,
2444               result);
2445 }
2446 
ThenBlasIamax(uint64 elem_count,const DeviceMemory<std::complex<double>> & x,int incx,DeviceMemory<int> * result)2447 Stream &Stream::ThenBlasIamax(uint64 elem_count,
2448                               const DeviceMemory<std::complex<double>> &x,
2449                               int incx, DeviceMemory<int> *result) {
2450   VLOG_CALL(PARAM(elem_count), PARAM(x), PARAM(incx), PARAM(result));
2451 
2452   ThenBlasImpl<uint64, const DeviceMemory<std::complex<double>> &, int,
2453                DeviceMemory<int> *> impl;
2454   return impl(this, &blas::BlasSupport::DoBlasIamax, elem_count, x, incx,
2455               result);
2456 }
2457 
ThenBlasIamin(uint64 elem_count,const DeviceMemory<float> & x,int incx,DeviceMemory<int> * result)2458 Stream &Stream::ThenBlasIamin(uint64 elem_count, const DeviceMemory<float> &x,
2459                               int incx, DeviceMemory<int> *result) {
2460   VLOG_CALL(PARAM(elem_count), PARAM(x), PARAM(incx), PARAM(result));
2461 
2462   ThenBlasImpl<uint64, const DeviceMemory<float> &, int, DeviceMemory<int> *>
2463       impl;
2464   return impl(this, &blas::BlasSupport::DoBlasIamin, elem_count, x, incx,
2465               result);
2466 }
2467 
ThenBlasIamin(uint64 elem_count,const DeviceMemory<double> & x,int incx,DeviceMemory<int> * result)2468 Stream &Stream::ThenBlasIamin(uint64 elem_count, const DeviceMemory<double> &x,
2469                               int incx, DeviceMemory<int> *result) {
2470   VLOG_CALL(PARAM(elem_count), PARAM(x), PARAM(incx), PARAM(result));
2471 
2472   ThenBlasImpl<uint64, const DeviceMemory<double> &, int, DeviceMemory<int> *>
2473       impl;
2474   return impl(this, &blas::BlasSupport::DoBlasIamin, elem_count, x, incx,
2475               result);
2476 }
2477 
ThenBlasIamin(uint64 elem_count,const DeviceMemory<std::complex<float>> & x,int incx,DeviceMemory<int> * result)2478 Stream &Stream::ThenBlasIamin(uint64 elem_count,
2479                               const DeviceMemory<std::complex<float>> &x,
2480                               int incx, DeviceMemory<int> *result) {
2481   VLOG_CALL(PARAM(elem_count), PARAM(x), PARAM(incx), PARAM(result));
2482 
2483   ThenBlasImpl<uint64, const DeviceMemory<std::complex<float>> &, int,
2484                DeviceMemory<int> *> impl;
2485   return impl(this, &blas::BlasSupport::DoBlasIamin, elem_count, x, incx,
2486               result);
2487 }
2488 
ThenBlasIamin(uint64 elem_count,const DeviceMemory<std::complex<double>> & x,int incx,DeviceMemory<int> * result)2489 Stream &Stream::ThenBlasIamin(uint64 elem_count,
2490                               const DeviceMemory<std::complex<double>> &x,
2491                               int incx, DeviceMemory<int> *result) {
2492   VLOG_CALL(PARAM(elem_count), PARAM(x), PARAM(incx), PARAM(result));
2493 
2494   ThenBlasImpl<uint64, const DeviceMemory<std::complex<double>> &, int,
2495                DeviceMemory<int> *> impl;
2496   return impl(this, &blas::BlasSupport::DoBlasIamin, elem_count, x, incx,
2497               result);
2498 }
2499 
ThenBlasGbmv(blas::Transpose trans,uint64 m,uint64 n,uint64 kl,uint64 ku,float alpha,const DeviceMemory<float> & a,int lda,const DeviceMemory<float> & x,int incx,float beta,DeviceMemory<float> * y,int incy)2500 Stream &Stream::ThenBlasGbmv(blas::Transpose trans, uint64 m, uint64 n,
2501                              uint64 kl, uint64 ku, float alpha,
2502                              const DeviceMemory<float> &a, int lda,
2503                              const DeviceMemory<float> &x, int incx, float beta,
2504                              DeviceMemory<float> *y, int incy) {
2505   VLOG_CALL(PARAM(trans), PARAM(m), PARAM(n), PARAM(kl), PARAM(ku),
2506             PARAM(alpha), PARAM(a), PARAM(lda), PARAM(x), PARAM(incx),
2507             PARAM(beta), PARAM(y), PARAM(incy));
2508 
2509   ThenBlasImpl<blas::Transpose, uint64, uint64, uint64, uint64, float,
2510                const DeviceMemory<float> &, int, const DeviceMemory<float> &,
2511                int, float, DeviceMemory<float> *, int> impl;
2512   return impl(this, &blas::BlasSupport::DoBlasGbmv, trans, m, n, kl, ku, alpha,
2513               a, lda, x, incx, beta, y, incy);
2514 }
2515 
ThenBlasGbmv(blas::Transpose trans,uint64 m,uint64 n,uint64 kl,uint64 ku,double alpha,const DeviceMemory<double> & a,int lda,const DeviceMemory<double> & x,int incx,double beta,DeviceMemory<double> * y,int incy)2516 Stream &Stream::ThenBlasGbmv(blas::Transpose trans, uint64 m, uint64 n,
2517                              uint64 kl, uint64 ku, double alpha,
2518                              const DeviceMemory<double> &a, int lda,
2519                              const DeviceMemory<double> &x, int incx,
2520                              double beta, DeviceMemory<double> *y, int incy) {
2521   VLOG_CALL(PARAM(trans), PARAM(m), PARAM(n), PARAM(kl), PARAM(ku),
2522             PARAM(alpha), PARAM(a), PARAM(lda), PARAM(x), PARAM(incx),
2523             PARAM(beta), PARAM(y), PARAM(incy));
2524 
2525   ThenBlasImpl<blas::Transpose, uint64, uint64, uint64, uint64, double,
2526                const DeviceMemory<double> &, int, const DeviceMemory<double> &,
2527                int, double, DeviceMemory<double> *, int> impl;
2528   return impl(this, &blas::BlasSupport::DoBlasGbmv, trans, m, n, kl, ku, alpha,
2529               a, lda, x, incx, beta, y, incy);
2530 }
2531 
ThenBlasGbmv(blas::Transpose trans,uint64 m,uint64 n,uint64 kl,uint64 ku,std::complex<float> alpha,const DeviceMemory<std::complex<float>> & a,int lda,const DeviceMemory<std::complex<float>> & x,int incx,std::complex<float> beta,DeviceMemory<std::complex<float>> * y,int incy)2532 Stream &Stream::ThenBlasGbmv(blas::Transpose trans, uint64 m, uint64 n,
2533                              uint64 kl, uint64 ku, std::complex<float> alpha,
2534                              const DeviceMemory<std::complex<float>> &a,
2535                              int lda,
2536                              const DeviceMemory<std::complex<float>> &x,
2537                              int incx, std::complex<float> beta,
2538                              DeviceMemory<std::complex<float>> *y, int incy) {
2539   VLOG_CALL(PARAM(trans), PARAM(m), PARAM(n), PARAM(kl), PARAM(ku),
2540             PARAM(alpha), PARAM(a), PARAM(lda), PARAM(x), PARAM(incx),
2541             PARAM(beta), PARAM(y), PARAM(incy));
2542 
2543   ThenBlasImpl<blas::Transpose, uint64, uint64, uint64, uint64,
2544                std::complex<float>, const DeviceMemory<std::complex<float>> &,
2545                int, const DeviceMemory<std::complex<float>> &, int,
2546                std::complex<float>, DeviceMemory<std::complex<float>> *,
2547                int> impl;
2548   return impl(this, &blas::BlasSupport::DoBlasGbmv, trans, m, n, kl, ku, alpha,
2549               a, lda, x, incx, beta, y, incy);
2550 }
2551 
ThenBlasGbmv(blas::Transpose trans,uint64 m,uint64 n,uint64 kl,uint64 ku,std::complex<double> alpha,const DeviceMemory<std::complex<double>> & a,int lda,const DeviceMemory<std::complex<double>> & x,int incx,std::complex<double> beta,DeviceMemory<std::complex<double>> * y,int incy)2552 Stream &Stream::ThenBlasGbmv(blas::Transpose trans, uint64 m, uint64 n,
2553                              uint64 kl, uint64 ku, std::complex<double> alpha,
2554                              const DeviceMemory<std::complex<double>> &a,
2555                              int lda,
2556                              const DeviceMemory<std::complex<double>> &x,
2557                              int incx, std::complex<double> beta,
2558                              DeviceMemory<std::complex<double>> *y, int incy) {
2559   VLOG_CALL(PARAM(trans), PARAM(m), PARAM(n), PARAM(kl), PARAM(ku),
2560             PARAM(alpha), PARAM(a), PARAM(lda), PARAM(x), PARAM(incx),
2561             PARAM(beta), PARAM(y), PARAM(incy));
2562 
2563   ThenBlasImpl<blas::Transpose, uint64, uint64, uint64, uint64,
2564                std::complex<double>, const DeviceMemory<std::complex<double>> &,
2565                int, const DeviceMemory<std::complex<double>> &, int,
2566                std::complex<double>, DeviceMemory<std::complex<double>> *,
2567                int> impl;
2568   return impl(this, &blas::BlasSupport::DoBlasGbmv, trans, m, n, kl, ku, alpha,
2569               a, lda, x, incx, beta, y, incy);
2570 }
2571 
ThenBlasGemv(blas::Transpose trans,uint64 m,uint64 n,float alpha,const DeviceMemory<float> & a,int lda,const DeviceMemory<float> & x,int incx,float beta,DeviceMemory<float> * y,int incy)2572 Stream &Stream::ThenBlasGemv(blas::Transpose trans, uint64 m, uint64 n,
2573                              float alpha, const DeviceMemory<float> &a, int lda,
2574                              const DeviceMemory<float> &x, int incx, float beta,
2575                              DeviceMemory<float> *y, int incy) {
2576   VLOG_CALL(PARAM(trans), PARAM(m), PARAM(n), PARAM(alpha), PARAM(a),
2577             PARAM(lda), PARAM(x), PARAM(incx), PARAM(beta), PARAM(y),
2578             PARAM(incy));
2579 
2580   ThenBlasImpl<blas::Transpose, uint64, uint64, float,
2581                const DeviceMemory<float> &, int, const DeviceMemory<float> &,
2582                int, float, DeviceMemory<float> *, int> impl;
2583   return impl(this, &blas::BlasSupport::DoBlasGemv, trans, m, n, alpha, a, lda,
2584               x, incx, beta, y, incy);
2585 }
2586 
ThenBlasGemv(blas::Transpose trans,uint64 m,uint64 n,double alpha,const DeviceMemory<double> & a,int lda,const DeviceMemory<double> & x,int incx,double beta,DeviceMemory<double> * y,int incy)2587 Stream &Stream::ThenBlasGemv(blas::Transpose trans, uint64 m, uint64 n,
2588                              double alpha, const DeviceMemory<double> &a,
2589                              int lda, const DeviceMemory<double> &x, int incx,
2590                              double beta, DeviceMemory<double> *y, int incy) {
2591   VLOG_CALL(PARAM(trans), PARAM(m), PARAM(n), PARAM(alpha), PARAM(a),
2592             PARAM(lda), PARAM(x), PARAM(incx), PARAM(beta), PARAM(y),
2593             PARAM(incy));
2594 
2595   ThenBlasImpl<blas::Transpose, uint64, uint64, double,
2596                const DeviceMemory<double> &, int, const DeviceMemory<double> &,
2597                int, double, DeviceMemory<double> *, int> impl;
2598   return impl(this, &blas::BlasSupport::DoBlasGemv, trans, m, n, alpha, a, lda,
2599               x, incx, beta, y, incy);
2600 }
2601 
ThenBlasGemv(blas::Transpose trans,uint64 m,uint64 n,std::complex<float> alpha,const DeviceMemory<std::complex<float>> & a,int lda,const DeviceMemory<std::complex<float>> & x,int incx,std::complex<float> beta,DeviceMemory<std::complex<float>> * y,int incy)2602 Stream &Stream::ThenBlasGemv(blas::Transpose trans, uint64 m, uint64 n,
2603                              std::complex<float> alpha,
2604                              const DeviceMemory<std::complex<float>> &a,
2605                              int lda,
2606                              const DeviceMemory<std::complex<float>> &x,
2607                              int incx, std::complex<float> beta,
2608                              DeviceMemory<std::complex<float>> *y, int incy) {
2609   VLOG_CALL(PARAM(trans), PARAM(m), PARAM(n), PARAM(alpha), PARAM(a),
2610             PARAM(lda), PARAM(x), PARAM(incx), PARAM(beta), PARAM(y),
2611             PARAM(incy));
2612 
2613   ThenBlasImpl<blas::Transpose, uint64, uint64, std::complex<float>,
2614                const DeviceMemory<std::complex<float>> &, int,
2615                const DeviceMemory<std::complex<float>> &, int,
2616                std::complex<float>, DeviceMemory<std::complex<float>> *,
2617                int> impl;
2618   return impl(this, &blas::BlasSupport::DoBlasGemv, trans, m, n, alpha, a, lda,
2619               x, incx, beta, y, incy);
2620 }
2621 
ThenBlasGemv(blas::Transpose trans,uint64 m,uint64 n,std::complex<double> alpha,const DeviceMemory<std::complex<double>> & a,int lda,const DeviceMemory<std::complex<double>> & x,int incx,std::complex<double> beta,DeviceMemory<std::complex<double>> * y,int incy)2622 Stream &Stream::ThenBlasGemv(blas::Transpose trans, uint64 m, uint64 n,
2623                              std::complex<double> alpha,
2624                              const DeviceMemory<std::complex<double>> &a,
2625                              int lda,
2626                              const DeviceMemory<std::complex<double>> &x,
2627                              int incx, std::complex<double> beta,
2628                              DeviceMemory<std::complex<double>> *y, int incy) {
2629   VLOG_CALL(PARAM(trans), PARAM(m), PARAM(n), PARAM(alpha), PARAM(a),
2630             PARAM(lda), PARAM(x), PARAM(incx), PARAM(beta), PARAM(y),
2631             PARAM(incy));
2632 
2633   ThenBlasImpl<blas::Transpose, uint64, uint64, std::complex<double>,
2634                const DeviceMemory<std::complex<double>> &, int,
2635                const DeviceMemory<std::complex<double>> &, int,
2636                std::complex<double>, DeviceMemory<std::complex<double>> *,
2637                int> impl;
2638   return impl(this, &blas::BlasSupport::DoBlasGemv, trans, m, n, alpha, a, lda,
2639               x, incx, beta, y, incy);
2640 }
2641 
ThenBlasGer(uint64 m,uint64 n,float alpha,const DeviceMemory<float> & x,int incx,const DeviceMemory<float> & y,int incy,DeviceMemory<float> * a,int lda)2642 Stream &Stream::ThenBlasGer(uint64 m, uint64 n, float alpha,
2643                             const DeviceMemory<float> &x, int incx,
2644                             const DeviceMemory<float> &y, int incy,
2645                             DeviceMemory<float> *a, int lda) {
2646   VLOG_CALL(PARAM(m), PARAM(n), PARAM(alpha), PARAM(x), PARAM(incx), PARAM(y),
2647             PARAM(incy), PARAM(a), PARAM(lda));
2648 
2649   ThenBlasImpl<uint64, uint64, float, const DeviceMemory<float> &, int,
2650                const DeviceMemory<float> &, int, DeviceMemory<float> *,
2651                int> impl;
2652   return impl(this, &blas::BlasSupport::DoBlasGer, m, n, alpha, x, incx, y,
2653               incy, a, lda);
2654 }
2655 
ThenBlasGer(uint64 m,uint64 n,double alpha,const DeviceMemory<double> & x,int incx,const DeviceMemory<double> & y,int incy,DeviceMemory<double> * a,int lda)2656 Stream &Stream::ThenBlasGer(uint64 m, uint64 n, double alpha,
2657                             const DeviceMemory<double> &x, int incx,
2658                             const DeviceMemory<double> &y, int incy,
2659                             DeviceMemory<double> *a, int lda) {
2660   VLOG_CALL(PARAM(m), PARAM(n), PARAM(alpha), PARAM(x), PARAM(incx), PARAM(y),
2661             PARAM(incy), PARAM(a), PARAM(lda));
2662 
2663   ThenBlasImpl<uint64, uint64, double, const DeviceMemory<double> &, int,
2664                const DeviceMemory<double> &, int, DeviceMemory<double> *,
2665                int> impl;
2666   return impl(this, &blas::BlasSupport::DoBlasGer, m, n, alpha, x, incx, y,
2667               incy, a, lda);
2668 }
2669 
ThenBlasGerc(uint64 m,uint64 n,std::complex<float> alpha,const DeviceMemory<std::complex<float>> & x,int incx,const DeviceMemory<std::complex<float>> & y,int incy,DeviceMemory<std::complex<float>> * a,int lda)2670 Stream &Stream::ThenBlasGerc(uint64 m, uint64 n, std::complex<float> alpha,
2671                              const DeviceMemory<std::complex<float>> &x,
2672                              int incx,
2673                              const DeviceMemory<std::complex<float>> &y,
2674                              int incy, DeviceMemory<std::complex<float>> *a,
2675                              int lda) {
2676   VLOG_CALL(PARAM(m), PARAM(n), PARAM(alpha), PARAM(x), PARAM(incx), PARAM(y),
2677             PARAM(incy), PARAM(a), PARAM(lda));
2678 
2679   ThenBlasImpl<uint64, uint64, std::complex<float>,
2680                const DeviceMemory<std::complex<float>> &, int,
2681                const DeviceMemory<std::complex<float>> &, int,
2682                DeviceMemory<std::complex<float>> *, int> impl;
2683   return impl(this, &blas::BlasSupport::DoBlasGerc, m, n, alpha, x, incx, y,
2684               incy, a, lda);
2685 }
2686 
ThenBlasGerc(uint64 m,uint64 n,std::complex<double> alpha,const DeviceMemory<std::complex<double>> & x,int incx,const DeviceMemory<std::complex<double>> & y,int incy,DeviceMemory<std::complex<double>> * a,int lda)2687 Stream &Stream::ThenBlasGerc(uint64 m, uint64 n, std::complex<double> alpha,
2688                              const DeviceMemory<std::complex<double>> &x,
2689                              int incx,
2690                              const DeviceMemory<std::complex<double>> &y,
2691                              int incy, DeviceMemory<std::complex<double>> *a,
2692                              int lda) {
2693   VLOG_CALL(PARAM(m), PARAM(n), PARAM(alpha), PARAM(x), PARAM(incx), PARAM(y),
2694             PARAM(incy), PARAM(a), PARAM(lda));
2695 
2696   ThenBlasImpl<uint64, uint64, std::complex<double>,
2697                const DeviceMemory<std::complex<double>> &, int,
2698                const DeviceMemory<std::complex<double>> &, int,
2699                DeviceMemory<std::complex<double>> *, int> impl;
2700   return impl(this, &blas::BlasSupport::DoBlasGerc, m, n, alpha, x, incx, y,
2701               incy, a, lda);
2702 }
2703 
ThenBlasGeru(uint64 m,uint64 n,std::complex<float> alpha,const DeviceMemory<std::complex<float>> & x,int incx,const DeviceMemory<std::complex<float>> & y,int incy,DeviceMemory<std::complex<float>> * a,int lda)2704 Stream &Stream::ThenBlasGeru(uint64 m, uint64 n, std::complex<float> alpha,
2705                              const DeviceMemory<std::complex<float>> &x,
2706                              int incx,
2707                              const DeviceMemory<std::complex<float>> &y,
2708                              int incy, DeviceMemory<std::complex<float>> *a,
2709                              int lda) {
2710   VLOG_CALL(PARAM(m), PARAM(n), PARAM(alpha), PARAM(x), PARAM(incx), PARAM(y),
2711             PARAM(incy), PARAM(a), PARAM(lda));
2712 
2713   ThenBlasImpl<uint64, uint64, std::complex<float>,
2714                const DeviceMemory<std::complex<float>> &, int,
2715                const DeviceMemory<std::complex<float>> &, int,
2716                DeviceMemory<std::complex<float>> *, int> impl;
2717   return impl(this, &blas::BlasSupport::DoBlasGeru, m, n, alpha, x, incx, y,
2718               incy, a, lda);
2719 }
2720 
ThenBlasGeru(uint64 m,uint64 n,std::complex<double> alpha,const DeviceMemory<std::complex<double>> & x,int incx,const DeviceMemory<std::complex<double>> & y,int incy,DeviceMemory<std::complex<double>> * a,int lda)2721 Stream &Stream::ThenBlasGeru(uint64 m, uint64 n, std::complex<double> alpha,
2722                              const DeviceMemory<std::complex<double>> &x,
2723                              int incx,
2724                              const DeviceMemory<std::complex<double>> &y,
2725                              int incy, DeviceMemory<std::complex<double>> *a,
2726                              int lda) {
2727   VLOG_CALL(PARAM(m), PARAM(n), PARAM(alpha), PARAM(x), PARAM(incx), PARAM(y),
2728             PARAM(incy), PARAM(a), PARAM(lda));
2729 
2730   ThenBlasImpl<uint64, uint64, std::complex<double>,
2731                const DeviceMemory<std::complex<double>> &, int,
2732                const DeviceMemory<std::complex<double>> &, int,
2733                DeviceMemory<std::complex<double>> *, int> impl;
2734   return impl(this, &blas::BlasSupport::DoBlasGeru, m, n, alpha, x, incx, y,
2735               incy, a, lda);
2736 }
2737 
ThenBlasHbmv(blas::UpperLower uplo,uint64 n,uint64 k,std::complex<float> alpha,const DeviceMemory<std::complex<float>> & a,int lda,const DeviceMemory<std::complex<float>> & x,int incx,std::complex<float> beta,DeviceMemory<std::complex<float>> * y,int incy)2738 Stream &Stream::ThenBlasHbmv(blas::UpperLower uplo, uint64 n, uint64 k,
2739                              std::complex<float> alpha,
2740                              const DeviceMemory<std::complex<float>> &a,
2741                              int lda,
2742                              const DeviceMemory<std::complex<float>> &x,
2743                              int incx, std::complex<float> beta,
2744                              DeviceMemory<std::complex<float>> *y, int incy) {
2745   VLOG_CALL(PARAM(uplo), PARAM(n), PARAM(k), PARAM(alpha), PARAM(a), PARAM(lda),
2746             PARAM(x), PARAM(incx), PARAM(beta), PARAM(y), PARAM(incy));
2747 
2748   ThenBlasImpl<blas::UpperLower, uint64, uint64, std::complex<float>,
2749                const DeviceMemory<std::complex<float>> &, int,
2750                const DeviceMemory<std::complex<float>> &, int,
2751                std::complex<float>, DeviceMemory<std::complex<float>> *,
2752                int> impl;
2753   return impl(this, &blas::BlasSupport::DoBlasHbmv, uplo, n, k, alpha, a, lda,
2754               x, incx, beta, y, incy);
2755 }
2756 
ThenBlasHbmv(blas::UpperLower uplo,uint64 n,uint64 k,std::complex<double> alpha,const DeviceMemory<std::complex<double>> & a,int lda,const DeviceMemory<std::complex<double>> & x,int incx,std::complex<double> beta,DeviceMemory<std::complex<double>> * y,int incy)2757 Stream &Stream::ThenBlasHbmv(blas::UpperLower uplo, uint64 n, uint64 k,
2758                              std::complex<double> alpha,
2759                              const DeviceMemory<std::complex<double>> &a,
2760                              int lda,
2761                              const DeviceMemory<std::complex<double>> &x,
2762                              int incx, std::complex<double> beta,
2763                              DeviceMemory<std::complex<double>> *y, int incy) {
2764   VLOG_CALL(PARAM(uplo), PARAM(n), PARAM(k), PARAM(alpha), PARAM(a), PARAM(lda),
2765             PARAM(x), PARAM(incx), PARAM(beta), PARAM(y), PARAM(incy));
2766 
2767   ThenBlasImpl<blas::UpperLower, uint64, uint64, std::complex<double>,
2768                const DeviceMemory<std::complex<double>> &, int,
2769                const DeviceMemory<std::complex<double>> &, int,
2770                std::complex<double>, DeviceMemory<std::complex<double>> *,
2771                int> impl;
2772   return impl(this, &blas::BlasSupport::DoBlasHbmv, uplo, n, k, alpha, a, lda,
2773               x, incx, beta, y, incy);
2774 }
2775 
ThenBlasHemv(blas::UpperLower uplo,uint64 n,std::complex<float> alpha,const DeviceMemory<std::complex<float>> & a,int lda,const DeviceMemory<std::complex<float>> & x,int incx,std::complex<float> beta,DeviceMemory<std::complex<float>> * y,int incy)2776 Stream &Stream::ThenBlasHemv(blas::UpperLower uplo, uint64 n,
2777                              std::complex<float> alpha,
2778                              const DeviceMemory<std::complex<float>> &a,
2779                              int lda,
2780                              const DeviceMemory<std::complex<float>> &x,
2781                              int incx, std::complex<float> beta,
2782                              DeviceMemory<std::complex<float>> *y, int incy) {
2783   VLOG_CALL(PARAM(uplo), PARAM(n), PARAM(alpha), PARAM(a), PARAM(lda), PARAM(x),
2784             PARAM(incx), PARAM(beta), PARAM(y), PARAM(incy));
2785 
2786   ThenBlasImpl<blas::UpperLower, uint64, std::complex<float>,
2787                const DeviceMemory<std::complex<float>> &, int,
2788                const DeviceMemory<std::complex<float>> &, int,
2789                std::complex<float>, DeviceMemory<std::complex<float>> *,
2790                int> impl;
2791   return impl(this, &blas::BlasSupport::DoBlasHemv, uplo, n, alpha, a, lda, x,
2792               incx, beta, y, incy);
2793 }
2794 
ThenBlasHemv(blas::UpperLower uplo,uint64 n,std::complex<double> alpha,const DeviceMemory<std::complex<double>> & a,int lda,const DeviceMemory<std::complex<double>> & x,int incx,std::complex<double> beta,DeviceMemory<std::complex<double>> * y,int incy)2795 Stream &Stream::ThenBlasHemv(blas::UpperLower uplo, uint64 n,
2796                              std::complex<double> alpha,
2797                              const DeviceMemory<std::complex<double>> &a,
2798                              int lda,
2799                              const DeviceMemory<std::complex<double>> &x,
2800                              int incx, std::complex<double> beta,
2801                              DeviceMemory<std::complex<double>> *y, int incy) {
2802   VLOG_CALL(PARAM(uplo), PARAM(n), PARAM(alpha), PARAM(a), PARAM(lda), PARAM(x),
2803             PARAM(incx), PARAM(beta), PARAM(y), PARAM(incy));
2804 
2805   ThenBlasImpl<blas::UpperLower, uint64, std::complex<double>,
2806                const DeviceMemory<std::complex<double>> &, int,
2807                const DeviceMemory<std::complex<double>> &, int,
2808                std::complex<double>, DeviceMemory<std::complex<double>> *,
2809                int> impl;
2810   return impl(this, &blas::BlasSupport::DoBlasHemv, uplo, n, alpha, a, lda, x,
2811               incx, beta, y, incy);
2812 }
2813 
ThenBlasHer(blas::UpperLower uplo,uint64 n,float alpha,const DeviceMemory<std::complex<float>> & x,int incx,DeviceMemory<std::complex<float>> * a,int lda)2814 Stream &Stream::ThenBlasHer(blas::UpperLower uplo, uint64 n, float alpha,
2815                             const DeviceMemory<std::complex<float>> &x,
2816                             int incx, DeviceMemory<std::complex<float>> *a,
2817                             int lda) {
2818   VLOG_CALL(PARAM(uplo), PARAM(n), PARAM(alpha), PARAM(x), PARAM(incx),
2819             PARAM(a), PARAM(lda));
2820 
2821   ThenBlasImpl<blas::UpperLower, uint64, float,
2822                const DeviceMemory<std::complex<float>> &, int,
2823                DeviceMemory<std::complex<float>> *, int> impl;
2824   return impl(this, &blas::BlasSupport::DoBlasHer, uplo, n, alpha, x, incx, a,
2825               lda);
2826 }
2827 
ThenBlasHer(blas::UpperLower uplo,uint64 n,double alpha,const DeviceMemory<std::complex<double>> & x,int incx,DeviceMemory<std::complex<double>> * a,int lda)2828 Stream &Stream::ThenBlasHer(blas::UpperLower uplo, uint64 n, double alpha,
2829                             const DeviceMemory<std::complex<double>> &x,
2830                             int incx, DeviceMemory<std::complex<double>> *a,
2831                             int lda) {
2832   VLOG_CALL(PARAM(uplo), PARAM(n), PARAM(alpha), PARAM(x), PARAM(incx),
2833             PARAM(a), PARAM(lda));
2834 
2835   ThenBlasImpl<blas::UpperLower, uint64, double,
2836                const DeviceMemory<std::complex<double>> &, int,
2837                DeviceMemory<std::complex<double>> *, int> impl;
2838   return impl(this, &blas::BlasSupport::DoBlasHer, uplo, n, alpha, x, incx, a,
2839               lda);
2840 }
2841 
ThenBlasHer2(blas::UpperLower uplo,uint64 n,std::complex<float> alpha,const DeviceMemory<std::complex<float>> & x,int incx,const DeviceMemory<std::complex<float>> & y,int incy,DeviceMemory<std::complex<float>> * a,int lda)2842 Stream &Stream::ThenBlasHer2(blas::UpperLower uplo, uint64 n,
2843                              std::complex<float> alpha,
2844                              const DeviceMemory<std::complex<float>> &x,
2845                              int incx,
2846                              const DeviceMemory<std::complex<float>> &y,
2847                              int incy, DeviceMemory<std::complex<float>> *a,
2848                              int lda) {
2849   VLOG_CALL(PARAM(uplo), PARAM(n), PARAM(alpha), PARAM(x), PARAM(incx),
2850             PARAM(y), PARAM(incy), PARAM(a), PARAM(lda));
2851 
2852   ThenBlasImpl<blas::UpperLower, uint64, std::complex<float>,
2853                const DeviceMemory<std::complex<float>> &, int,
2854                const DeviceMemory<std::complex<float>> &, int,
2855                DeviceMemory<std::complex<float>> *, int> impl;
2856   return impl(this, &blas::BlasSupport::DoBlasHer2, uplo, n, alpha, x, incx, y,
2857               incy, a, lda);
2858 }
2859 
ThenBlasHer2(blas::UpperLower uplo,uint64 n,std::complex<double> alpha,const DeviceMemory<std::complex<double>> & x,int incx,const DeviceMemory<std::complex<double>> & y,int incy,DeviceMemory<std::complex<double>> * a,int lda)2860 Stream &Stream::ThenBlasHer2(blas::UpperLower uplo, uint64 n,
2861                              std::complex<double> alpha,
2862                              const DeviceMemory<std::complex<double>> &x,
2863                              int incx,
2864                              const DeviceMemory<std::complex<double>> &y,
2865                              int incy, DeviceMemory<std::complex<double>> *a,
2866                              int lda) {
2867   VLOG_CALL(PARAM(uplo), PARAM(n), PARAM(alpha), PARAM(x), PARAM(incx),
2868             PARAM(y), PARAM(incy), PARAM(a), PARAM(lda));
2869 
2870   ThenBlasImpl<blas::UpperLower, uint64, std::complex<double>,
2871                const DeviceMemory<std::complex<double>> &, int,
2872                const DeviceMemory<std::complex<double>> &, int,
2873                DeviceMemory<std::complex<double>> *, int> impl;
2874   return impl(this, &blas::BlasSupport::DoBlasHer2, uplo, n, alpha, x, incx, y,
2875               incy, a, lda);
2876 }
2877 
ThenBlasHpmv(blas::UpperLower uplo,uint64 n,std::complex<float> alpha,const DeviceMemory<std::complex<float>> & ap,const DeviceMemory<std::complex<float>> & x,int incx,std::complex<float> beta,DeviceMemory<std::complex<float>> * y,int incy)2878 Stream &Stream::ThenBlasHpmv(blas::UpperLower uplo, uint64 n,
2879                              std::complex<float> alpha,
2880                              const DeviceMemory<std::complex<float>> &ap,
2881                              const DeviceMemory<std::complex<float>> &x,
2882                              int incx, std::complex<float> beta,
2883                              DeviceMemory<std::complex<float>> *y, int incy) {
2884   VLOG_CALL(PARAM(uplo), PARAM(n), PARAM(alpha), PARAM(ap), PARAM(x),
2885             PARAM(incx), PARAM(beta), PARAM(y), PARAM(incy));
2886 
2887   ThenBlasImpl<blas::UpperLower, uint64, std::complex<float>,
2888                const DeviceMemory<std::complex<float>> &,
2889                const DeviceMemory<std::complex<float>> &, int,
2890                std::complex<float>, DeviceMemory<std::complex<float>> *,
2891                int> impl;
2892   return impl(this, &blas::BlasSupport::DoBlasHpmv, uplo, n, alpha, ap, x, incx,
2893               beta, y, incy);
2894 }
2895 
ThenBlasHpmv(blas::UpperLower uplo,uint64 n,std::complex<double> alpha,const DeviceMemory<std::complex<double>> & ap,const DeviceMemory<std::complex<double>> & x,int incx,std::complex<double> beta,DeviceMemory<std::complex<double>> * y,int incy)2896 Stream &Stream::ThenBlasHpmv(blas::UpperLower uplo, uint64 n,
2897                              std::complex<double> alpha,
2898                              const DeviceMemory<std::complex<double>> &ap,
2899                              const DeviceMemory<std::complex<double>> &x,
2900                              int incx, std::complex<double> beta,
2901                              DeviceMemory<std::complex<double>> *y, int incy) {
2902   VLOG_CALL(PARAM(uplo), PARAM(n), PARAM(alpha), PARAM(ap), PARAM(x),
2903             PARAM(incx), PARAM(beta), PARAM(y), PARAM(incy));
2904 
2905   ThenBlasImpl<blas::UpperLower, uint64, std::complex<double>,
2906                const DeviceMemory<std::complex<double>> &,
2907                const DeviceMemory<std::complex<double>> &, int,
2908                std::complex<double>, DeviceMemory<std::complex<double>> *,
2909                int> impl;
2910   return impl(this, &blas::BlasSupport::DoBlasHpmv, uplo, n, alpha, ap, x, incx,
2911               beta, y, incy);
2912 }
2913 
ThenBlasHpr(blas::UpperLower uplo,uint64 n,float alpha,const DeviceMemory<std::complex<float>> & x,int incx,DeviceMemory<std::complex<float>> * ap)2914 Stream &Stream::ThenBlasHpr(blas::UpperLower uplo, uint64 n, float alpha,
2915                             const DeviceMemory<std::complex<float>> &x,
2916                             int incx, DeviceMemory<std::complex<float>> *ap) {
2917   VLOG_CALL(PARAM(uplo), PARAM(n), PARAM(alpha), PARAM(x), PARAM(incx),
2918             PARAM(ap));
2919 
2920   ThenBlasImpl<blas::UpperLower, uint64, float,
2921                const DeviceMemory<std::complex<float>> &, int,
2922                DeviceMemory<std::complex<float>> *> impl;
2923   return impl(this, &blas::BlasSupport::DoBlasHpr, uplo, n, alpha, x, incx, ap);
2924 }
2925 
ThenBlasHpr(blas::UpperLower uplo,uint64 n,double alpha,const DeviceMemory<std::complex<double>> & x,int incx,DeviceMemory<std::complex<double>> * ap)2926 Stream &Stream::ThenBlasHpr(blas::UpperLower uplo, uint64 n, double alpha,
2927                             const DeviceMemory<std::complex<double>> &x,
2928                             int incx, DeviceMemory<std::complex<double>> *ap) {
2929   VLOG_CALL(PARAM(uplo), PARAM(n), PARAM(alpha), PARAM(x), PARAM(incx),
2930             PARAM(ap));
2931 
2932   ThenBlasImpl<blas::UpperLower, uint64, double,
2933                const DeviceMemory<std::complex<double>> &, int,
2934                DeviceMemory<std::complex<double>> *> impl;
2935   return impl(this, &blas::BlasSupport::DoBlasHpr, uplo, n, alpha, x, incx, ap);
2936 }
2937 
ThenBlasHpr2(blas::UpperLower uplo,uint64 n,std::complex<float> alpha,const DeviceMemory<std::complex<float>> & x,int incx,const DeviceMemory<std::complex<float>> & y,int incy,DeviceMemory<std::complex<float>> * ap)2938 Stream &Stream::ThenBlasHpr2(blas::UpperLower uplo, uint64 n,
2939                              std::complex<float> alpha,
2940                              const DeviceMemory<std::complex<float>> &x,
2941                              int incx,
2942                              const DeviceMemory<std::complex<float>> &y,
2943                              int incy, DeviceMemory<std::complex<float>> *ap) {
2944   VLOG_CALL(PARAM(uplo), PARAM(n), PARAM(alpha), PARAM(x), PARAM(incx),
2945             PARAM(y), PARAM(incy), PARAM(ap));
2946 
2947   ThenBlasImpl<blas::UpperLower, uint64, std::complex<float>,
2948                const DeviceMemory<std::complex<float>> &, int,
2949                const DeviceMemory<std::complex<float>> &, int,
2950                DeviceMemory<std::complex<float>> *> impl;
2951   return impl(this, &blas::BlasSupport::DoBlasHpr2, uplo, n, alpha, x, incx, y,
2952               incy, ap);
2953 }
2954 
ThenBlasHpr2(blas::UpperLower uplo,uint64 n,std::complex<double> alpha,const DeviceMemory<std::complex<double>> & x,int incx,const DeviceMemory<std::complex<double>> & y,int incy,DeviceMemory<std::complex<double>> * ap)2955 Stream &Stream::ThenBlasHpr2(blas::UpperLower uplo, uint64 n,
2956                              std::complex<double> alpha,
2957                              const DeviceMemory<std::complex<double>> &x,
2958                              int incx,
2959                              const DeviceMemory<std::complex<double>> &y,
2960                              int incy, DeviceMemory<std::complex<double>> *ap) {
2961   VLOG_CALL(PARAM(uplo), PARAM(n), PARAM(alpha), PARAM(x), PARAM(incx),
2962             PARAM(y), PARAM(incy), PARAM(ap));
2963 
2964   ThenBlasImpl<blas::UpperLower, uint64, std::complex<double>,
2965                const DeviceMemory<std::complex<double>> &, int,
2966                const DeviceMemory<std::complex<double>> &, int,
2967                DeviceMemory<std::complex<double>> *> impl;
2968   return impl(this, &blas::BlasSupport::DoBlasHpr2, uplo, n, alpha, x, incx, y,
2969               incy, ap);
2970 }
2971 
ThenBlasSbmv(blas::UpperLower uplo,uint64 n,uint64 k,float alpha,const DeviceMemory<float> & a,int lda,const DeviceMemory<float> & x,int incx,float beta,DeviceMemory<float> * y,int incy)2972 Stream &Stream::ThenBlasSbmv(blas::UpperLower uplo, uint64 n, uint64 k,
2973                              float alpha, const DeviceMemory<float> &a, int lda,
2974                              const DeviceMemory<float> &x, int incx, float beta,
2975                              DeviceMemory<float> *y, int incy) {
2976   VLOG_CALL(PARAM(uplo), PARAM(n), PARAM(k), PARAM(alpha), PARAM(a), PARAM(lda),
2977             PARAM(x), PARAM(incx), PARAM(beta), PARAM(y), PARAM(incy));
2978 
2979   ThenBlasImpl<blas::UpperLower, uint64, uint64, float,
2980                const DeviceMemory<float> &, int, const DeviceMemory<float> &,
2981                int, float, DeviceMemory<float> *, int> impl;
2982   return impl(this, &blas::BlasSupport::DoBlasSbmv, uplo, n, k, alpha, a, lda,
2983               x, incx, beta, y, incy);
2984 }
2985 
ThenBlasSbmv(blas::UpperLower uplo,uint64 n,uint64 k,double alpha,const DeviceMemory<double> & a,int lda,const DeviceMemory<double> & x,int incx,double beta,DeviceMemory<double> * y,int incy)2986 Stream &Stream::ThenBlasSbmv(blas::UpperLower uplo, uint64 n, uint64 k,
2987                              double alpha, const DeviceMemory<double> &a,
2988                              int lda, const DeviceMemory<double> &x, int incx,
2989                              double beta, DeviceMemory<double> *y, int incy) {
2990   VLOG_CALL(PARAM(uplo), PARAM(n), PARAM(k), PARAM(alpha), PARAM(a), PARAM(lda),
2991             PARAM(x), PARAM(incx), PARAM(beta), PARAM(y), PARAM(incy));
2992 
2993   ThenBlasImpl<blas::UpperLower, uint64, uint64, double,
2994                const DeviceMemory<double> &, int, const DeviceMemory<double> &,
2995                int, double, DeviceMemory<double> *, int> impl;
2996   return impl(this, &blas::BlasSupport::DoBlasSbmv, uplo, n, k, alpha, a, lda,
2997               x, incx, beta, y, incy);
2998 }
2999 
ThenBlasSpmv(blas::UpperLower uplo,uint64 n,float alpha,const DeviceMemory<float> & ap,const DeviceMemory<float> & x,int incx,float beta,DeviceMemory<float> * y,int incy)3000 Stream &Stream::ThenBlasSpmv(blas::UpperLower uplo, uint64 n, float alpha,
3001                              const DeviceMemory<float> &ap,
3002                              const DeviceMemory<float> &x, int incx, float beta,
3003                              DeviceMemory<float> *y, int incy) {
3004   VLOG_CALL(PARAM(uplo), PARAM(n), PARAM(alpha), PARAM(ap), PARAM(x),
3005             PARAM(incx), PARAM(beta), PARAM(y), PARAM(incy));
3006 
3007   ThenBlasImpl<blas::UpperLower, uint64, float, const DeviceMemory<float> &,
3008                const DeviceMemory<float> &, int, float, DeviceMemory<float> *,
3009                int> impl;
3010   return impl(this, &blas::BlasSupport::DoBlasSpmv, uplo, n, alpha, ap, x, incx,
3011               beta, y, incy);
3012 }
3013 
ThenBlasSpmv(blas::UpperLower uplo,uint64 n,double alpha,const DeviceMemory<double> & ap,const DeviceMemory<double> & x,int incx,double beta,DeviceMemory<double> * y,int incy)3014 Stream &Stream::ThenBlasSpmv(blas::UpperLower uplo, uint64 n, double alpha,
3015                              const DeviceMemory<double> &ap,
3016                              const DeviceMemory<double> &x, int incx,
3017                              double beta, DeviceMemory<double> *y, int incy) {
3018   VLOG_CALL(PARAM(uplo), PARAM(n), PARAM(alpha), PARAM(ap), PARAM(x),
3019             PARAM(incx), PARAM(beta), PARAM(y), PARAM(incy));
3020 
3021   ThenBlasImpl<blas::UpperLower, uint64, double, const DeviceMemory<double> &,
3022                const DeviceMemory<double> &, int, double,
3023                DeviceMemory<double> *, int> impl;
3024   return impl(this, &blas::BlasSupport::DoBlasSpmv, uplo, n, alpha, ap, x, incx,
3025               beta, y, incy);
3026 }
3027 
ThenBlasSpr(blas::UpperLower uplo,uint64 n,float alpha,const DeviceMemory<float> & x,int incx,DeviceMemory<float> * ap)3028 Stream &Stream::ThenBlasSpr(blas::UpperLower uplo, uint64 n, float alpha,
3029                             const DeviceMemory<float> &x, int incx,
3030                             DeviceMemory<float> *ap) {
3031   VLOG_CALL(PARAM(uplo), PARAM(n), PARAM(alpha), PARAM(x), PARAM(incx),
3032             PARAM(ap));
3033 
3034   ThenBlasImpl<blas::UpperLower, uint64, float, const DeviceMemory<float> &,
3035                int, DeviceMemory<float> *> impl;
3036   return impl(this, &blas::BlasSupport::DoBlasSpr, uplo, n, alpha, x, incx, ap);
3037 }
3038 
ThenBlasSpr(blas::UpperLower uplo,uint64 n,double alpha,const DeviceMemory<double> & x,int incx,DeviceMemory<double> * ap)3039 Stream &Stream::ThenBlasSpr(blas::UpperLower uplo, uint64 n, double alpha,
3040                             const DeviceMemory<double> &x, int incx,
3041                             DeviceMemory<double> *ap) {
3042   VLOG_CALL(PARAM(uplo), PARAM(n), PARAM(alpha), PARAM(x), PARAM(incx),
3043             PARAM(ap));
3044 
3045   ThenBlasImpl<blas::UpperLower, uint64, double, const DeviceMemory<double> &,
3046                int, DeviceMemory<double> *> impl;
3047   return impl(this, &blas::BlasSupport::DoBlasSpr, uplo, n, alpha, x, incx, ap);
3048 }
3049 
ThenBlasSpr2(blas::UpperLower uplo,uint64 n,float alpha,const DeviceMemory<float> & x,int incx,const DeviceMemory<float> & y,int incy,DeviceMemory<float> * ap)3050 Stream &Stream::ThenBlasSpr2(blas::UpperLower uplo, uint64 n, float alpha,
3051                              const DeviceMemory<float> &x, int incx,
3052                              const DeviceMemory<float> &y, int incy,
3053                              DeviceMemory<float> *ap) {
3054   VLOG_CALL(PARAM(uplo), PARAM(n), PARAM(alpha), PARAM(x), PARAM(incx),
3055             PARAM(y), PARAM(incy), PARAM(ap));
3056 
3057   ThenBlasImpl<blas::UpperLower, uint64, float, const DeviceMemory<float> &,
3058                int, const DeviceMemory<float> &, int,
3059                DeviceMemory<float> *> impl;
3060   return impl(this, &blas::BlasSupport::DoBlasSpr2, uplo, n, alpha, x, incx, y,
3061               incy, ap);
3062 }
3063 
ThenBlasSpr2(blas::UpperLower uplo,uint64 n,double alpha,const DeviceMemory<double> & x,int incx,const DeviceMemory<double> & y,int incy,DeviceMemory<double> * ap)3064 Stream &Stream::ThenBlasSpr2(blas::UpperLower uplo, uint64 n, double alpha,
3065                              const DeviceMemory<double> &x, int incx,
3066                              const DeviceMemory<double> &y, int incy,
3067                              DeviceMemory<double> *ap) {
3068   VLOG_CALL(PARAM(uplo), PARAM(n), PARAM(alpha), PARAM(x), PARAM(incx),
3069             PARAM(y), PARAM(incy), PARAM(ap));
3070 
3071   ThenBlasImpl<blas::UpperLower, uint64, double, const DeviceMemory<double> &,
3072                int, const DeviceMemory<double> &, int,
3073                DeviceMemory<double> *> impl;
3074   return impl(this, &blas::BlasSupport::DoBlasSpr2, uplo, n, alpha, x, incx, y,
3075               incy, ap);
3076 }
3077 
ThenBlasSymv(blas::UpperLower uplo,uint64 n,float alpha,const DeviceMemory<float> & a,int lda,const DeviceMemory<float> & x,int incx,float beta,DeviceMemory<float> * y,int incy)3078 Stream &Stream::ThenBlasSymv(blas::UpperLower uplo, uint64 n, float alpha,
3079                              const DeviceMemory<float> &a, int lda,
3080                              const DeviceMemory<float> &x, int incx, float beta,
3081                              DeviceMemory<float> *y, int incy) {
3082   VLOG_CALL(PARAM(uplo), PARAM(n), PARAM(alpha), PARAM(a), PARAM(lda), PARAM(x),
3083             PARAM(incx), PARAM(beta), PARAM(y), PARAM(incy));
3084 
3085   ThenBlasImpl<blas::UpperLower, uint64, float, const DeviceMemory<float> &,
3086                int, const DeviceMemory<float> &, int, float,
3087                DeviceMemory<float> *, int> impl;
3088   return impl(this, &blas::BlasSupport::DoBlasSymv, uplo, n, alpha, a, lda, x,
3089               incx, beta, y, incy);
3090 }
3091 
ThenBlasSymv(blas::UpperLower uplo,uint64 n,double alpha,const DeviceMemory<double> & a,int lda,const DeviceMemory<double> & x,int incx,double beta,DeviceMemory<double> * y,int incy)3092 Stream &Stream::ThenBlasSymv(blas::UpperLower uplo, uint64 n, double alpha,
3093                              const DeviceMemory<double> &a, int lda,
3094                              const DeviceMemory<double> &x, int incx,
3095                              double beta, DeviceMemory<double> *y, int incy) {
3096   VLOG_CALL(PARAM(uplo), PARAM(n), PARAM(alpha), PARAM(a), PARAM(lda), PARAM(x),
3097             PARAM(incx), PARAM(beta), PARAM(y), PARAM(incy));
3098 
3099   ThenBlasImpl<blas::UpperLower, uint64, double, const DeviceMemory<double> &,
3100                int, const DeviceMemory<double> &, int, double,
3101                DeviceMemory<double> *, int> impl;
3102   return impl(this, &blas::BlasSupport::DoBlasSymv, uplo, n, alpha, a, lda, x,
3103               incx, beta, y, incy);
3104 }
3105 
ThenBlasSyr(blas::UpperLower uplo,uint64 n,float alpha,const DeviceMemory<float> & x,int incx,DeviceMemory<float> * a,int lda)3106 Stream &Stream::ThenBlasSyr(blas::UpperLower uplo, uint64 n, float alpha,
3107                             const DeviceMemory<float> &x, int incx,
3108                             DeviceMemory<float> *a, int lda) {
3109   VLOG_CALL(PARAM(uplo), PARAM(n), PARAM(alpha), PARAM(x), PARAM(incx),
3110             PARAM(a), PARAM(lda));
3111 
3112   ThenBlasImpl<blas::UpperLower, uint64, float, const DeviceMemory<float> &,
3113                int, DeviceMemory<float> *, int> impl;
3114   return impl(this, &blas::BlasSupport::DoBlasSyr, uplo, n, alpha, x, incx, a,
3115               lda);
3116 }
3117 
ThenBlasSyr(blas::UpperLower uplo,uint64 n,double alpha,const DeviceMemory<double> & x,int incx,DeviceMemory<double> * a,int lda)3118 Stream &Stream::ThenBlasSyr(blas::UpperLower uplo, uint64 n, double alpha,
3119                             const DeviceMemory<double> &x, int incx,
3120                             DeviceMemory<double> *a, int lda) {
3121   VLOG_CALL(PARAM(uplo), PARAM(n), PARAM(alpha), PARAM(x), PARAM(incx),
3122             PARAM(a), PARAM(lda));
3123 
3124   ThenBlasImpl<blas::UpperLower, uint64, double, const DeviceMemory<double> &,
3125                int, DeviceMemory<double> *, int> impl;
3126   return impl(this, &blas::BlasSupport::DoBlasSyr, uplo, n, alpha, x, incx, a,
3127               lda);
3128 }
3129 
ThenBlasSyr2(blas::UpperLower uplo,uint64 n,float alpha,const DeviceMemory<float> & x,int incx,const DeviceMemory<float> & y,int incy,DeviceMemory<float> * a,int lda)3130 Stream &Stream::ThenBlasSyr2(blas::UpperLower uplo, uint64 n, float alpha,
3131                              const DeviceMemory<float> &x, int incx,
3132                              const DeviceMemory<float> &y, int incy,
3133                              DeviceMemory<float> *a, int lda) {
3134   VLOG_CALL(PARAM(uplo), PARAM(n), PARAM(alpha), PARAM(x), PARAM(incx),
3135             PARAM(y), PARAM(incy), PARAM(a), PARAM(lda));
3136 
3137   ThenBlasImpl<blas::UpperLower, uint64, float, const DeviceMemory<float> &,
3138                int, const DeviceMemory<float> &, int, DeviceMemory<float> *,
3139                int> impl;
3140   return impl(this, &blas::BlasSupport::DoBlasSyr2, uplo, n, alpha, x, incx, y,
3141               incy, a, lda);
3142 }
3143 
ThenBlasSyr2(blas::UpperLower uplo,uint64 n,double alpha,const DeviceMemory<double> & x,int incx,const DeviceMemory<double> & y,int incy,DeviceMemory<double> * a,int lda)3144 Stream &Stream::ThenBlasSyr2(blas::UpperLower uplo, uint64 n, double alpha,
3145                              const DeviceMemory<double> &x, int incx,
3146                              const DeviceMemory<double> &y, int incy,
3147                              DeviceMemory<double> *a, int lda) {
3148   VLOG_CALL(PARAM(uplo), PARAM(n), PARAM(alpha), PARAM(x), PARAM(incx),
3149             PARAM(y), PARAM(incy), PARAM(a), PARAM(lda));
3150 
3151   ThenBlasImpl<blas::UpperLower, uint64, double, const DeviceMemory<double> &,
3152                int, const DeviceMemory<double> &, int, DeviceMemory<double> *,
3153                int> impl;
3154   return impl(this, &blas::BlasSupport::DoBlasSyr2, uplo, n, alpha, x, incx, y,
3155               incy, a, lda);
3156 }
3157 
ThenBlasTbmv(blas::UpperLower uplo,blas::Transpose trans,blas::Diagonal diag,uint64 n,uint64 k,const DeviceMemory<float> & a,int lda,DeviceMemory<float> * x,int incx)3158 Stream &Stream::ThenBlasTbmv(blas::UpperLower uplo, blas::Transpose trans,
3159                              blas::Diagonal diag, uint64 n, uint64 k,
3160                              const DeviceMemory<float> &a, int lda,
3161                              DeviceMemory<float> *x, int incx) {
3162   VLOG_CALL(PARAM(uplo), PARAM(trans), PARAM(diag), PARAM(n), PARAM(k),
3163             PARAM(a), PARAM(lda), PARAM(x), PARAM(incx));
3164 
3165   ThenBlasImpl<blas::UpperLower, blas::Transpose, blas::Diagonal, uint64,
3166                uint64, const DeviceMemory<float> &, int, DeviceMemory<float> *,
3167                int> impl;
3168   return impl(this, &blas::BlasSupport::DoBlasTbmv, uplo, trans, diag, n, k, a,
3169               lda, x, incx);
3170 }
3171 
ThenBlasTbmv(blas::UpperLower uplo,blas::Transpose trans,blas::Diagonal diag,uint64 n,uint64 k,const DeviceMemory<double> & a,int lda,DeviceMemory<double> * x,int incx)3172 Stream &Stream::ThenBlasTbmv(blas::UpperLower uplo, blas::Transpose trans,
3173                              blas::Diagonal diag, uint64 n, uint64 k,
3174                              const DeviceMemory<double> &a, int lda,
3175                              DeviceMemory<double> *x, int incx) {
3176   VLOG_CALL(PARAM(uplo), PARAM(trans), PARAM(diag), PARAM(n), PARAM(k),
3177             PARAM(a), PARAM(lda), PARAM(x), PARAM(incx));
3178 
3179   ThenBlasImpl<blas::UpperLower, blas::Transpose, blas::Diagonal, uint64,
3180                uint64, const DeviceMemory<double> &, int,
3181                DeviceMemory<double> *, int> impl;
3182   return impl(this, &blas::BlasSupport::DoBlasTbmv, uplo, trans, diag, n, k, a,
3183               lda, x, incx);
3184 }
3185 
ThenBlasTbmv(blas::UpperLower uplo,blas::Transpose trans,blas::Diagonal diag,uint64 n,uint64 k,const DeviceMemory<std::complex<float>> & a,int lda,DeviceMemory<std::complex<float>> * x,int incx)3186 Stream &Stream::ThenBlasTbmv(blas::UpperLower uplo, blas::Transpose trans,
3187                              blas::Diagonal diag, uint64 n, uint64 k,
3188                              const DeviceMemory<std::complex<float>> &a,
3189                              int lda, DeviceMemory<std::complex<float>> *x,
3190                              int incx) {
3191   VLOG_CALL(PARAM(uplo), PARAM(trans), PARAM(diag), PARAM(n), PARAM(k),
3192             PARAM(a), PARAM(lda), PARAM(x), PARAM(incx));
3193 
3194   ThenBlasImpl<blas::UpperLower, blas::Transpose, blas::Diagonal, uint64,
3195                uint64, const DeviceMemory<std::complex<float>> &, int,
3196                DeviceMemory<std::complex<float>> *, int> impl;
3197   return impl(this, &blas::BlasSupport::DoBlasTbmv, uplo, trans, diag, n, k, a,
3198               lda, x, incx);
3199 }
3200 
ThenBlasTbmv(blas::UpperLower uplo,blas::Transpose trans,blas::Diagonal diag,uint64 n,uint64 k,const DeviceMemory<std::complex<double>> & a,int lda,DeviceMemory<std::complex<double>> * x,int incx)3201 Stream &Stream::ThenBlasTbmv(blas::UpperLower uplo, blas::Transpose trans,
3202                              blas::Diagonal diag, uint64 n, uint64 k,
3203                              const DeviceMemory<std::complex<double>> &a,
3204                              int lda, DeviceMemory<std::complex<double>> *x,
3205                              int incx) {
3206   VLOG_CALL(PARAM(uplo), PARAM(trans), PARAM(diag), PARAM(n), PARAM(k),
3207             PARAM(a), PARAM(lda), PARAM(x), PARAM(incx));
3208 
3209   ThenBlasImpl<blas::UpperLower, blas::Transpose, blas::Diagonal, uint64,
3210                uint64, const DeviceMemory<std::complex<double>> &, int,
3211                DeviceMemory<std::complex<double>> *, int> impl;
3212   return impl(this, &blas::BlasSupport::DoBlasTbmv, uplo, trans, diag, n, k, a,
3213               lda, x, incx);
3214 }
3215 
ThenBlasTbsv(blas::UpperLower uplo,blas::Transpose trans,blas::Diagonal diag,uint64 n,uint64 k,const DeviceMemory<float> & a,int lda,DeviceMemory<float> * x,int incx)3216 Stream &Stream::ThenBlasTbsv(blas::UpperLower uplo, blas::Transpose trans,
3217                              blas::Diagonal diag, uint64 n, uint64 k,
3218                              const DeviceMemory<float> &a, int lda,
3219                              DeviceMemory<float> *x, int incx) {
3220   VLOG_CALL(PARAM(uplo), PARAM(trans), PARAM(diag), PARAM(n), PARAM(k),
3221             PARAM(a), PARAM(lda), PARAM(x), PARAM(incx));
3222 
3223   ThenBlasImpl<blas::UpperLower, blas::Transpose, blas::Diagonal, uint64,
3224                uint64, const DeviceMemory<float> &, int, DeviceMemory<float> *,
3225                int> impl;
3226   return impl(this, &blas::BlasSupport::DoBlasTbsv, uplo, trans, diag, n, k, a,
3227               lda, x, incx);
3228 }
3229 
ThenBlasTbsv(blas::UpperLower uplo,blas::Transpose trans,blas::Diagonal diag,uint64 n,uint64 k,const DeviceMemory<double> & a,int lda,DeviceMemory<double> * x,int incx)3230 Stream &Stream::ThenBlasTbsv(blas::UpperLower uplo, blas::Transpose trans,
3231                              blas::Diagonal diag, uint64 n, uint64 k,
3232                              const DeviceMemory<double> &a, int lda,
3233                              DeviceMemory<double> *x, int incx) {
3234   VLOG_CALL(PARAM(uplo), PARAM(trans), PARAM(diag), PARAM(n), PARAM(k),
3235             PARAM(a), PARAM(lda), PARAM(x), PARAM(incx));
3236 
3237   ThenBlasImpl<blas::UpperLower, blas::Transpose, blas::Diagonal, uint64,
3238                uint64, const DeviceMemory<double> &, int,
3239                DeviceMemory<double> *, int> impl;
3240   return impl(this, &blas::BlasSupport::DoBlasTbsv, uplo, trans, diag, n, k, a,
3241               lda, x, incx);
3242 }
3243 
ThenBlasTbsv(blas::UpperLower uplo,blas::Transpose trans,blas::Diagonal diag,uint64 n,uint64 k,const DeviceMemory<std::complex<float>> & a,int lda,DeviceMemory<std::complex<float>> * x,int incx)3244 Stream &Stream::ThenBlasTbsv(blas::UpperLower uplo, blas::Transpose trans,
3245                              blas::Diagonal diag, uint64 n, uint64 k,
3246                              const DeviceMemory<std::complex<float>> &a,
3247                              int lda, DeviceMemory<std::complex<float>> *x,
3248                              int incx) {
3249   VLOG_CALL(PARAM(uplo), PARAM(trans), PARAM(diag), PARAM(n), PARAM(k),
3250             PARAM(a), PARAM(lda), PARAM(x), PARAM(incx));
3251 
3252   ThenBlasImpl<blas::UpperLower, blas::Transpose, blas::Diagonal, uint64,
3253                uint64, const DeviceMemory<std::complex<float>> &, int,
3254                DeviceMemory<std::complex<float>> *, int> impl;
3255   return impl(this, &blas::BlasSupport::DoBlasTbsv, uplo, trans, diag, n, k, a,
3256               lda, x, incx);
3257 }
3258 
ThenBlasTbsv(blas::UpperLower uplo,blas::Transpose trans,blas::Diagonal diag,uint64 n,uint64 k,const DeviceMemory<std::complex<double>> & a,int lda,DeviceMemory<std::complex<double>> * x,int incx)3259 Stream &Stream::ThenBlasTbsv(blas::UpperLower uplo, blas::Transpose trans,
3260                              blas::Diagonal diag, uint64 n, uint64 k,
3261                              const DeviceMemory<std::complex<double>> &a,
3262                              int lda, DeviceMemory<std::complex<double>> *x,
3263                              int incx) {
3264   VLOG_CALL(PARAM(uplo), PARAM(trans), PARAM(diag), PARAM(n), PARAM(k),
3265             PARAM(a), PARAM(lda), PARAM(x), PARAM(incx));
3266 
3267   ThenBlasImpl<blas::UpperLower, blas::Transpose, blas::Diagonal, uint64,
3268                uint64, const DeviceMemory<std::complex<double>> &, int,
3269                DeviceMemory<std::complex<double>> *, int> impl;
3270   return impl(this, &blas::BlasSupport::DoBlasTbsv, uplo, trans, diag, n, k, a,
3271               lda, x, incx);
3272 }
3273 
ThenBlasTpmv(blas::UpperLower uplo,blas::Transpose trans,blas::Diagonal diag,uint64 n,const DeviceMemory<float> & ap,DeviceMemory<float> * x,int incx)3274 Stream &Stream::ThenBlasTpmv(blas::UpperLower uplo, blas::Transpose trans,
3275                              blas::Diagonal diag, uint64 n,
3276                              const DeviceMemory<float> &ap,
3277                              DeviceMemory<float> *x, int incx) {
3278   VLOG_CALL(PARAM(uplo), PARAM(trans), PARAM(diag), PARAM(n), PARAM(ap),
3279             PARAM(x), PARAM(incx));
3280 
3281   ThenBlasImpl<blas::UpperLower, blas::Transpose, blas::Diagonal, uint64,
3282                const DeviceMemory<float> &, DeviceMemory<float> *, int> impl;
3283   return impl(this, &blas::BlasSupport::DoBlasTpmv, uplo, trans, diag, n, ap, x,
3284               incx);
3285 }
3286 
ThenBlasTpmv(blas::UpperLower uplo,blas::Transpose trans,blas::Diagonal diag,uint64 n,const DeviceMemory<double> & ap,DeviceMemory<double> * x,int incx)3287 Stream &Stream::ThenBlasTpmv(blas::UpperLower uplo, blas::Transpose trans,
3288                              blas::Diagonal diag, uint64 n,
3289                              const DeviceMemory<double> &ap,
3290                              DeviceMemory<double> *x, int incx) {
3291   VLOG_CALL(PARAM(uplo), PARAM(trans), PARAM(diag), PARAM(n), PARAM(ap),
3292             PARAM(x), PARAM(incx));
3293 
3294   ThenBlasImpl<blas::UpperLower, blas::Transpose, blas::Diagonal, uint64,
3295                const DeviceMemory<double> &, DeviceMemory<double> *, int> impl;
3296   return impl(this, &blas::BlasSupport::DoBlasTpmv, uplo, trans, diag, n, ap, x,
3297               incx);
3298 }
3299 
ThenBlasTpmv(blas::UpperLower uplo,blas::Transpose trans,blas::Diagonal diag,uint64 n,const DeviceMemory<std::complex<float>> & ap,DeviceMemory<std::complex<float>> * x,int incx)3300 Stream &Stream::ThenBlasTpmv(blas::UpperLower uplo, blas::Transpose trans,
3301                              blas::Diagonal diag, uint64 n,
3302                              const DeviceMemory<std::complex<float>> &ap,
3303                              DeviceMemory<std::complex<float>> *x, int incx) {
3304   VLOG_CALL(PARAM(uplo), PARAM(trans), PARAM(diag), PARAM(n), PARAM(ap),
3305             PARAM(x), PARAM(incx));
3306 
3307   ThenBlasImpl<blas::UpperLower, blas::Transpose, blas::Diagonal, uint64,
3308                const DeviceMemory<std::complex<float>> &,
3309                DeviceMemory<std::complex<float>> *, int> impl;
3310   return impl(this, &blas::BlasSupport::DoBlasTpmv, uplo, trans, diag, n, ap, x,
3311               incx);
3312 }
3313 
ThenBlasTpmv(blas::UpperLower uplo,blas::Transpose trans,blas::Diagonal diag,uint64 n,const DeviceMemory<std::complex<double>> & ap,DeviceMemory<std::complex<double>> * x,int incx)3314 Stream &Stream::ThenBlasTpmv(blas::UpperLower uplo, blas::Transpose trans,
3315                              blas::Diagonal diag, uint64 n,
3316                              const DeviceMemory<std::complex<double>> &ap,
3317                              DeviceMemory<std::complex<double>> *x, int incx) {
3318   VLOG_CALL(PARAM(uplo), PARAM(trans), PARAM(diag), PARAM(n), PARAM(ap),
3319             PARAM(x), PARAM(incx));
3320 
3321   ThenBlasImpl<blas::UpperLower, blas::Transpose, blas::Diagonal, uint64,
3322                const DeviceMemory<std::complex<double>> &,
3323                DeviceMemory<std::complex<double>> *, int> impl;
3324   return impl(this, &blas::BlasSupport::DoBlasTpmv, uplo, trans, diag, n, ap, x,
3325               incx);
3326 }
3327 
ThenBlasTpsv(blas::UpperLower uplo,blas::Transpose trans,blas::Diagonal diag,uint64 n,const DeviceMemory<float> & ap,DeviceMemory<float> * x,int incx)3328 Stream &Stream::ThenBlasTpsv(blas::UpperLower uplo, blas::Transpose trans,
3329                              blas::Diagonal diag, uint64 n,
3330                              const DeviceMemory<float> &ap,
3331                              DeviceMemory<float> *x, int incx) {
3332   VLOG_CALL(PARAM(uplo), PARAM(trans), PARAM(diag), PARAM(n), PARAM(ap),
3333             PARAM(x), PARAM(incx));
3334 
3335   ThenBlasImpl<blas::UpperLower, blas::Transpose, blas::Diagonal, uint64,
3336                const DeviceMemory<float> &, DeviceMemory<float> *, int> impl;
3337   return impl(this, &blas::BlasSupport::DoBlasTpsv, uplo, trans, diag, n, ap, x,
3338               incx);
3339 }
3340 
ThenBlasTpsv(blas::UpperLower uplo,blas::Transpose trans,blas::Diagonal diag,uint64 n,const DeviceMemory<double> & ap,DeviceMemory<double> * x,int incx)3341 Stream &Stream::ThenBlasTpsv(blas::UpperLower uplo, blas::Transpose trans,
3342                              blas::Diagonal diag, uint64 n,
3343                              const DeviceMemory<double> &ap,
3344                              DeviceMemory<double> *x, int incx) {
3345   VLOG_CALL(PARAM(uplo), PARAM(trans), PARAM(diag), PARAM(n), PARAM(ap),
3346             PARAM(x), PARAM(incx));
3347 
3348   ThenBlasImpl<blas::UpperLower, blas::Transpose, blas::Diagonal, uint64,
3349                const DeviceMemory<double> &, DeviceMemory<double> *, int> impl;
3350   return impl(this, &blas::BlasSupport::DoBlasTpsv, uplo, trans, diag, n, ap, x,
3351               incx);
3352 }
3353 
ThenBlasTpsv(blas::UpperLower uplo,blas::Transpose trans,blas::Diagonal diag,uint64 n,const DeviceMemory<std::complex<float>> & ap,DeviceMemory<std::complex<float>> * x,int incx)3354 Stream &Stream::ThenBlasTpsv(blas::UpperLower uplo, blas::Transpose trans,
3355                              blas::Diagonal diag, uint64 n,
3356                              const DeviceMemory<std::complex<float>> &ap,
3357                              DeviceMemory<std::complex<float>> *x, int incx) {
3358   VLOG_CALL(PARAM(uplo), PARAM(trans), PARAM(diag), PARAM(n), PARAM(ap),
3359             PARAM(x), PARAM(incx));
3360 
3361   ThenBlasImpl<blas::UpperLower, blas::Transpose, blas::Diagonal, uint64,
3362                const DeviceMemory<std::complex<float>> &,
3363                DeviceMemory<std::complex<float>> *, int> impl;
3364   return impl(this, &blas::BlasSupport::DoBlasTpsv, uplo, trans, diag, n, ap, x,
3365               incx);
3366 }
3367 
ThenBlasTpsv(blas::UpperLower uplo,blas::Transpose trans,blas::Diagonal diag,uint64 n,const DeviceMemory<std::complex<double>> & ap,DeviceMemory<std::complex<double>> * x,int incx)3368 Stream &Stream::ThenBlasTpsv(blas::UpperLower uplo, blas::Transpose trans,
3369                              blas::Diagonal diag, uint64 n,
3370                              const DeviceMemory<std::complex<double>> &ap,
3371                              DeviceMemory<std::complex<double>> *x, int incx) {
3372   VLOG_CALL(PARAM(uplo), PARAM(trans), PARAM(diag), PARAM(n), PARAM(ap),
3373             PARAM(x), PARAM(incx));
3374 
3375   ThenBlasImpl<blas::UpperLower, blas::Transpose, blas::Diagonal, uint64,
3376                const DeviceMemory<std::complex<double>> &,
3377                DeviceMemory<std::complex<double>> *, int> impl;
3378   return impl(this, &blas::BlasSupport::DoBlasTpsv, uplo, trans, diag, n, ap, x,
3379               incx);
3380 }
3381 
ThenBlasTrmv(blas::UpperLower uplo,blas::Transpose trans,blas::Diagonal diag,uint64 n,const DeviceMemory<float> & a,int lda,DeviceMemory<float> * x,int incx)3382 Stream &Stream::ThenBlasTrmv(blas::UpperLower uplo, blas::Transpose trans,
3383                              blas::Diagonal diag, uint64 n,
3384                              const DeviceMemory<float> &a, int lda,
3385                              DeviceMemory<float> *x, int incx) {
3386   VLOG_CALL(PARAM(uplo), PARAM(trans), PARAM(diag), PARAM(n), PARAM(a),
3387             PARAM(lda), PARAM(x), PARAM(incx));
3388 
3389   ThenBlasImpl<blas::UpperLower, blas::Transpose, blas::Diagonal, uint64,
3390                const DeviceMemory<float> &, int, DeviceMemory<float> *,
3391                int> impl;
3392   return impl(this, &blas::BlasSupport::DoBlasTrmv, uplo, trans, diag, n, a,
3393               lda, x, incx);
3394 }
3395 
ThenBlasTrmv(blas::UpperLower uplo,blas::Transpose trans,blas::Diagonal diag,uint64 n,const DeviceMemory<double> & a,int lda,DeviceMemory<double> * x,int incx)3396 Stream &Stream::ThenBlasTrmv(blas::UpperLower uplo, blas::Transpose trans,
3397                              blas::Diagonal diag, uint64 n,
3398                              const DeviceMemory<double> &a, int lda,
3399                              DeviceMemory<double> *x, int incx) {
3400   VLOG_CALL(PARAM(uplo), PARAM(trans), PARAM(diag), PARAM(n), PARAM(a),
3401             PARAM(lda), PARAM(x), PARAM(incx));
3402 
3403   ThenBlasImpl<blas::UpperLower, blas::Transpose, blas::Diagonal, uint64,
3404                const DeviceMemory<double> &, int, DeviceMemory<double> *,
3405                int> impl;
3406   return impl(this, &blas::BlasSupport::DoBlasTrmv, uplo, trans, diag, n, a,
3407               lda, x, incx);
3408 }
3409 
ThenBlasTrmv(blas::UpperLower uplo,blas::Transpose trans,blas::Diagonal diag,uint64 n,const DeviceMemory<std::complex<float>> & a,int lda,DeviceMemory<std::complex<float>> * x,int incx)3410 Stream &Stream::ThenBlasTrmv(blas::UpperLower uplo, blas::Transpose trans,
3411                              blas::Diagonal diag, uint64 n,
3412                              const DeviceMemory<std::complex<float>> &a,
3413                              int lda, DeviceMemory<std::complex<float>> *x,
3414                              int incx) {
3415   VLOG_CALL(PARAM(uplo), PARAM(trans), PARAM(diag), PARAM(n), PARAM(a),
3416             PARAM(lda), PARAM(x), PARAM(incx));
3417 
3418   ThenBlasImpl<blas::UpperLower, blas::Transpose, blas::Diagonal, uint64,
3419                const DeviceMemory<std::complex<float>> &, int,
3420                DeviceMemory<std::complex<float>> *, int> impl;
3421   return impl(this, &blas::BlasSupport::DoBlasTrmv, uplo, trans, diag, n, a,
3422               lda, x, incx);
3423 }
3424 
ThenBlasTrmv(blas::UpperLower uplo,blas::Transpose trans,blas::Diagonal diag,uint64 n,const DeviceMemory<std::complex<double>> & a,int lda,DeviceMemory<std::complex<double>> * x,int incx)3425 Stream &Stream::ThenBlasTrmv(blas::UpperLower uplo, blas::Transpose trans,
3426                              blas::Diagonal diag, uint64 n,
3427                              const DeviceMemory<std::complex<double>> &a,
3428                              int lda, DeviceMemory<std::complex<double>> *x,
3429                              int incx) {
3430   VLOG_CALL(PARAM(uplo), PARAM(trans), PARAM(diag), PARAM(n), PARAM(a),
3431             PARAM(lda), PARAM(x), PARAM(incx));
3432 
3433   ThenBlasImpl<blas::UpperLower, blas::Transpose, blas::Diagonal, uint64,
3434                const DeviceMemory<std::complex<double>> &, int,
3435                DeviceMemory<std::complex<double>> *, int> impl;
3436   return impl(this, &blas::BlasSupport::DoBlasTrmv, uplo, trans, diag, n, a,
3437               lda, x, incx);
3438 }
3439 
ThenBlasTrsv(blas::UpperLower uplo,blas::Transpose trans,blas::Diagonal diag,uint64 n,const DeviceMemory<float> & a,int lda,DeviceMemory<float> * x,int incx)3440 Stream &Stream::ThenBlasTrsv(blas::UpperLower uplo, blas::Transpose trans,
3441                              blas::Diagonal diag, uint64 n,
3442                              const DeviceMemory<float> &a, int lda,
3443                              DeviceMemory<float> *x, int incx) {
3444   VLOG_CALL(PARAM(uplo), PARAM(trans), PARAM(diag), PARAM(n), PARAM(a),
3445             PARAM(lda), PARAM(x), PARAM(incx));
3446 
3447   ThenBlasImpl<blas::UpperLower, blas::Transpose, blas::Diagonal, uint64,
3448                const DeviceMemory<float> &, int, DeviceMemory<float> *,
3449                int> impl;
3450   return impl(this, &blas::BlasSupport::DoBlasTrsv, uplo, trans, diag, n, a,
3451               lda, x, incx);
3452 }
3453 
ThenBlasTrsv(blas::UpperLower uplo,blas::Transpose trans,blas::Diagonal diag,uint64 n,const DeviceMemory<double> & a,int lda,DeviceMemory<double> * x,int incx)3454 Stream &Stream::ThenBlasTrsv(blas::UpperLower uplo, blas::Transpose trans,
3455                              blas::Diagonal diag, uint64 n,
3456                              const DeviceMemory<double> &a, int lda,
3457                              DeviceMemory<double> *x, int incx) {
3458   VLOG_CALL(PARAM(uplo), PARAM(trans), PARAM(diag), PARAM(n), PARAM(a),
3459             PARAM(lda), PARAM(x), PARAM(incx));
3460 
3461   ThenBlasImpl<blas::UpperLower, blas::Transpose, blas::Diagonal, uint64,
3462                const DeviceMemory<double> &, int, DeviceMemory<double> *,
3463                int> impl;
3464   return impl(this, &blas::BlasSupport::DoBlasTrsv, uplo, trans, diag, n, a,
3465               lda, x, incx);
3466 }
3467 
ThenBlasTrsv(blas::UpperLower uplo,blas::Transpose trans,blas::Diagonal diag,uint64 n,const DeviceMemory<std::complex<float>> & a,int lda,DeviceMemory<std::complex<float>> * x,int incx)3468 Stream &Stream::ThenBlasTrsv(blas::UpperLower uplo, blas::Transpose trans,
3469                              blas::Diagonal diag, uint64 n,
3470                              const DeviceMemory<std::complex<float>> &a,
3471                              int lda, DeviceMemory<std::complex<float>> *x,
3472                              int incx) {
3473   VLOG_CALL(PARAM(uplo), PARAM(trans), PARAM(diag), PARAM(n), PARAM(a),
3474             PARAM(lda), PARAM(x), PARAM(incx));
3475 
3476   ThenBlasImpl<blas::UpperLower, blas::Transpose, blas::Diagonal, uint64,
3477                const DeviceMemory<std::complex<float>> &, int,
3478                DeviceMemory<std::complex<float>> *, int> impl;
3479   return impl(this, &blas::BlasSupport::DoBlasTrsv, uplo, trans, diag, n, a,
3480               lda, x, incx);
3481 }
3482 
ThenBlasTrsv(blas::UpperLower uplo,blas::Transpose trans,blas::Diagonal diag,uint64 n,const DeviceMemory<std::complex<double>> & a,int lda,DeviceMemory<std::complex<double>> * x,int incx)3483 Stream &Stream::ThenBlasTrsv(blas::UpperLower uplo, blas::Transpose trans,
3484                              blas::Diagonal diag, uint64 n,
3485                              const DeviceMemory<std::complex<double>> &a,
3486                              int lda, DeviceMemory<std::complex<double>> *x,
3487                              int incx) {
3488   VLOG_CALL(PARAM(uplo), PARAM(trans), PARAM(diag), PARAM(n), PARAM(a),
3489             PARAM(lda), PARAM(x), PARAM(incx));
3490 
3491   ThenBlasImpl<blas::UpperLower, blas::Transpose, blas::Diagonal, uint64,
3492                const DeviceMemory<std::complex<double>> &, int,
3493                DeviceMemory<std::complex<double>> *, int> impl;
3494   return impl(this, &blas::BlasSupport::DoBlasTrsv, uplo, trans, diag, n, a,
3495               lda, x, incx);
3496 }
3497 
ThenBlasGemm(blas::Transpose transa,blas::Transpose transb,uint64 m,uint64 n,uint64 k,float alpha,const DeviceMemory<Eigen::half> & a,int lda,const DeviceMemory<Eigen::half> & b,int ldb,float beta,DeviceMemory<Eigen::half> * c,int ldc)3498 Stream &Stream::ThenBlasGemm(blas::Transpose transa, blas::Transpose transb,
3499                              uint64 m, uint64 n, uint64 k, float alpha,
3500                              const DeviceMemory<Eigen::half> &a, int lda,
3501                              const DeviceMemory<Eigen::half> &b, int ldb,
3502                              float beta,
3503                              DeviceMemory<Eigen::half> *c, int ldc) {
3504   VLOG_CALL(PARAM(transa), PARAM(transb), PARAM(m), PARAM(n), PARAM(k),
3505             PARAM(alpha), PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb),
3506             PARAM(beta), PARAM(c), PARAM(ldc));
3507 
3508   ThenBlasImpl<blas::Transpose, blas::Transpose, uint64, uint64, uint64, float,
3509                const DeviceMemory<Eigen::half> &, int,
3510                const DeviceMemory<Eigen::half> &, int,
3511                float, DeviceMemory<Eigen::half> *, int> impl;
3512   return impl(this, &blas::BlasSupport::DoBlasGemm, transa, transb, m, n, k,
3513               alpha, a, lda, b, ldb, beta, c, ldc);
3514 }
3515 
ThenBlasGemm(blas::Transpose transa,blas::Transpose transb,uint64 m,uint64 n,uint64 k,float alpha,const DeviceMemory<float> & a,int lda,const DeviceMemory<float> & b,int ldb,float beta,DeviceMemory<float> * c,int ldc)3516 Stream &Stream::ThenBlasGemm(blas::Transpose transa, blas::Transpose transb,
3517                              uint64 m, uint64 n, uint64 k, float alpha,
3518                              const DeviceMemory<float> &a, int lda,
3519                              const DeviceMemory<float> &b, int ldb, float beta,
3520                              DeviceMemory<float> *c, int ldc) {
3521   VLOG_CALL(PARAM(transa), PARAM(transb), PARAM(m), PARAM(n), PARAM(k),
3522             PARAM(alpha), PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb),
3523             PARAM(beta), PARAM(c), PARAM(ldc));
3524 
3525   ThenBlasImpl<blas::Transpose, blas::Transpose, uint64, uint64, uint64, float,
3526                const DeviceMemory<float> &, int, const DeviceMemory<float> &,
3527                int, float, DeviceMemory<float> *, int> impl;
3528   return impl(this, &blas::BlasSupport::DoBlasGemm, transa, transb, m, n, k,
3529               alpha, a, lda, b, ldb, beta, c, ldc);
3530 }
3531 
ThenBlasGemm(blas::Transpose transa,blas::Transpose transb,uint64 m,uint64 n,uint64 k,double alpha,const DeviceMemory<double> & a,int lda,const DeviceMemory<double> & b,int ldb,double beta,DeviceMemory<double> * c,int ldc)3532 Stream &Stream::ThenBlasGemm(blas::Transpose transa, blas::Transpose transb,
3533                              uint64 m, uint64 n, uint64 k, double alpha,
3534                              const DeviceMemory<double> &a, int lda,
3535                              const DeviceMemory<double> &b, int ldb,
3536                              double beta, DeviceMemory<double> *c, int ldc) {
3537   VLOG_CALL(PARAM(transa), PARAM(transb), PARAM(m), PARAM(n), PARAM(k),
3538             PARAM(alpha), PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb),
3539             PARAM(beta), PARAM(c), PARAM(ldc));
3540 
3541   ThenBlasImpl<blas::Transpose, blas::Transpose, uint64, uint64, uint64, double,
3542                const DeviceMemory<double> &, int, const DeviceMemory<double> &,
3543                int, double, DeviceMemory<double> *, int> impl;
3544   return impl(this, &blas::BlasSupport::DoBlasGemm, transa, transb, m, n, k,
3545               alpha, a, lda, b, ldb, beta, c, ldc);
3546 }
3547 
ThenBlasGemm(blas::Transpose transa,blas::Transpose transb,uint64 m,uint64 n,uint64 k,std::complex<float> alpha,const DeviceMemory<std::complex<float>> & a,int lda,const DeviceMemory<std::complex<float>> & b,int ldb,std::complex<float> beta,DeviceMemory<std::complex<float>> * c,int ldc)3548 Stream &Stream::ThenBlasGemm(blas::Transpose transa, blas::Transpose transb,
3549                              uint64 m, uint64 n, uint64 k,
3550                              std::complex<float> alpha,
3551                              const DeviceMemory<std::complex<float>> &a,
3552                              int lda,
3553                              const DeviceMemory<std::complex<float>> &b,
3554                              int ldb, std::complex<float> beta,
3555                              DeviceMemory<std::complex<float>> *c, int ldc) {
3556   VLOG_CALL(PARAM(transa), PARAM(transb), PARAM(m), PARAM(n), PARAM(k),
3557             PARAM(alpha), PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb),
3558             PARAM(beta), PARAM(c), PARAM(ldc));
3559 
3560   ThenBlasImpl<blas::Transpose, blas::Transpose, uint64, uint64, uint64,
3561                std::complex<float>, const DeviceMemory<std::complex<float>> &,
3562                int, const DeviceMemory<std::complex<float>> &, int,
3563                std::complex<float>, DeviceMemory<std::complex<float>> *,
3564                int> impl;
3565   return impl(this, &blas::BlasSupport::DoBlasGemm, transa, transb, m, n, k,
3566               alpha, a, lda, b, ldb, beta, c, ldc);
3567 }
3568 
ThenBlasGemm(blas::Transpose transa,blas::Transpose transb,uint64 m,uint64 n,uint64 k,std::complex<double> alpha,const DeviceMemory<std::complex<double>> & a,int lda,const DeviceMemory<std::complex<double>> & b,int ldb,std::complex<double> beta,DeviceMemory<std::complex<double>> * c,int ldc)3569 Stream &Stream::ThenBlasGemm(blas::Transpose transa, blas::Transpose transb,
3570                              uint64 m, uint64 n, uint64 k,
3571                              std::complex<double> alpha,
3572                              const DeviceMemory<std::complex<double>> &a,
3573                              int lda,
3574                              const DeviceMemory<std::complex<double>> &b,
3575                              int ldb, std::complex<double> beta,
3576                              DeviceMemory<std::complex<double>> *c, int ldc) {
3577   VLOG_CALL(PARAM(transa), PARAM(transb), PARAM(m), PARAM(n), PARAM(k),
3578             PARAM(alpha), PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb),
3579             PARAM(beta), PARAM(c), PARAM(ldc));
3580 
3581   ThenBlasImpl<blas::Transpose, blas::Transpose, uint64, uint64, uint64,
3582                std::complex<double>, const DeviceMemory<std::complex<double>> &,
3583                int, const DeviceMemory<std::complex<double>> &, int,
3584                std::complex<double>, DeviceMemory<std::complex<double>> *,
3585                int> impl;
3586   return impl(this, &blas::BlasSupport::DoBlasGemm, transa, transb, m, n, k,
3587               alpha, a, lda, b, ldb, beta, c, ldc);
3588 }
3589 
3590 namespace {
3591 // Like ThenBlasImpl, except this expects the last argument of blas_func to be a
3592 // blas::ProfileResult*.  This functor doesn't put the stream into an error
3593 // state if the op fails and the profile result is non-null.  Instead, the
3594 // error-ness is returned in the profile result itself.
3595 template <typename... Args>
3596 struct ThenBlasWithProfileImpl {
operator ()stream_executor::__anon3e261ebe0211::ThenBlasWithProfileImpl3597   Stream &operator()(Stream *stream,
3598                      bool (blas::BlasSupport::*blas_func)(
3599                          Stream *, Args..., blas::ProfileResult *),
3600                      Args... args, blas::ProfileResult *profile_result) {
3601     ThenBlasImpl<Args..., blas::ProfileResult *> Runner;
3602     bool record_error = profile_result == nullptr;
3603     return Runner.Run(stream, blas_func, record_error, args..., profile_result);
3604   }
3605 };
3606 }  // anonymous namespace
3607 
ThenBlasGemvWithProfiling(blas::Transpose trans,uint64 m,uint64 n,float alpha,const DeviceMemory<float> & a,int lda,const DeviceMemory<float> & x,int incx,float beta,DeviceMemory<float> * y,int incy,blas::ProfileResult * output_profile_result)3608 Stream &Stream::ThenBlasGemvWithProfiling(
3609     blas::Transpose trans, uint64 m, uint64 n, float alpha,
3610     const DeviceMemory<float> &a, int lda, const DeviceMemory<float> &x,
3611     int incx, float beta, DeviceMemory<float> *y, int incy,
3612     blas::ProfileResult *output_profile_result) {
3613   VLOG_CALL(PARAM(trans), PARAM(m), PARAM(n), PARAM(alpha), PARAM(a),
3614             PARAM(lda), PARAM(x), PARAM(incx), PARAM(beta), PARAM(y),
3615             PARAM(incy));
3616 
3617   ThenBlasWithProfileImpl<
3618       blas::Transpose, uint64, uint64, float, const DeviceMemory<float> &, int,
3619       const DeviceMemory<float> &, int, float, DeviceMemory<float> *, int>
3620       impl;
3621   return impl(this, &blas::BlasSupport::DoBlasGemvWithProfiling, trans, m, n,
3622               alpha, a, lda, x, incx, beta, y, incy, output_profile_result);
3623 }
3624 
ThenBlasGemvWithProfiling(blas::Transpose trans,uint64 m,uint64 n,double alpha,const DeviceMemory<double> & a,int lda,const DeviceMemory<double> & x,int incx,double beta,DeviceMemory<double> * y,int incy,blas::ProfileResult * output_profile_result)3625 Stream &Stream::ThenBlasGemvWithProfiling(
3626     blas::Transpose trans, uint64 m, uint64 n, double alpha,
3627     const DeviceMemory<double> &a, int lda, const DeviceMemory<double> &x,
3628     int incx, double beta, DeviceMemory<double> *y, int incy,
3629     blas::ProfileResult *output_profile_result) {
3630   VLOG_CALL(PARAM(trans), PARAM(m), PARAM(n), PARAM(alpha), PARAM(a),
3631             PARAM(lda), PARAM(x), PARAM(incx), PARAM(beta), PARAM(y),
3632             PARAM(incy));
3633 
3634   ThenBlasWithProfileImpl<blas::Transpose, uint64, uint64, double,
3635                           const DeviceMemory<double> &, int,
3636                           const DeviceMemory<double> &, int, double,
3637                           DeviceMemory<double> *, int>
3638       impl;
3639   return impl(this, &blas::BlasSupport::DoBlasGemvWithProfiling, trans, m, n,
3640               alpha, a, lda, x, incx, beta, y, incy, output_profile_result);
3641 }
3642 
ThenBlasGemvWithProfiling(blas::Transpose trans,uint64 m,uint64 n,std::complex<float> alpha,const DeviceMemory<std::complex<float>> & a,int lda,const DeviceMemory<std::complex<float>> & x,int incx,std::complex<float> beta,DeviceMemory<std::complex<float>> * y,int incy,blas::ProfileResult * output_profile_result)3643 Stream &Stream::ThenBlasGemvWithProfiling(
3644     blas::Transpose trans, uint64 m, uint64 n, std::complex<float> alpha,
3645     const DeviceMemory<std::complex<float>> &a, int lda,
3646     const DeviceMemory<std::complex<float>> &x, int incx,
3647     std::complex<float> beta, DeviceMemory<std::complex<float>> *y, int incy,
3648     blas::ProfileResult *output_profile_result) {
3649   VLOG_CALL(PARAM(trans), PARAM(m), PARAM(n), PARAM(alpha), PARAM(a),
3650             PARAM(lda), PARAM(x), PARAM(incx), PARAM(beta), PARAM(y),
3651             PARAM(incy));
3652 
3653   ThenBlasWithProfileImpl<blas::Transpose, uint64, uint64, std::complex<float>,
3654                           const DeviceMemory<std::complex<float>> &, int,
3655                           const DeviceMemory<std::complex<float>> &, int,
3656                           std::complex<float>,
3657                           DeviceMemory<std::complex<float>> *, int>
3658       impl;
3659   return impl(this, &blas::BlasSupport::DoBlasGemvWithProfiling, trans, m, n,
3660               alpha, a, lda, x, incx, beta, y, incy, output_profile_result);
3661 }
3662 
ThenBlasGemvWithProfiling(blas::Transpose trans,uint64 m,uint64 n,std::complex<double> alpha,const DeviceMemory<std::complex<double>> & a,int lda,const DeviceMemory<std::complex<double>> & x,int incx,std::complex<double> beta,DeviceMemory<std::complex<double>> * y,int incy,blas::ProfileResult * output_profile_result)3663 Stream &Stream::ThenBlasGemvWithProfiling(
3664     blas::Transpose trans, uint64 m, uint64 n, std::complex<double> alpha,
3665     const DeviceMemory<std::complex<double>> &a, int lda,
3666     const DeviceMemory<std::complex<double>> &x, int incx,
3667     std::complex<double> beta, DeviceMemory<std::complex<double>> *y, int incy,
3668     blas::ProfileResult *output_profile_result) {
3669   VLOG_CALL(PARAM(trans), PARAM(m), PARAM(n), PARAM(alpha), PARAM(a),
3670             PARAM(lda), PARAM(x), PARAM(incx), PARAM(beta), PARAM(y),
3671             PARAM(incy));
3672 
3673   ThenBlasWithProfileImpl<blas::Transpose, uint64, uint64, std::complex<double>,
3674                           const DeviceMemory<std::complex<double>> &, int,
3675                           const DeviceMemory<std::complex<double>> &, int,
3676                           std::complex<double>,
3677                           DeviceMemory<std::complex<double>> *, int>
3678       impl;
3679   return impl(this, &blas::BlasSupport::DoBlasGemvWithProfiling, trans, m, n,
3680               alpha, a, lda, x, incx, beta, y, incy, output_profile_result);
3681 }
3682 
ThenBlasGemmWithProfiling(blas::Transpose transa,blas::Transpose transb,uint64 m,uint64 n,uint64 k,float alpha,const DeviceMemory<Eigen::half> & a,int lda,const DeviceMemory<Eigen::half> & b,int ldb,float beta,DeviceMemory<Eigen::half> * c,int ldc,blas::ProfileResult * output_profile_result)3683 Stream &Stream::ThenBlasGemmWithProfiling(
3684     blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n,
3685     uint64 k, float alpha, const DeviceMemory<Eigen::half> &a, int lda,
3686     const DeviceMemory<Eigen::half> &b, int ldb, float beta,
3687     DeviceMemory<Eigen::half> *c, int ldc,
3688     blas::ProfileResult *output_profile_result) {
3689   VLOG_CALL(PARAM(transa), PARAM(transb), PARAM(m), PARAM(n), PARAM(k),
3690             PARAM(alpha), PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb),
3691             PARAM(beta), PARAM(c), PARAM(ldc));
3692 
3693   ThenBlasWithProfileImpl<blas::Transpose, blas::Transpose, uint64, uint64,
3694                           uint64, float, const DeviceMemory<Eigen::half> &, int,
3695                           const DeviceMemory<Eigen::half> &, int, float,
3696                           DeviceMemory<Eigen::half> *, int>
3697       impl;
3698   return impl(this, &blas::BlasSupport::DoBlasGemmWithProfiling, transa, transb,
3699               m, n, k, alpha, a, lda, b, ldb, beta, c, ldc,
3700               output_profile_result);
3701 }
3702 
ThenBlasGemmWithProfiling(blas::Transpose transa,blas::Transpose transb,uint64 m,uint64 n,uint64 k,float alpha,const DeviceMemory<float> & a,int lda,const DeviceMemory<float> & b,int ldb,float beta,DeviceMemory<float> * c,int ldc,blas::ProfileResult * output_profile_result)3703 Stream &Stream::ThenBlasGemmWithProfiling(
3704     blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n,
3705     uint64 k, float alpha, const DeviceMemory<float> &a, int lda,
3706     const DeviceMemory<float> &b, int ldb, float beta, DeviceMemory<float> *c,
3707     int ldc, blas::ProfileResult *output_profile_result) {
3708   VLOG_CALL(PARAM(transa), PARAM(transb), PARAM(m), PARAM(n), PARAM(k),
3709             PARAM(alpha), PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb),
3710             PARAM(beta), PARAM(c), PARAM(ldc));
3711 
3712   ThenBlasWithProfileImpl<blas::Transpose, blas::Transpose, uint64, uint64,
3713                           uint64, float, const DeviceMemory<float> &, int,
3714                           const DeviceMemory<float> &, int, float,
3715                           DeviceMemory<float> *, int>
3716       impl;
3717   return impl(this, &blas::BlasSupport::DoBlasGemmWithProfiling, transa, transb,
3718               m, n, k, alpha, a, lda, b, ldb, beta, c, ldc,
3719               output_profile_result);
3720 }
3721 
ThenBlasGemmWithProfiling(blas::Transpose transa,blas::Transpose transb,uint64 m,uint64 n,uint64 k,double alpha,const DeviceMemory<double> & a,int lda,const DeviceMemory<double> & b,int ldb,double beta,DeviceMemory<double> * c,int ldc,blas::ProfileResult * output_profile_result)3722 Stream &Stream::ThenBlasGemmWithProfiling(
3723     blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n,
3724     uint64 k, double alpha, const DeviceMemory<double> &a, int lda,
3725     const DeviceMemory<double> &b, int ldb, double beta,
3726     DeviceMemory<double> *c, int ldc,
3727     blas::ProfileResult *output_profile_result) {
3728   VLOG_CALL(PARAM(transa), PARAM(transb), PARAM(m), PARAM(n), PARAM(k),
3729             PARAM(alpha), PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb),
3730             PARAM(beta), PARAM(c), PARAM(ldc));
3731 
3732   ThenBlasWithProfileImpl<blas::Transpose, blas::Transpose, uint64, uint64,
3733                           uint64, double, const DeviceMemory<double> &, int,
3734                           const DeviceMemory<double> &, int, double,
3735                           DeviceMemory<double> *, int>
3736       impl;
3737   return impl(this, &blas::BlasSupport::DoBlasGemmWithProfiling, transa, transb,
3738               m, n, k, alpha, a, lda, b, ldb, beta, c, ldc,
3739               output_profile_result);
3740 }
3741 
ThenBlasGemmWithProfiling(blas::Transpose transa,blas::Transpose transb,uint64 m,uint64 n,uint64 k,std::complex<float> alpha,const DeviceMemory<std::complex<float>> & a,int lda,const DeviceMemory<std::complex<float>> & b,int ldb,std::complex<float> beta,DeviceMemory<std::complex<float>> * c,int ldc,blas::ProfileResult * output_profile_result)3742 Stream &Stream::ThenBlasGemmWithProfiling(
3743     blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n,
3744     uint64 k, std::complex<float> alpha,
3745     const DeviceMemory<std::complex<float>> &a, int lda,
3746     const DeviceMemory<std::complex<float>> &b, int ldb,
3747     std::complex<float> beta, DeviceMemory<std::complex<float>> *c, int ldc,
3748     blas::ProfileResult *output_profile_result) {
3749   VLOG_CALL(PARAM(transa), PARAM(transb), PARAM(m), PARAM(n), PARAM(k),
3750             PARAM(alpha), PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb),
3751             PARAM(beta), PARAM(c), PARAM(ldc));
3752 
3753   ThenBlasWithProfileImpl<
3754       blas::Transpose, blas::Transpose, uint64, uint64, uint64,
3755       std::complex<float>, const DeviceMemory<std::complex<float>> &, int,
3756       const DeviceMemory<std::complex<float>> &, int, std::complex<float>,
3757       DeviceMemory<std::complex<float>> *, int>
3758       impl;
3759   return impl(this, &blas::BlasSupport::DoBlasGemmWithProfiling, transa, transb,
3760               m, n, k, alpha, a, lda, b, ldb, beta, c, ldc,
3761               output_profile_result);
3762 }
3763 
ThenBlasGemmWithProfiling(blas::Transpose transa,blas::Transpose transb,uint64 m,uint64 n,uint64 k,std::complex<double> alpha,const DeviceMemory<std::complex<double>> & a,int lda,const DeviceMemory<std::complex<double>> & b,int ldb,std::complex<double> beta,DeviceMemory<std::complex<double>> * c,int ldc,blas::ProfileResult * output_profile_result)3764 Stream &Stream::ThenBlasGemmWithProfiling(
3765     blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n,
3766     uint64 k, std::complex<double> alpha,
3767     const DeviceMemory<std::complex<double>> &a, int lda,
3768     const DeviceMemory<std::complex<double>> &b, int ldb,
3769     std::complex<double> beta, DeviceMemory<std::complex<double>> *c, int ldc,
3770     blas::ProfileResult *output_profile_result) {
3771   VLOG_CALL(PARAM(transa), PARAM(transb), PARAM(m), PARAM(n), PARAM(k),
3772             PARAM(alpha), PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb),
3773             PARAM(beta), PARAM(c), PARAM(ldc));
3774 
3775   ThenBlasWithProfileImpl<
3776       blas::Transpose, blas::Transpose, uint64, uint64, uint64,
3777       std::complex<double>, const DeviceMemory<std::complex<double>> &, int,
3778       const DeviceMemory<std::complex<double>> &, int, std::complex<double>,
3779       DeviceMemory<std::complex<double>> *, int>
3780       impl;
3781   return impl(this, &blas::BlasSupport::DoBlasGemmWithProfiling, transa, transb,
3782               m, n, k, alpha, a, lda, b, ldb, beta, c, ldc,
3783               output_profile_result);
3784 }
3785 
ThenBlasGemmWithAlgorithm(blas::Transpose transa,blas::Transpose transb,uint64 m,uint64 n,uint64 k,const HostOrDeviceScalar<Eigen::half> & alpha,const DeviceMemory<Eigen::half> & a,int lda,const DeviceMemory<Eigen::half> & b,int ldb,const HostOrDeviceScalar<Eigen::half> & beta,DeviceMemory<Eigen::half> * c,int ldc,blas::ComputationType computation_type,blas::AlgorithmType algorithm,blas::ProfileResult * output_profile_result)3786 Stream &Stream::ThenBlasGemmWithAlgorithm(
3787     blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n,
3788     uint64 k, const HostOrDeviceScalar<Eigen::half> &alpha,
3789     const DeviceMemory<Eigen::half> &a, int lda,
3790     const DeviceMemory<Eigen::half> &b, int ldb,
3791     const HostOrDeviceScalar<Eigen::half> &beta, DeviceMemory<Eigen::half> *c,
3792     int ldc, blas::ComputationType computation_type,
3793     blas::AlgorithmType algorithm, blas::ProfileResult *output_profile_result) {
3794   VLOG_CALL(PARAM(transa), PARAM(transb), PARAM(m), PARAM(n), PARAM(k),
3795             PARAM(alpha), PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb),
3796             PARAM(beta), PARAM(c), PARAM(ldc), PARAM(computation_type),
3797             PARAM(algorithm));
3798 
3799   ThenBlasWithProfileImpl<
3800       blas::Transpose, blas::Transpose, uint64, uint64, uint64,
3801       const HostOrDeviceScalar<Eigen::half> &,
3802       const DeviceMemory<Eigen::half> &, int, const DeviceMemory<Eigen::half> &,
3803       int, const HostOrDeviceScalar<Eigen::half> &, DeviceMemory<Eigen::half> *,
3804       int, blas::ComputationType, blas::AlgorithmType>
3805       impl;
3806   return impl(this, &blas::BlasSupport::DoBlasGemmWithAlgorithm, transa, transb,
3807               m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, computation_type,
3808               algorithm, output_profile_result);
3809 }
3810 
ThenBlasGemmWithAlgorithm(blas::Transpose transa,blas::Transpose transb,uint64 m,uint64 n,uint64 k,const HostOrDeviceScalar<int> & alpha,const DeviceMemory<int8> & a,int lda,const DeviceMemory<int8> & b,int ldb,const HostOrDeviceScalar<int> & beta,DeviceMemory<int> * c,int ldc,blas::ComputationType computation_type,blas::AlgorithmType algorithm,blas::ProfileResult * output_profile_result)3811 Stream &Stream::ThenBlasGemmWithAlgorithm(
3812     blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n,
3813     uint64 k, const HostOrDeviceScalar<int> &alpha, const DeviceMemory<int8> &a,
3814     int lda, const DeviceMemory<int8> &b, int ldb,
3815     const HostOrDeviceScalar<int> &beta, DeviceMemory<int> *c, int ldc,
3816     blas::ComputationType computation_type, blas::AlgorithmType algorithm,
3817     blas::ProfileResult *output_profile_result) {
3818   VLOG_CALL(PARAM(transa), PARAM(transb), PARAM(m), PARAM(n), PARAM(k),
3819             PARAM(alpha), PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb),
3820             PARAM(beta), PARAM(c), PARAM(ldc), PARAM(computation_type),
3821             PARAM(algorithm));
3822 
3823   ThenBlasWithProfileImpl<
3824       blas::Transpose, blas::Transpose, uint64, uint64, uint64,
3825       const HostOrDeviceScalar<int> &, const DeviceMemory<int8> &, int,
3826       const DeviceMemory<int8> &, int, const HostOrDeviceScalar<int> &,
3827       DeviceMemory<int> *, int, blas::ComputationType, blas::AlgorithmType>
3828       impl;
3829   return impl(this, &blas::BlasSupport::DoBlasGemmWithAlgorithm, transa, transb,
3830               m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, computation_type,
3831               algorithm, output_profile_result);
3832 }
3833 
ThenBlasGemmWithAlgorithm(blas::Transpose transa,blas::Transpose transb,uint64 m,uint64 n,uint64 k,const HostOrDeviceScalar<float> & alpha,const DeviceMemory<float> & a,int lda,const DeviceMemory<float> & b,int ldb,const HostOrDeviceScalar<float> & beta,DeviceMemory<float> * c,int ldc,blas::ComputationType computation_type,blas::AlgorithmType algorithm,blas::ProfileResult * output_profile_result)3834 Stream &Stream::ThenBlasGemmWithAlgorithm(
3835     blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n,
3836     uint64 k, const HostOrDeviceScalar<float> &alpha,
3837     const DeviceMemory<float> &a, int lda, const DeviceMemory<float> &b,
3838     int ldb, const HostOrDeviceScalar<float> &beta, DeviceMemory<float> *c,
3839     int ldc, blas::ComputationType computation_type,
3840     blas::AlgorithmType algorithm, blas::ProfileResult *output_profile_result) {
3841   VLOG_CALL(PARAM(transa), PARAM(transb), PARAM(m), PARAM(n), PARAM(k),
3842             PARAM(alpha), PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb),
3843             PARAM(beta), PARAM(c), PARAM(ldc), PARAM(computation_type),
3844             PARAM(algorithm));
3845 
3846   ThenBlasWithProfileImpl<
3847       blas::Transpose, blas::Transpose, uint64, uint64, uint64,
3848       const HostOrDeviceScalar<float> &, const DeviceMemory<float> &, int,
3849       const DeviceMemory<float> &, int, const HostOrDeviceScalar<float> &,
3850       DeviceMemory<float> *, int, blas::ComputationType, blas::AlgorithmType>
3851       impl;
3852   return impl(this, &blas::BlasSupport::DoBlasGemmWithAlgorithm, transa, transb,
3853               m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, computation_type,
3854               algorithm, output_profile_result);
3855 }
3856 
ThenBlasGemmWithAlgorithm(blas::Transpose transa,blas::Transpose transb,uint64 m,uint64 n,uint64 k,const HostOrDeviceScalar<double> & alpha,const DeviceMemory<double> & a,int lda,const DeviceMemory<double> & b,int ldb,const HostOrDeviceScalar<double> & beta,DeviceMemory<double> * c,int ldc,blas::ComputationType computation_type,blas::AlgorithmType algorithm,blas::ProfileResult * output_profile_result)3857 Stream &Stream::ThenBlasGemmWithAlgorithm(
3858     blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n,
3859     uint64 k, const HostOrDeviceScalar<double> &alpha,
3860     const DeviceMemory<double> &a, int lda, const DeviceMemory<double> &b,
3861     int ldb, const HostOrDeviceScalar<double> &beta, DeviceMemory<double> *c,
3862     int ldc, blas::ComputationType computation_type,
3863     blas::AlgorithmType algorithm, blas::ProfileResult *output_profile_result) {
3864   VLOG_CALL(PARAM(transa), PARAM(transb), PARAM(m), PARAM(n), PARAM(k),
3865             PARAM(alpha), PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb),
3866             PARAM(beta), PARAM(c), PARAM(ldc), PARAM(computation_type),
3867             PARAM(algorithm));
3868 
3869   ThenBlasWithProfileImpl<
3870       blas::Transpose, blas::Transpose, uint64, uint64, uint64,
3871       const HostOrDeviceScalar<double> &, const DeviceMemory<double> &, int,
3872       const DeviceMemory<double> &, int, const HostOrDeviceScalar<double> &,
3873       DeviceMemory<double> *, int, blas::ComputationType, blas::AlgorithmType>
3874       impl;
3875   return impl(this, &blas::BlasSupport::DoBlasGemmWithAlgorithm, transa, transb,
3876               m, n, k, HostOrDeviceScalar<double>(alpha), a, lda, b, ldb,
3877               HostOrDeviceScalar<double>(beta), c, ldc, computation_type,
3878               algorithm, output_profile_result);
3879 }
3880 
ThenBlasGemmWithAlgorithm(blas::Transpose transa,blas::Transpose transb,uint64 m,uint64 n,uint64 k,const HostOrDeviceScalar<std::complex<float>> & alpha,const DeviceMemory<std::complex<float>> & a,int lda,const DeviceMemory<std::complex<float>> & b,int ldb,const HostOrDeviceScalar<std::complex<float>> & beta,DeviceMemory<std::complex<float>> * c,int ldc,blas::ComputationType computation_type,blas::AlgorithmType algorithm,blas::ProfileResult * output_profile_result)3881 Stream &Stream::ThenBlasGemmWithAlgorithm(
3882     blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n,
3883     uint64 k, const HostOrDeviceScalar<std::complex<float>> &alpha,
3884     const DeviceMemory<std::complex<float>> &a, int lda,
3885     const DeviceMemory<std::complex<float>> &b, int ldb,
3886     const HostOrDeviceScalar<std::complex<float>> &beta,
3887     DeviceMemory<std::complex<float>> *c, int ldc,
3888     blas::ComputationType computation_type, blas::AlgorithmType algorithm,
3889     blas::ProfileResult *output_profile_result) {
3890   VLOG_CALL(PARAM(transa), PARAM(transb), PARAM(m), PARAM(n), PARAM(k),
3891             PARAM(alpha), PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb),
3892             PARAM(beta), PARAM(c), PARAM(ldc), PARAM(computation_type),
3893             PARAM(algorithm));
3894 
3895   ThenBlasWithProfileImpl<blas::Transpose, blas::Transpose, uint64, uint64,
3896                           uint64,
3897                           const HostOrDeviceScalar<std::complex<float>> &,
3898                           const DeviceMemory<std::complex<float>> &, int,
3899                           const DeviceMemory<std::complex<float>> &, int,
3900                           const HostOrDeviceScalar<std::complex<float>> &,
3901                           DeviceMemory<std::complex<float>> *, int,
3902                           blas::ComputationType, blas::AlgorithmType>
3903       impl;
3904   return impl(this, &blas::BlasSupport::DoBlasGemmWithAlgorithm, transa, transb,
3905               m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, computation_type,
3906               algorithm, output_profile_result);
3907 }
3908 
ThenBlasGemmWithAlgorithm(blas::Transpose transa,blas::Transpose transb,uint64 m,uint64 n,uint64 k,const HostOrDeviceScalar<std::complex<double>> & alpha,const DeviceMemory<std::complex<double>> & a,int lda,const DeviceMemory<std::complex<double>> & b,int ldb,const HostOrDeviceScalar<std::complex<double>> & beta,DeviceMemory<std::complex<double>> * c,int ldc,blas::ComputationType computation_type,blas::AlgorithmType algorithm,blas::ProfileResult * output_profile_result)3909 Stream &Stream::ThenBlasGemmWithAlgorithm(
3910     blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n,
3911     uint64 k, const HostOrDeviceScalar<std::complex<double>> &alpha,
3912     const DeviceMemory<std::complex<double>> &a, int lda,
3913     const DeviceMemory<std::complex<double>> &b, int ldb,
3914     const HostOrDeviceScalar<std::complex<double>> &beta,
3915     DeviceMemory<std::complex<double>> *c, int ldc,
3916     blas::ComputationType computation_type, blas::AlgorithmType algorithm,
3917     blas::ProfileResult *output_profile_result) {
3918   VLOG_CALL(PARAM(transa), PARAM(transb), PARAM(m), PARAM(n), PARAM(k),
3919             PARAM(alpha), PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb),
3920             PARAM(beta), PARAM(c), PARAM(ldc), PARAM(computation_type),
3921             PARAM(algorithm));
3922 
3923   ThenBlasWithProfileImpl<blas::Transpose, blas::Transpose, uint64, uint64,
3924                           uint64,
3925                           const HostOrDeviceScalar<std::complex<double>> &,
3926                           const DeviceMemory<std::complex<double>> &, int,
3927                           const DeviceMemory<std::complex<double>> &, int,
3928                           const HostOrDeviceScalar<std::complex<double>> &,
3929                           DeviceMemory<std::complex<double>> *, int,
3930                           blas::ComputationType, blas::AlgorithmType>
3931       impl;
3932   return impl(this, &blas::BlasSupport::DoBlasGemmWithAlgorithm, transa, transb,
3933               m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, computation_type,
3934               algorithm, output_profile_result);
3935 }
3936 
ThenBlasHemm(blas::Side side,blas::UpperLower uplo,uint64 m,uint64 n,std::complex<float> alpha,const DeviceMemory<std::complex<float>> & a,int lda,const DeviceMemory<std::complex<float>> & b,int ldb,std::complex<float> beta,DeviceMemory<std::complex<float>> * c,int ldc)3937 Stream &Stream::ThenBlasHemm(blas::Side side, blas::UpperLower uplo, uint64 m,
3938                              uint64 n, std::complex<float> alpha,
3939                              const DeviceMemory<std::complex<float>> &a,
3940                              int lda,
3941                              const DeviceMemory<std::complex<float>> &b,
3942                              int ldb, std::complex<float> beta,
3943                              DeviceMemory<std::complex<float>> *c, int ldc) {
3944   VLOG_CALL(PARAM(side), PARAM(uplo), PARAM(m), PARAM(n), PARAM(alpha),
3945             PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb), PARAM(beta), PARAM(c),
3946             PARAM(ldc));
3947 
3948   ThenBlasImpl<blas::Side, blas::UpperLower, uint64, uint64,
3949                std::complex<float>, const DeviceMemory<std::complex<float>> &,
3950                int, const DeviceMemory<std::complex<float>> &, int,
3951                std::complex<float>, DeviceMemory<std::complex<float>> *,
3952                int> impl;
3953   return impl(this, &blas::BlasSupport::DoBlasHemm, side, uplo, m, n, alpha, a,
3954               lda, b, ldb, beta, c, ldc);
3955 }
3956 
ThenBlasHemm(blas::Side side,blas::UpperLower uplo,uint64 m,uint64 n,std::complex<double> alpha,const DeviceMemory<std::complex<double>> & a,int lda,const DeviceMemory<std::complex<double>> & b,int ldb,std::complex<double> beta,DeviceMemory<std::complex<double>> * c,int ldc)3957 Stream &Stream::ThenBlasHemm(blas::Side side, blas::UpperLower uplo, uint64 m,
3958                              uint64 n, std::complex<double> alpha,
3959                              const DeviceMemory<std::complex<double>> &a,
3960                              int lda,
3961                              const DeviceMemory<std::complex<double>> &b,
3962                              int ldb, std::complex<double> beta,
3963                              DeviceMemory<std::complex<double>> *c, int ldc) {
3964   VLOG_CALL(PARAM(side), PARAM(uplo), PARAM(m), PARAM(n), PARAM(alpha),
3965             PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb), PARAM(beta), PARAM(c),
3966             PARAM(ldc));
3967 
3968   ThenBlasImpl<blas::Side, blas::UpperLower, uint64, uint64,
3969                std::complex<double>, const DeviceMemory<std::complex<double>> &,
3970                int, const DeviceMemory<std::complex<double>> &, int,
3971                std::complex<double>, DeviceMemory<std::complex<double>> *,
3972                int> impl;
3973   return impl(this, &blas::BlasSupport::DoBlasHemm, side, uplo, m, n, alpha, a,
3974               lda, b, ldb, beta, c, ldc);
3975 }
3976 
ThenBlasHerk(blas::UpperLower uplo,blas::Transpose trans,uint64 n,uint64 k,float alpha,const DeviceMemory<std::complex<float>> & a,int lda,float beta,DeviceMemory<std::complex<float>> * c,int ldc)3977 Stream &Stream::ThenBlasHerk(blas::UpperLower uplo, blas::Transpose trans,
3978                              uint64 n, uint64 k, float alpha,
3979                              const DeviceMemory<std::complex<float>> &a,
3980                              int lda, float beta,
3981                              DeviceMemory<std::complex<float>> *c, int ldc) {
3982   VLOG_CALL(PARAM(uplo), PARAM(trans), PARAM(n), PARAM(k), PARAM(alpha),
3983             PARAM(a), PARAM(lda), PARAM(beta), PARAM(c), PARAM(ldc));
3984 
3985   ThenBlasImpl<blas::UpperLower, blas::Transpose, uint64, uint64, float,
3986                const DeviceMemory<std::complex<float>> &, int, float,
3987                DeviceMemory<std::complex<float>> *, int> impl;
3988   return impl(this, &blas::BlasSupport::DoBlasHerk, uplo, trans, n, k, alpha, a,
3989               lda, beta, c, ldc);
3990 }
3991 
ThenBlasHerk(blas::UpperLower uplo,blas::Transpose trans,uint64 n,uint64 k,double alpha,const DeviceMemory<std::complex<double>> & a,int lda,double beta,DeviceMemory<std::complex<double>> * c,int ldc)3992 Stream &Stream::ThenBlasHerk(blas::UpperLower uplo, blas::Transpose trans,
3993                              uint64 n, uint64 k, double alpha,
3994                              const DeviceMemory<std::complex<double>> &a,
3995                              int lda, double beta,
3996                              DeviceMemory<std::complex<double>> *c, int ldc) {
3997   VLOG_CALL(PARAM(uplo), PARAM(trans), PARAM(n), PARAM(k), PARAM(alpha),
3998             PARAM(a), PARAM(lda), PARAM(beta), PARAM(c), PARAM(ldc));
3999 
4000   ThenBlasImpl<blas::UpperLower, blas::Transpose, uint64, uint64, double,
4001                const DeviceMemory<std::complex<double>> &, int, double,
4002                DeviceMemory<std::complex<double>> *, int> impl;
4003   return impl(this, &blas::BlasSupport::DoBlasHerk, uplo, trans, n, k, alpha, a,
4004               lda, beta, c, ldc);
4005 }
4006 
ThenBlasHer2k(blas::UpperLower uplo,blas::Transpose trans,uint64 n,uint64 k,std::complex<float> alpha,const DeviceMemory<std::complex<float>> & a,int lda,const DeviceMemory<std::complex<float>> & b,int ldb,float beta,DeviceMemory<std::complex<float>> * c,int ldc)4007 Stream &Stream::ThenBlasHer2k(blas::UpperLower uplo, blas::Transpose trans,
4008                               uint64 n, uint64 k, std::complex<float> alpha,
4009                               const DeviceMemory<std::complex<float>> &a,
4010                               int lda,
4011                               const DeviceMemory<std::complex<float>> &b,
4012                               int ldb, float beta,
4013                               DeviceMemory<std::complex<float>> *c, int ldc) {
4014   VLOG_CALL(PARAM(uplo), PARAM(trans), PARAM(n), PARAM(k), PARAM(alpha),
4015             PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb), PARAM(beta), PARAM(c),
4016             PARAM(ldc));
4017 
4018   ThenBlasImpl<blas::UpperLower, blas::Transpose, uint64, uint64,
4019                std::complex<float>, const DeviceMemory<std::complex<float>> &,
4020                int, const DeviceMemory<std::complex<float>> &, int, float,
4021                DeviceMemory<std::complex<float>> *, int> impl;
4022   return impl(this, &blas::BlasSupport::DoBlasHer2k, uplo, trans, n, k, alpha,
4023               a, lda, b, ldb, beta, c, ldc);
4024 }
4025 
ThenBlasHer2k(blas::UpperLower uplo,blas::Transpose trans,uint64 n,uint64 k,std::complex<double> alpha,const DeviceMemory<std::complex<double>> & a,int lda,const DeviceMemory<std::complex<double>> & b,int ldb,double beta,DeviceMemory<std::complex<double>> * c,int ldc)4026 Stream &Stream::ThenBlasHer2k(blas::UpperLower uplo, blas::Transpose trans,
4027                               uint64 n, uint64 k, std::complex<double> alpha,
4028                               const DeviceMemory<std::complex<double>> &a,
4029                               int lda,
4030                               const DeviceMemory<std::complex<double>> &b,
4031                               int ldb, double beta,
4032                               DeviceMemory<std::complex<double>> *c, int ldc) {
4033   VLOG_CALL(PARAM(uplo), PARAM(trans), PARAM(n), PARAM(k), PARAM(alpha),
4034             PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb), PARAM(beta), PARAM(c),
4035             PARAM(ldc));
4036 
4037   ThenBlasImpl<blas::UpperLower, blas::Transpose, uint64, uint64,
4038                std::complex<double>, const DeviceMemory<std::complex<double>> &,
4039                int, const DeviceMemory<std::complex<double>> &, int, double,
4040                DeviceMemory<std::complex<double>> *, int> impl;
4041   return impl(this, &blas::BlasSupport::DoBlasHer2k, uplo, trans, n, k, alpha,
4042               a, lda, b, ldb, beta, c, ldc);
4043 }
4044 
ThenBlasSymm(blas::Side side,blas::UpperLower uplo,uint64 m,uint64 n,float alpha,const DeviceMemory<float> & a,int lda,const DeviceMemory<float> & b,int ldb,float beta,DeviceMemory<float> * c,int ldc)4045 Stream &Stream::ThenBlasSymm(blas::Side side, blas::UpperLower uplo, uint64 m,
4046                              uint64 n, float alpha,
4047                              const DeviceMemory<float> &a, int lda,
4048                              const DeviceMemory<float> &b, int ldb, float beta,
4049                              DeviceMemory<float> *c, int ldc) {
4050   VLOG_CALL(PARAM(side), PARAM(uplo), PARAM(m), PARAM(n), PARAM(alpha),
4051             PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb), PARAM(beta), PARAM(c),
4052             PARAM(ldc));
4053 
4054   ThenBlasImpl<blas::Side, blas::UpperLower, uint64, uint64, float,
4055                const DeviceMemory<float> &, int, const DeviceMemory<float> &,
4056                int, float, DeviceMemory<float> *, int> impl;
4057   return impl(this, &blas::BlasSupport::DoBlasSymm, side, uplo, m, n, alpha, a,
4058               lda, b, ldb, beta, c, ldc);
4059 }
4060 
ThenBlasSymm(blas::Side side,blas::UpperLower uplo,uint64 m,uint64 n,double alpha,const DeviceMemory<double> & a,int lda,const DeviceMemory<double> & b,int ldb,double beta,DeviceMemory<double> * c,int ldc)4061 Stream &Stream::ThenBlasSymm(blas::Side side, blas::UpperLower uplo, uint64 m,
4062                              uint64 n, double alpha,
4063                              const DeviceMemory<double> &a, int lda,
4064                              const DeviceMemory<double> &b, int ldb,
4065                              double beta, DeviceMemory<double> *c, int ldc) {
4066   VLOG_CALL(PARAM(side), PARAM(uplo), PARAM(m), PARAM(n), PARAM(alpha),
4067             PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb), PARAM(beta), PARAM(c),
4068             PARAM(ldc));
4069 
4070   ThenBlasImpl<blas::Side, blas::UpperLower, uint64, uint64, double,
4071                const DeviceMemory<double> &, int, const DeviceMemory<double> &,
4072                int, double, DeviceMemory<double> *, int> impl;
4073   return impl(this, &blas::BlasSupport::DoBlasSymm, side, uplo, m, n, alpha, a,
4074               lda, b, ldb, beta, c, ldc);
4075 }
4076 
ThenBlasSymm(blas::Side side,blas::UpperLower uplo,uint64 m,uint64 n,std::complex<float> alpha,const DeviceMemory<std::complex<float>> & a,int lda,const DeviceMemory<std::complex<float>> & b,int ldb,std::complex<float> beta,DeviceMemory<std::complex<float>> * c,int ldc)4077 Stream &Stream::ThenBlasSymm(blas::Side side, blas::UpperLower uplo, uint64 m,
4078                              uint64 n, std::complex<float> alpha,
4079                              const DeviceMemory<std::complex<float>> &a,
4080                              int lda,
4081                              const DeviceMemory<std::complex<float>> &b,
4082                              int ldb, std::complex<float> beta,
4083                              DeviceMemory<std::complex<float>> *c, int ldc) {
4084   VLOG_CALL(PARAM(side), PARAM(uplo), PARAM(m), PARAM(n), PARAM(alpha),
4085             PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb), PARAM(beta), PARAM(c),
4086             PARAM(ldc));
4087 
4088   ThenBlasImpl<blas::Side, blas::UpperLower, uint64, uint64,
4089                std::complex<float>, const DeviceMemory<std::complex<float>> &,
4090                int, const DeviceMemory<std::complex<float>> &, int,
4091                std::complex<float>, DeviceMemory<std::complex<float>> *,
4092                int> impl;
4093   return impl(this, &blas::BlasSupport::DoBlasSymm, side, uplo, m, n, alpha, a,
4094               lda, b, ldb, beta, c, ldc);
4095 }
4096 
ThenBlasSymm(blas::Side side,blas::UpperLower uplo,uint64 m,uint64 n,std::complex<double> alpha,const DeviceMemory<std::complex<double>> & a,int lda,const DeviceMemory<std::complex<double>> & b,int ldb,std::complex<double> beta,DeviceMemory<std::complex<double>> * c,int ldc)4097 Stream &Stream::ThenBlasSymm(blas::Side side, blas::UpperLower uplo, uint64 m,
4098                              uint64 n, std::complex<double> alpha,
4099                              const DeviceMemory<std::complex<double>> &a,
4100                              int lda,
4101                              const DeviceMemory<std::complex<double>> &b,
4102                              int ldb, std::complex<double> beta,
4103                              DeviceMemory<std::complex<double>> *c, int ldc) {
4104   VLOG_CALL(PARAM(side), PARAM(uplo), PARAM(m), PARAM(n), PARAM(alpha),
4105             PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb), PARAM(beta), PARAM(c),
4106             PARAM(ldc));
4107 
4108   ThenBlasImpl<blas::Side, blas::UpperLower, uint64, uint64,
4109                std::complex<double>, const DeviceMemory<std::complex<double>> &,
4110                int, const DeviceMemory<std::complex<double>> &, int,
4111                std::complex<double>, DeviceMemory<std::complex<double>> *,
4112                int> impl;
4113   return impl(this, &blas::BlasSupport::DoBlasSymm, side, uplo, m, n, alpha, a,
4114               lda, b, ldb, beta, c, ldc);
4115 }
4116 
ThenBlasSyrk(blas::UpperLower uplo,blas::Transpose trans,uint64 n,uint64 k,float alpha,const DeviceMemory<float> & a,int lda,float beta,DeviceMemory<float> * c,int ldc)4117 Stream &Stream::ThenBlasSyrk(blas::UpperLower uplo, blas::Transpose trans,
4118                              uint64 n, uint64 k, float alpha,
4119                              const DeviceMemory<float> &a, int lda, float beta,
4120                              DeviceMemory<float> *c, int ldc) {
4121   VLOG_CALL(PARAM(uplo), PARAM(trans), PARAM(n), PARAM(k), PARAM(alpha),
4122             PARAM(a), PARAM(lda), PARAM(beta), PARAM(c), PARAM(ldc));
4123 
4124   ThenBlasImpl<blas::UpperLower, blas::Transpose, uint64, uint64, float,
4125                const DeviceMemory<float> &, int, float, DeviceMemory<float> *,
4126                int> impl;
4127   return impl(this, &blas::BlasSupport::DoBlasSyrk, uplo, trans, n, k, alpha, a,
4128               lda, beta, c, ldc);
4129 }
4130 
ThenBlasSyrk(blas::UpperLower uplo,blas::Transpose trans,uint64 n,uint64 k,double alpha,const DeviceMemory<double> & a,int lda,double beta,DeviceMemory<double> * c,int ldc)4131 Stream &Stream::ThenBlasSyrk(blas::UpperLower uplo, blas::Transpose trans,
4132                              uint64 n, uint64 k, double alpha,
4133                              const DeviceMemory<double> &a, int lda,
4134                              double beta, DeviceMemory<double> *c, int ldc) {
4135   VLOG_CALL(PARAM(uplo), PARAM(trans), PARAM(n), PARAM(k), PARAM(alpha),
4136             PARAM(a), PARAM(lda), PARAM(beta), PARAM(c), PARAM(ldc));
4137 
4138   ThenBlasImpl<blas::UpperLower, blas::Transpose, uint64, uint64, double,
4139                const DeviceMemory<double> &, int, double,
4140                DeviceMemory<double> *, int> impl;
4141   return impl(this, &blas::BlasSupport::DoBlasSyrk, uplo, trans, n, k, alpha, a,
4142               lda, beta, c, ldc);
4143 }
4144 
ThenBlasSyrk(blas::UpperLower uplo,blas::Transpose trans,uint64 n,uint64 k,std::complex<float> alpha,const DeviceMemory<std::complex<float>> & a,int lda,std::complex<float> beta,DeviceMemory<std::complex<float>> * c,int ldc)4145 Stream &Stream::ThenBlasSyrk(blas::UpperLower uplo, blas::Transpose trans,
4146                              uint64 n, uint64 k, std::complex<float> alpha,
4147                              const DeviceMemory<std::complex<float>> &a,
4148                              int lda, std::complex<float> beta,
4149                              DeviceMemory<std::complex<float>> *c, int ldc) {
4150   VLOG_CALL(PARAM(uplo), PARAM(trans), PARAM(n), PARAM(k), PARAM(alpha),
4151             PARAM(a), PARAM(lda), PARAM(beta), PARAM(c), PARAM(ldc));
4152 
4153   ThenBlasImpl<blas::UpperLower, blas::Transpose, uint64, uint64,
4154                std::complex<float>, const DeviceMemory<std::complex<float>> &,
4155                int, std::complex<float>, DeviceMemory<std::complex<float>> *,
4156                int> impl;
4157   return impl(this, &blas::BlasSupport::DoBlasSyrk, uplo, trans, n, k, alpha, a,
4158               lda, beta, c, ldc);
4159 }
4160 
ThenBlasSyrk(blas::UpperLower uplo,blas::Transpose trans,uint64 n,uint64 k,std::complex<double> alpha,const DeviceMemory<std::complex<double>> & a,int lda,std::complex<double> beta,DeviceMemory<std::complex<double>> * c,int ldc)4161 Stream &Stream::ThenBlasSyrk(blas::UpperLower uplo, blas::Transpose trans,
4162                              uint64 n, uint64 k, std::complex<double> alpha,
4163                              const DeviceMemory<std::complex<double>> &a,
4164                              int lda, std::complex<double> beta,
4165                              DeviceMemory<std::complex<double>> *c, int ldc) {
4166   VLOG_CALL(PARAM(uplo), PARAM(trans), PARAM(n), PARAM(k), PARAM(alpha),
4167             PARAM(a), PARAM(lda), PARAM(beta), PARAM(c), PARAM(ldc));
4168 
4169   ThenBlasImpl<blas::UpperLower, blas::Transpose, uint64, uint64,
4170                std::complex<double>, const DeviceMemory<std::complex<double>> &,
4171                int, std::complex<double>, DeviceMemory<std::complex<double>> *,
4172                int> impl;
4173   return impl(this, &blas::BlasSupport::DoBlasSyrk, uplo, trans, n, k, alpha, a,
4174               lda, beta, c, ldc);
4175 }
4176 
ThenBlasSyr2k(blas::UpperLower uplo,blas::Transpose trans,uint64 n,uint64 k,float alpha,const DeviceMemory<float> & a,int lda,const DeviceMemory<float> & b,int ldb,float beta,DeviceMemory<float> * c,int ldc)4177 Stream &Stream::ThenBlasSyr2k(blas::UpperLower uplo, blas::Transpose trans,
4178                               uint64 n, uint64 k, float alpha,
4179                               const DeviceMemory<float> &a, int lda,
4180                               const DeviceMemory<float> &b, int ldb, float beta,
4181                               DeviceMemory<float> *c, int ldc) {
4182   VLOG_CALL(PARAM(uplo), PARAM(trans), PARAM(n), PARAM(k), PARAM(alpha),
4183             PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb), PARAM(beta), PARAM(c),
4184             PARAM(ldc));
4185 
4186   ThenBlasImpl<blas::UpperLower, blas::Transpose, uint64, uint64, float,
4187                const DeviceMemory<float> &, int, const DeviceMemory<float> &,
4188                int, float, DeviceMemory<float> *, int> impl;
4189   return impl(this, &blas::BlasSupport::DoBlasSyr2k, uplo, trans, n, k, alpha,
4190               a, lda, b, ldb, beta, c, ldc);
4191 }
4192 
ThenBlasSyr2k(blas::UpperLower uplo,blas::Transpose trans,uint64 n,uint64 k,double alpha,const DeviceMemory<double> & a,int lda,const DeviceMemory<double> & b,int ldb,double beta,DeviceMemory<double> * c,int ldc)4193 Stream &Stream::ThenBlasSyr2k(blas::UpperLower uplo, blas::Transpose trans,
4194                               uint64 n, uint64 k, double alpha,
4195                               const DeviceMemory<double> &a, int lda,
4196                               const DeviceMemory<double> &b, int ldb,
4197                               double beta, DeviceMemory<double> *c, int ldc) {
4198   VLOG_CALL(PARAM(uplo), PARAM(trans), PARAM(n), PARAM(k), PARAM(alpha),
4199             PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb), PARAM(beta), PARAM(c),
4200             PARAM(ldc));
4201 
4202   ThenBlasImpl<blas::UpperLower, blas::Transpose, uint64, uint64, double,
4203                const DeviceMemory<double> &, int, const DeviceMemory<double> &,
4204                int, double, DeviceMemory<double> *, int> impl;
4205   return impl(this, &blas::BlasSupport::DoBlasSyr2k, uplo, trans, n, k, alpha,
4206               a, lda, b, ldb, beta, c, ldc);
4207 }
4208 
ThenBlasSyr2k(blas::UpperLower uplo,blas::Transpose trans,uint64 n,uint64 k,std::complex<float> alpha,const DeviceMemory<std::complex<float>> & a,int lda,const DeviceMemory<std::complex<float>> & b,int ldb,std::complex<float> beta,DeviceMemory<std::complex<float>> * c,int ldc)4209 Stream &Stream::ThenBlasSyr2k(blas::UpperLower uplo, blas::Transpose trans,
4210                               uint64 n, uint64 k, std::complex<float> alpha,
4211                               const DeviceMemory<std::complex<float>> &a,
4212                               int lda,
4213                               const DeviceMemory<std::complex<float>> &b,
4214                               int ldb, std::complex<float> beta,
4215                               DeviceMemory<std::complex<float>> *c, int ldc) {
4216   VLOG_CALL(PARAM(uplo), PARAM(trans), PARAM(n), PARAM(k), PARAM(alpha),
4217             PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb), PARAM(beta), PARAM(c),
4218             PARAM(ldc));
4219 
4220   ThenBlasImpl<blas::UpperLower, blas::Transpose, uint64, uint64,
4221                std::complex<float>, const DeviceMemory<std::complex<float>> &,
4222                int, const DeviceMemory<std::complex<float>> &, int,
4223                std::complex<float>, DeviceMemory<std::complex<float>> *,
4224                int> impl;
4225   return impl(this, &blas::BlasSupport::DoBlasSyr2k, uplo, trans, n, k, alpha,
4226               a, lda, b, ldb, beta, c, ldc);
4227 }
4228 
ThenBlasSyr2k(blas::UpperLower uplo,blas::Transpose trans,uint64 n,uint64 k,std::complex<double> alpha,const DeviceMemory<std::complex<double>> & a,int lda,const DeviceMemory<std::complex<double>> & b,int ldb,std::complex<double> beta,DeviceMemory<std::complex<double>> * c,int ldc)4229 Stream &Stream::ThenBlasSyr2k(blas::UpperLower uplo, blas::Transpose trans,
4230                               uint64 n, uint64 k, std::complex<double> alpha,
4231                               const DeviceMemory<std::complex<double>> &a,
4232                               int lda,
4233                               const DeviceMemory<std::complex<double>> &b,
4234                               int ldb, std::complex<double> beta,
4235                               DeviceMemory<std::complex<double>> *c, int ldc) {
4236   VLOG_CALL(PARAM(uplo), PARAM(trans), PARAM(n), PARAM(k), PARAM(alpha),
4237             PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb), PARAM(beta), PARAM(c),
4238             PARAM(ldc));
4239 
4240   ThenBlasImpl<blas::UpperLower, blas::Transpose, uint64, uint64,
4241                std::complex<double>, const DeviceMemory<std::complex<double>> &,
4242                int, const DeviceMemory<std::complex<double>> &, int,
4243                std::complex<double>, DeviceMemory<std::complex<double>> *,
4244                int> impl;
4245   return impl(this, &blas::BlasSupport::DoBlasSyr2k, uplo, trans, n, k, alpha,
4246               a, lda, b, ldb, beta, c, ldc);
4247 }
4248 
ThenBlasTrmm(blas::Side side,blas::UpperLower uplo,blas::Transpose transa,blas::Diagonal diag,uint64 m,uint64 n,float alpha,const DeviceMemory<float> & a,int lda,DeviceMemory<float> * b,int ldb)4249 Stream &Stream::ThenBlasTrmm(blas::Side side, blas::UpperLower uplo,
4250                              blas::Transpose transa, blas::Diagonal diag,
4251                              uint64 m, uint64 n, float alpha,
4252                              const DeviceMemory<float> &a, int lda,
4253                              DeviceMemory<float> *b, int ldb) {
4254   VLOG_CALL(PARAM(side), PARAM(uplo), PARAM(transa), PARAM(diag), PARAM(m),
4255             PARAM(n), PARAM(alpha), PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb));
4256 
4257   ThenBlasImpl<blas::Side, blas::UpperLower, blas::Transpose, blas::Diagonal,
4258                uint64, uint64, float, const DeviceMemory<float> &, int,
4259                DeviceMemory<float> *, int> impl;
4260   return impl(this, &blas::BlasSupport::DoBlasTrmm, side, uplo, transa, diag, m,
4261               n, alpha, a, lda, b, ldb);
4262 }
4263 
ThenBlasTrmm(blas::Side side,blas::UpperLower uplo,blas::Transpose transa,blas::Diagonal diag,uint64 m,uint64 n,double alpha,const DeviceMemory<double> & a,int lda,DeviceMemory<double> * b,int ldb)4264 Stream &Stream::ThenBlasTrmm(blas::Side side, blas::UpperLower uplo,
4265                              blas::Transpose transa, blas::Diagonal diag,
4266                              uint64 m, uint64 n, double alpha,
4267                              const DeviceMemory<double> &a, int lda,
4268                              DeviceMemory<double> *b, int ldb) {
4269   VLOG_CALL(PARAM(side), PARAM(uplo), PARAM(transa), PARAM(diag), PARAM(m),
4270             PARAM(n), PARAM(alpha), PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb));
4271 
4272   ThenBlasImpl<blas::Side, blas::UpperLower, blas::Transpose, blas::Diagonal,
4273                uint64, uint64, double, const DeviceMemory<double> &, int,
4274                DeviceMemory<double> *, int> impl;
4275   return impl(this, &blas::BlasSupport::DoBlasTrmm, side, uplo, transa, diag, m,
4276               n, alpha, a, lda, b, ldb);
4277 }
4278 
ThenBlasTrmm(blas::Side side,blas::UpperLower uplo,blas::Transpose transa,blas::Diagonal diag,uint64 m,uint64 n,std::complex<float> alpha,const DeviceMemory<std::complex<float>> & a,int lda,DeviceMemory<std::complex<float>> * b,int ldb)4279 Stream &Stream::ThenBlasTrmm(blas::Side side, blas::UpperLower uplo,
4280                              blas::Transpose transa, blas::Diagonal diag,
4281                              uint64 m, uint64 n, std::complex<float> alpha,
4282                              const DeviceMemory<std::complex<float>> &a,
4283                              int lda, DeviceMemory<std::complex<float>> *b,
4284                              int ldb) {
4285   VLOG_CALL(PARAM(side), PARAM(uplo), PARAM(transa), PARAM(diag), PARAM(m),
4286             PARAM(n), PARAM(alpha), PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb));
4287 
4288   ThenBlasImpl<blas::Side, blas::UpperLower, blas::Transpose, blas::Diagonal,
4289                uint64, uint64, std::complex<float>,
4290                const DeviceMemory<std::complex<float>> &, int,
4291                DeviceMemory<std::complex<float>> *, int> impl;
4292   return impl(this, &blas::BlasSupport::DoBlasTrmm, side, uplo, transa, diag, m,
4293               n, alpha, a, lda, b, ldb);
4294 }
4295 
ThenBlasTrmm(blas::Side side,blas::UpperLower uplo,blas::Transpose transa,blas::Diagonal diag,uint64 m,uint64 n,std::complex<double> alpha,const DeviceMemory<std::complex<double>> & a,int lda,DeviceMemory<std::complex<double>> * b,int ldb)4296 Stream &Stream::ThenBlasTrmm(blas::Side side, blas::UpperLower uplo,
4297                              blas::Transpose transa, blas::Diagonal diag,
4298                              uint64 m, uint64 n, std::complex<double> alpha,
4299                              const DeviceMemory<std::complex<double>> &a,
4300                              int lda, DeviceMemory<std::complex<double>> *b,
4301                              int ldb) {
4302   VLOG_CALL(PARAM(side), PARAM(uplo), PARAM(transa), PARAM(diag), PARAM(m),
4303             PARAM(n), PARAM(alpha), PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb));
4304 
4305   ThenBlasImpl<blas::Side, blas::UpperLower, blas::Transpose, blas::Diagonal,
4306                uint64, uint64, std::complex<double>,
4307                const DeviceMemory<std::complex<double>> &, int,
4308                DeviceMemory<std::complex<double>> *, int> impl;
4309   return impl(this, &blas::BlasSupport::DoBlasTrmm, side, uplo, transa, diag, m,
4310               n, alpha, a, lda, b, ldb);
4311 }
4312 
ThenBlasTrsm(blas::Side side,blas::UpperLower uplo,blas::Transpose transa,blas::Diagonal diag,uint64 m,uint64 n,float alpha,const DeviceMemory<float> & a,int lda,DeviceMemory<float> * b,int ldb)4313 Stream &Stream::ThenBlasTrsm(blas::Side side, blas::UpperLower uplo,
4314                              blas::Transpose transa, blas::Diagonal diag,
4315                              uint64 m, uint64 n, float alpha,
4316                              const DeviceMemory<float> &a, int lda,
4317                              DeviceMemory<float> *b, int ldb) {
4318   VLOG_CALL(PARAM(side), PARAM(uplo), PARAM(transa), PARAM(diag), PARAM(m),
4319             PARAM(n), PARAM(alpha), PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb));
4320 
4321   ThenBlasImpl<blas::Side, blas::UpperLower, blas::Transpose, blas::Diagonal,
4322                uint64, uint64, float, const DeviceMemory<float> &, int,
4323                DeviceMemory<float> *, int> impl;
4324   return impl(this, &blas::BlasSupport::DoBlasTrsm, side, uplo, transa, diag, m,
4325               n, alpha, a, lda, b, ldb);
4326 }
4327 
ThenBlasTrsm(blas::Side side,blas::UpperLower uplo,blas::Transpose transa,blas::Diagonal diag,uint64 m,uint64 n,double alpha,const DeviceMemory<double> & a,int lda,DeviceMemory<double> * b,int ldb)4328 Stream &Stream::ThenBlasTrsm(blas::Side side, blas::UpperLower uplo,
4329                              blas::Transpose transa, blas::Diagonal diag,
4330                              uint64 m, uint64 n, double alpha,
4331                              const DeviceMemory<double> &a, int lda,
4332                              DeviceMemory<double> *b, int ldb) {
4333   VLOG_CALL(PARAM(side), PARAM(uplo), PARAM(transa), PARAM(diag), PARAM(m),
4334             PARAM(n), PARAM(alpha), PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb));
4335 
4336   ThenBlasImpl<blas::Side, blas::UpperLower, blas::Transpose, blas::Diagonal,
4337                uint64, uint64, double, const DeviceMemory<double> &, int,
4338                DeviceMemory<double> *, int> impl;
4339   return impl(this, &blas::BlasSupport::DoBlasTrsm, side, uplo, transa, diag, m,
4340               n, alpha, a, lda, b, ldb);
4341 }
4342 
ThenBlasTrsm(blas::Side side,blas::UpperLower uplo,blas::Transpose transa,blas::Diagonal diag,uint64 m,uint64 n,std::complex<float> alpha,const DeviceMemory<std::complex<float>> & a,int lda,DeviceMemory<std::complex<float>> * b,int ldb)4343 Stream &Stream::ThenBlasTrsm(blas::Side side, blas::UpperLower uplo,
4344                              blas::Transpose transa, blas::Diagonal diag,
4345                              uint64 m, uint64 n, std::complex<float> alpha,
4346                              const DeviceMemory<std::complex<float>> &a,
4347                              int lda, DeviceMemory<std::complex<float>> *b,
4348                              int ldb) {
4349   VLOG_CALL(PARAM(side), PARAM(uplo), PARAM(transa), PARAM(diag), PARAM(m),
4350             PARAM(n), PARAM(alpha), PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb));
4351 
4352   ThenBlasImpl<blas::Side, blas::UpperLower, blas::Transpose, blas::Diagonal,
4353                uint64, uint64, std::complex<float>,
4354                const DeviceMemory<std::complex<float>> &, int,
4355                DeviceMemory<std::complex<float>> *, int> impl;
4356   return impl(this, &blas::BlasSupport::DoBlasTrsm, side, uplo, transa, diag, m,
4357               n, alpha, a, lda, b, ldb);
4358 }
4359 
ThenBlasTrsm(blas::Side side,blas::UpperLower uplo,blas::Transpose transa,blas::Diagonal diag,uint64 m,uint64 n,std::complex<double> alpha,const DeviceMemory<std::complex<double>> & a,int lda,DeviceMemory<std::complex<double>> * b,int ldb)4360 Stream &Stream::ThenBlasTrsm(blas::Side side, blas::UpperLower uplo,
4361                              blas::Transpose transa, blas::Diagonal diag,
4362                              uint64 m, uint64 n, std::complex<double> alpha,
4363                              const DeviceMemory<std::complex<double>> &a,
4364                              int lda, DeviceMemory<std::complex<double>> *b,
4365                              int ldb) {
4366   VLOG_CALL(PARAM(side), PARAM(uplo), PARAM(transa), PARAM(diag), PARAM(m),
4367             PARAM(n), PARAM(alpha), PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb));
4368 
4369   ThenBlasImpl<blas::Side, blas::UpperLower, blas::Transpose, blas::Diagonal,
4370                uint64, uint64, std::complex<double>,
4371                const DeviceMemory<std::complex<double>> &, int,
4372                DeviceMemory<std::complex<double>> *, int> impl;
4373   return impl(this, &blas::BlasSupport::DoBlasTrsm, side, uplo, transa, diag, m,
4374               n, alpha, a, lda, b, ldb);
4375 }
4376 
ThenBlasGemmBatched(blas::Transpose transa,blas::Transpose transb,uint64 m,uint64 n,uint64 k,float alpha,const port::ArraySlice<DeviceMemory<Eigen::half> * > & a,int lda,const port::ArraySlice<DeviceMemory<Eigen::half> * > & b,int ldb,float beta,const port::ArraySlice<DeviceMemory<Eigen::half> * > & c,int ldc,int batch_count)4377 Stream &Stream::ThenBlasGemmBatched(
4378     blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n,
4379     uint64 k, float alpha,
4380     const port::ArraySlice<DeviceMemory<Eigen::half> *> &a, int lda,
4381     const port::ArraySlice<DeviceMemory<Eigen::half> *> &b, int ldb, float beta,
4382     const port::ArraySlice<DeviceMemory<Eigen::half> *> &c, int ldc,
4383     int batch_count) {
4384   return ThenBlasGemmBatchedWithScratch(transa, transb, m, n, k, alpha, a, lda,
4385                                         b, ldb, beta, c, ldc, batch_count,
4386                                         /*scratch_allocator=*/nullptr);
4387 }
4388 
ThenBlasGemmBatchedWithScratch(blas::Transpose transa,blas::Transpose transb,uint64 m,uint64 n,uint64 k,float alpha,const port::ArraySlice<DeviceMemory<Eigen::half> * > & a,int lda,const port::ArraySlice<DeviceMemory<Eigen::half> * > & b,int ldb,float beta,const port::ArraySlice<DeviceMemory<Eigen::half> * > & c,int ldc,int batch_count,ScratchAllocator * scratch_allocator)4389 Stream &Stream::ThenBlasGemmBatchedWithScratch(
4390     blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n,
4391     uint64 k, float alpha,
4392     const port::ArraySlice<DeviceMemory<Eigen::half> *> &a, int lda,
4393     const port::ArraySlice<DeviceMemory<Eigen::half> *> &b, int ldb, float beta,
4394     const port::ArraySlice<DeviceMemory<Eigen::half> *> &c, int ldc,
4395     int batch_count, ScratchAllocator *scratch_allocator) {
4396   VLOG_CALL(PARAM(transa), PARAM(transb), PARAM(m), PARAM(n), PARAM(k),
4397             PARAM(alpha), PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb),
4398             PARAM(beta), PARAM(c), PARAM(ldc), PARAM(batch_count));
4399 
4400   ThenBlasImpl<blas::Transpose, blas::Transpose, uint64, uint64, uint64, float,
4401                const port::ArraySlice<DeviceMemory<Eigen::half> *> &, int,
4402                const port::ArraySlice<DeviceMemory<Eigen::half> *> &, int,
4403                float, const port::ArraySlice<DeviceMemory<Eigen::half> *> &,
4404                int, int, ScratchAllocator *>
4405       impl;
4406   return impl(this, &blas::BlasSupport::DoBlasGemmBatched, transa, transb, m, n,
4407               k, alpha, a, lda, b, ldb, beta, c, ldc, batch_count,
4408               scratch_allocator);
4409 }
4410 
ThenBlasGemmBatched(blas::Transpose transa,blas::Transpose transb,uint64 m,uint64 n,uint64 k,float alpha,const port::ArraySlice<DeviceMemory<float> * > & a,int lda,const port::ArraySlice<DeviceMemory<float> * > & b,int ldb,float beta,const port::ArraySlice<DeviceMemory<float> * > & c,int ldc,int batch_count)4411 Stream &Stream::ThenBlasGemmBatched(
4412     blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n,
4413     uint64 k, float alpha, const port::ArraySlice<DeviceMemory<float> *> &a,
4414     int lda, const port::ArraySlice<DeviceMemory<float> *> &b, int ldb,
4415     float beta, const port::ArraySlice<DeviceMemory<float> *> &c, int ldc,
4416     int batch_count) {
4417   return ThenBlasGemmBatchedWithScratch(transa, transb, m, n, k, alpha, a, lda,
4418                                         b, ldb, beta, c, ldc, batch_count,
4419                                         /*scratch_allocator=*/nullptr);
4420 }
4421 
ThenBlasGemmBatchedWithScratch(blas::Transpose transa,blas::Transpose transb,uint64 m,uint64 n,uint64 k,float alpha,const port::ArraySlice<DeviceMemory<float> * > & a,int lda,const port::ArraySlice<DeviceMemory<float> * > & b,int ldb,float beta,const port::ArraySlice<DeviceMemory<float> * > & c,int ldc,int batch_count,ScratchAllocator * scratch_allocator)4422 Stream &Stream::ThenBlasGemmBatchedWithScratch(
4423     blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n,
4424     uint64 k, float alpha, const port::ArraySlice<DeviceMemory<float> *> &a,
4425     int lda, const port::ArraySlice<DeviceMemory<float> *> &b, int ldb,
4426     float beta, const port::ArraySlice<DeviceMemory<float> *> &c, int ldc,
4427     int batch_count, ScratchAllocator *scratch_allocator) {
4428   VLOG_CALL(PARAM(transa), PARAM(transb), PARAM(m), PARAM(n), PARAM(k),
4429             PARAM(alpha), PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb),
4430             PARAM(beta), PARAM(c), PARAM(ldc), PARAM(batch_count));
4431 
4432   ThenBlasImpl<blas::Transpose, blas::Transpose, uint64, uint64, uint64, float,
4433                const port::ArraySlice<DeviceMemory<float> *> &, int,
4434                const port::ArraySlice<DeviceMemory<float> *> &, int, float,
4435                const port::ArraySlice<DeviceMemory<float> *> &, int, int,
4436                ScratchAllocator *>
4437       impl;
4438   return impl(this, &blas::BlasSupport::DoBlasGemmBatched, transa, transb, m, n,
4439               k, alpha, a, lda, b, ldb, beta, c, ldc, batch_count,
4440               scratch_allocator);
4441 }
4442 
ThenBlasGemmBatched(blas::Transpose transa,blas::Transpose transb,uint64 m,uint64 n,uint64 k,double alpha,const port::ArraySlice<DeviceMemory<double> * > & a,int lda,const port::ArraySlice<DeviceMemory<double> * > & b,int ldb,double beta,const port::ArraySlice<DeviceMemory<double> * > & c,int ldc,int batch_count)4443 Stream &Stream::ThenBlasGemmBatched(
4444     blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n,
4445     uint64 k, double alpha, const port::ArraySlice<DeviceMemory<double> *> &a,
4446     int lda, const port::ArraySlice<DeviceMemory<double> *> &b, int ldb,
4447     double beta, const port::ArraySlice<DeviceMemory<double> *> &c, int ldc,
4448     int batch_count) {
4449   return ThenBlasGemmBatchedWithScratch(transa, transb, m, n, k, alpha, a, lda,
4450                                         b, ldb, beta, c, ldc, batch_count,
4451                                         /*scratch_allocator=*/nullptr);
4452 }
4453 
ThenBlasGemmBatchedWithScratch(blas::Transpose transa,blas::Transpose transb,uint64 m,uint64 n,uint64 k,double alpha,const port::ArraySlice<DeviceMemory<double> * > & a,int lda,const port::ArraySlice<DeviceMemory<double> * > & b,int ldb,double beta,const port::ArraySlice<DeviceMemory<double> * > & c,int ldc,int batch_count,ScratchAllocator * scratch_allocator)4454 Stream &Stream::ThenBlasGemmBatchedWithScratch(
4455     blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n,
4456     uint64 k, double alpha, const port::ArraySlice<DeviceMemory<double> *> &a,
4457     int lda, const port::ArraySlice<DeviceMemory<double> *> &b, int ldb,
4458     double beta, const port::ArraySlice<DeviceMemory<double> *> &c, int ldc,
4459     int batch_count, ScratchAllocator *scratch_allocator) {
4460   VLOG_CALL(PARAM(transa), PARAM(transb), PARAM(m), PARAM(n), PARAM(k),
4461             PARAM(alpha), PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb),
4462             PARAM(beta), PARAM(c), PARAM(ldc), PARAM(batch_count));
4463 
4464   ThenBlasImpl<blas::Transpose, blas::Transpose, uint64, uint64, uint64, double,
4465                const port::ArraySlice<DeviceMemory<double> *> &, int,
4466                const port::ArraySlice<DeviceMemory<double> *> &, int, double,
4467                const port::ArraySlice<DeviceMemory<double> *> &, int, int,
4468                ScratchAllocator *>
4469       impl;
4470   return impl(this, &blas::BlasSupport::DoBlasGemmBatched, transa, transb, m, n,
4471               k, alpha, a, lda, b, ldb, beta, c, ldc, batch_count,
4472               scratch_allocator);
4473 }
4474 
ThenBlasGemmBatched(blas::Transpose transa,blas::Transpose transb,uint64 m,uint64 n,uint64 k,std::complex<float> alpha,const port::ArraySlice<DeviceMemory<std::complex<float>> * > & a,int lda,const port::ArraySlice<DeviceMemory<std::complex<float>> * > & b,int ldb,std::complex<float> beta,const port::ArraySlice<DeviceMemory<std::complex<float>> * > & c,int ldc,int batch_count)4475 Stream &Stream::ThenBlasGemmBatched(
4476     blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n,
4477     uint64 k, std::complex<float> alpha,
4478     const port::ArraySlice<DeviceMemory<std::complex<float>> *> &a, int lda,
4479     const port::ArraySlice<DeviceMemory<std::complex<float>> *> &b, int ldb,
4480     std::complex<float> beta,
4481     const port::ArraySlice<DeviceMemory<std::complex<float>> *> &c, int ldc,
4482     int batch_count) {
4483   return ThenBlasGemmBatchedWithScratch(transa, transb, m, n, k, alpha, a, lda,
4484                                         b, ldb, beta, c, ldc, batch_count,
4485                                         /*scratch_allocator=*/nullptr);
4486 }
4487 
ThenBlasGemmBatchedWithScratch(blas::Transpose transa,blas::Transpose transb,uint64 m,uint64 n,uint64 k,std::complex<float> alpha,const port::ArraySlice<DeviceMemory<std::complex<float>> * > & a,int lda,const port::ArraySlice<DeviceMemory<std::complex<float>> * > & b,int ldb,std::complex<float> beta,const port::ArraySlice<DeviceMemory<std::complex<float>> * > & c,int ldc,int batch_count,ScratchAllocator * scratch_allocator)4488 Stream &Stream::ThenBlasGemmBatchedWithScratch(
4489     blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n,
4490     uint64 k, std::complex<float> alpha,
4491     const port::ArraySlice<DeviceMemory<std::complex<float>> *> &a, int lda,
4492     const port::ArraySlice<DeviceMemory<std::complex<float>> *> &b, int ldb,
4493     std::complex<float> beta,
4494     const port::ArraySlice<DeviceMemory<std::complex<float>> *> &c, int ldc,
4495     int batch_count, ScratchAllocator *scratch_allocator) {
4496   VLOG_CALL(PARAM(transa), PARAM(transb), PARAM(m), PARAM(n), PARAM(k),
4497             PARAM(alpha), PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb),
4498             PARAM(beta), PARAM(c), PARAM(ldc), PARAM(batch_count));
4499 
4500   ThenBlasImpl<blas::Transpose, blas::Transpose, uint64, uint64, uint64,
4501                std::complex<float>,
4502                const port::ArraySlice<DeviceMemory<std::complex<float>> *> &,
4503                int,
4504                const port::ArraySlice<DeviceMemory<std::complex<float>> *> &,
4505                int, std::complex<float>,
4506                const port::ArraySlice<DeviceMemory<std::complex<float>> *> &,
4507                int, int, ScratchAllocator *>
4508       impl;
4509   return impl(this, &blas::BlasSupport::DoBlasGemmBatched, transa, transb, m, n,
4510               k, alpha, a, lda, b, ldb, beta, c, ldc, batch_count,
4511               scratch_allocator);
4512 }
4513 
ThenBlasGemmBatched(blas::Transpose transa,blas::Transpose transb,uint64 m,uint64 n,uint64 k,std::complex<double> alpha,const port::ArraySlice<DeviceMemory<std::complex<double>> * > & a,int lda,const port::ArraySlice<DeviceMemory<std::complex<double>> * > & b,int ldb,std::complex<double> beta,const port::ArraySlice<DeviceMemory<std::complex<double>> * > & c,int ldc,int batch_count)4514 Stream &Stream::ThenBlasGemmBatched(
4515     blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n,
4516     uint64 k, std::complex<double> alpha,
4517     const port::ArraySlice<DeviceMemory<std::complex<double>> *> &a, int lda,
4518     const port::ArraySlice<DeviceMemory<std::complex<double>> *> &b, int ldb,
4519     std::complex<double> beta,
4520     const port::ArraySlice<DeviceMemory<std::complex<double>> *> &c, int ldc,
4521     int batch_count) {
4522   return ThenBlasGemmBatchedWithScratch(transa, transb, m, n, k, alpha, a, lda,
4523                                         b, ldb, beta, c, ldc, batch_count,
4524                                         /*scratch_allocator=*/nullptr);
4525 }
4526 
ThenBlasGemmBatchedWithScratch(blas::Transpose transa,blas::Transpose transb,uint64 m,uint64 n,uint64 k,std::complex<double> alpha,const port::ArraySlice<DeviceMemory<std::complex<double>> * > & a,int lda,const port::ArraySlice<DeviceMemory<std::complex<double>> * > & b,int ldb,std::complex<double> beta,const port::ArraySlice<DeviceMemory<std::complex<double>> * > & c,int ldc,int batch_count,ScratchAllocator * scratch_allocator)4527 Stream &Stream::ThenBlasGemmBatchedWithScratch(
4528     blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n,
4529     uint64 k, std::complex<double> alpha,
4530     const port::ArraySlice<DeviceMemory<std::complex<double>> *> &a, int lda,
4531     const port::ArraySlice<DeviceMemory<std::complex<double>> *> &b, int ldb,
4532     std::complex<double> beta,
4533     const port::ArraySlice<DeviceMemory<std::complex<double>> *> &c, int ldc,
4534     int batch_count, ScratchAllocator *scratch_allocator) {
4535   VLOG_CALL(PARAM(transa), PARAM(transb), PARAM(m), PARAM(n), PARAM(k),
4536             PARAM(alpha), PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb),
4537             PARAM(beta), PARAM(c), PARAM(ldc), PARAM(batch_count));
4538 
4539   ThenBlasImpl<blas::Transpose, blas::Transpose, uint64, uint64, uint64,
4540                std::complex<double>,
4541                const port::ArraySlice<DeviceMemory<std::complex<double>> *> &,
4542                int,
4543                const port::ArraySlice<DeviceMemory<std::complex<double>> *> &,
4544                int, std::complex<double>,
4545                const port::ArraySlice<DeviceMemory<std::complex<double>> *> &,
4546                int, int, ScratchAllocator *>
4547       impl;
4548   return impl(this, &blas::BlasSupport::DoBlasGemmBatched, transa, transb, m, n,
4549               k, alpha, a, lda, b, ldb, beta, c, ldc, batch_count,
4550               scratch_allocator);
4551 }
4552 
ThenBlasGemmStridedBatched(blas::Transpose transa,blas::Transpose transb,uint64 m,uint64 n,uint64 k,float alpha,const DeviceMemory<Eigen::half> & a,int lda,int64 stride_a,const DeviceMemory<Eigen::half> & b,int ldb,int64 stride_b,float beta,DeviceMemory<Eigen::half> * c,int ldc,int64 stride_c,int batch_count)4553 Stream &Stream::ThenBlasGemmStridedBatched(
4554     blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n,
4555     uint64 k, float alpha, const DeviceMemory<Eigen::half> &a, int lda,
4556     int64 stride_a, const DeviceMemory<Eigen::half> &b, int ldb, int64 stride_b,
4557     float beta, DeviceMemory<Eigen::half> *c, int ldc, int64 stride_c,
4558     int batch_count) {
4559   VLOG_CALL(PARAM(transa), PARAM(transb), PARAM(m), PARAM(n), PARAM(k),
4560             PARAM(alpha), PARAM(a), PARAM(lda), PARAM(stride_a), PARAM(b),
4561             PARAM(ldb), PARAM(stride_b), PARAM(beta), PARAM(c), PARAM(ldc),
4562             PARAM(stride_c), PARAM(batch_count));
4563 
4564   ThenBlasImpl<blas::Transpose, blas::Transpose, uint64, uint64, uint64, float,
4565                const DeviceMemory<Eigen::half> &, int, int64,
4566                const DeviceMemory<Eigen::half> &, int, int64, float,
4567                DeviceMemory<Eigen::half> *, int, int64, int>
4568       impl;
4569   return impl(this, &blas::BlasSupport::DoBlasGemmStridedBatched, transa,
4570               transb, m, n, k, alpha, a, lda, stride_a, b, ldb, stride_b, beta,
4571               c, ldc, stride_c, batch_count);
4572 }
4573 
ThenBlasGemmStridedBatched(blas::Transpose transa,blas::Transpose transb,uint64 m,uint64 n,uint64 k,float alpha,const DeviceMemory<float> & a,int lda,int64 stride_a,const DeviceMemory<float> & b,int ldb,int64 stride_b,float beta,DeviceMemory<float> * c,int ldc,int64 stride_c,int batch_count)4574 Stream &Stream::ThenBlasGemmStridedBatched(
4575     blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n,
4576     uint64 k, float alpha, const DeviceMemory<float> &a, int lda,
4577     int64 stride_a, const DeviceMemory<float> &b, int ldb, int64 stride_b,
4578     float beta, DeviceMemory<float> *c, int ldc, int64 stride_c,
4579     int batch_count) {
4580   VLOG_CALL(PARAM(transa), PARAM(transb), PARAM(m), PARAM(n), PARAM(k),
4581             PARAM(alpha), PARAM(a), PARAM(lda), PARAM(stride_a), PARAM(b),
4582             PARAM(ldb), PARAM(stride_b), PARAM(beta), PARAM(c), PARAM(ldc),
4583             PARAM(stride_c), PARAM(batch_count));
4584 
4585   ThenBlasImpl<blas::Transpose, blas::Transpose, uint64, uint64, uint64, float,
4586                const DeviceMemory<float> &, int, int64,
4587                const DeviceMemory<float> &, int, int64, float,
4588                DeviceMemory<float> *, int, int64, int>
4589       impl;
4590   return impl(this, &blas::BlasSupport::DoBlasGemmStridedBatched, transa,
4591               transb, m, n, k, alpha, a, lda, stride_a, b, ldb, stride_b, beta,
4592               c, ldc, stride_c, batch_count);
4593 }
4594 
ThenBlasGemmStridedBatched(blas::Transpose transa,blas::Transpose transb,uint64 m,uint64 n,uint64 k,double alpha,const DeviceMemory<double> & a,int lda,int64 stride_a,const DeviceMemory<double> & b,int ldb,int64 stride_b,double beta,DeviceMemory<double> * c,int ldc,int64 stride_c,int batch_count)4595 Stream &Stream::ThenBlasGemmStridedBatched(
4596     blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n,
4597     uint64 k, double alpha, const DeviceMemory<double> &a, int lda,
4598     int64 stride_a, const DeviceMemory<double> &b, int ldb, int64 stride_b,
4599     double beta, DeviceMemory<double> *c, int ldc, int64 stride_c,
4600     int batch_count) {
4601   VLOG_CALL(PARAM(transa), PARAM(transb), PARAM(m), PARAM(n), PARAM(k),
4602             PARAM(alpha), PARAM(a), PARAM(lda), PARAM(stride_a), PARAM(b),
4603             PARAM(ldb), PARAM(stride_b), PARAM(beta), PARAM(c), PARAM(ldc),
4604             PARAM(stride_c), PARAM(batch_count));
4605 
4606   ThenBlasImpl<blas::Transpose, blas::Transpose, uint64, uint64, uint64, double,
4607                const DeviceMemory<double> &, int, int64,
4608                const DeviceMemory<double> &, int, int64, double,
4609                DeviceMemory<double> *, int, int64, int>
4610       impl;
4611   return impl(this, &blas::BlasSupport::DoBlasGemmStridedBatched, transa,
4612               transb, m, n, k, alpha, a, lda, stride_a, b, ldb, stride_b, beta,
4613               c, ldc, stride_c, batch_count);
4614 }
4615 
ThenBlasGemmStridedBatched(blas::Transpose transa,blas::Transpose transb,uint64 m,uint64 n,uint64 k,std::complex<float> alpha,const DeviceMemory<std::complex<float>> & a,int lda,int64 stride_a,const DeviceMemory<std::complex<float>> & b,int ldb,int64 stride_b,std::complex<float> beta,DeviceMemory<std::complex<float>> * c,int ldc,int64 stride_c,int batch_count)4616 Stream &Stream::ThenBlasGemmStridedBatched(
4617     blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n,
4618     uint64 k, std::complex<float> alpha,
4619     const DeviceMemory<std::complex<float>> &a, int lda, int64 stride_a,
4620     const DeviceMemory<std::complex<float>> &b, int ldb, int64 stride_b,
4621     std::complex<float> beta, DeviceMemory<std::complex<float>> *c, int ldc,
4622     int64 stride_c, int batch_count) {
4623   VLOG_CALL(PARAM(transa), PARAM(transb), PARAM(m), PARAM(n), PARAM(k),
4624             PARAM(alpha), PARAM(a), PARAM(lda), PARAM(stride_a), PARAM(b),
4625             PARAM(ldb), PARAM(stride_b), PARAM(beta), PARAM(c), PARAM(ldc),
4626             PARAM(stride_c), PARAM(batch_count));
4627 
4628   ThenBlasImpl<blas::Transpose, blas::Transpose, uint64, uint64, uint64,
4629                std::complex<float>, const DeviceMemory<std::complex<float>> &,
4630                int, int64, const DeviceMemory<std::complex<float>> &, int,
4631                int64, std::complex<float>, DeviceMemory<std::complex<float>> *,
4632                int, int64, int>
4633       impl;
4634   return impl(this, &blas::BlasSupport::DoBlasGemmStridedBatched, transa,
4635               transb, m, n, k, alpha, a, lda, stride_a, b, ldb, stride_b, beta,
4636               c, ldc, stride_c, batch_count);
4637 }
4638 
ThenBlasGemmStridedBatched(blas::Transpose transa,blas::Transpose transb,uint64 m,uint64 n,uint64 k,std::complex<double> alpha,const DeviceMemory<std::complex<double>> & a,int lda,int64 stride_a,const DeviceMemory<std::complex<double>> & b,int ldb,int64 stride_b,std::complex<double> beta,DeviceMemory<std::complex<double>> * c,int ldc,int64 stride_c,int batch_count)4639 Stream &Stream::ThenBlasGemmStridedBatched(
4640     blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n,
4641     uint64 k, std::complex<double> alpha,
4642     const DeviceMemory<std::complex<double>> &a, int lda, int64 stride_a,
4643     const DeviceMemory<std::complex<double>> &b, int ldb, int64 stride_b,
4644     std::complex<double> beta, DeviceMemory<std::complex<double>> *c, int ldc,
4645     int64 stride_c, int batch_count) {
4646   VLOG_CALL(PARAM(transa), PARAM(transb), PARAM(m), PARAM(n), PARAM(k),
4647             PARAM(alpha), PARAM(a), PARAM(lda), PARAM(stride_a), PARAM(b),
4648             PARAM(ldb), PARAM(stride_b), PARAM(beta), PARAM(c), PARAM(ldc),
4649             PARAM(stride_c), PARAM(batch_count));
4650 
4651   ThenBlasImpl<blas::Transpose, blas::Transpose, uint64, uint64, uint64,
4652                std::complex<double>, const DeviceMemory<std::complex<double>> &,
4653                int, int64, const DeviceMemory<std::complex<double>> &, int,
4654                int64, std::complex<double>,
4655                DeviceMemory<std::complex<double>> *, int, int64, int>
4656       impl;
4657   return impl(this, &blas::BlasSupport::DoBlasGemmStridedBatched, transa,
4658               transb, m, n, k, alpha, a, lda, stride_a, b, ldb, stride_b, beta,
4659               c, ldc, stride_c, batch_count);
4660 }
4661 
ThenSetRngSeed(const uint8 * seed,uint64 seed_bytes)4662 Stream &Stream::ThenSetRngSeed(const uint8 *seed, uint64 seed_bytes) {
4663   VLOG_CALL(PARAM(seed), PARAM(seed_bytes));
4664 
4665   if (ok()) {
4666     if (rng::RngSupport *rng = parent_->AsRng()) {
4667       CheckError(rng->SetSeed(this, seed, seed_bytes));
4668     } else {
4669       SetError();
4670       LOG(INFO) << DebugStreamPointers() << " unable to initialize RNG";
4671     }
4672   } else {
4673     LOG(INFO) << DebugStreamPointers()
4674               << " did not set RNG seed: " << static_cast<const void *>(seed)
4675               << "; bytes: " << seed_bytes;
4676   }
4677   return *this;
4678 }
4679 
ThenPopulateRandUniform(DeviceMemory<float> * values)4680 Stream &Stream::ThenPopulateRandUniform(DeviceMemory<float> *values) {
4681   VLOG_CALL(PARAM(values));
4682 
4683   if (ok()) {
4684     if (rng::RngSupport *rng = parent_->AsRng()) {
4685       CheckError(rng->DoPopulateRandUniform(this, values));
4686     } else {
4687       SetError();
4688       LOG(INFO) << DebugStreamPointers()
4689                 << " attempting to perform RNG operation using StreamExecutor"
4690                    " without RNG support.";
4691     }
4692   }
4693   return *this;
4694 }
4695 
ThenPopulateRandGaussian(float mean,float sd,DeviceMemory<float> * values)4696 Stream &Stream::ThenPopulateRandGaussian(float mean, float sd,
4697                                          DeviceMemory<float> *values) {
4698   VLOG_CALL(PARAM(mean), PARAM(sd), PARAM(values));
4699 
4700   if (ok()) {
4701     if (rng::RngSupport *rng = parent_->AsRng()) {
4702       CheckError(rng->DoPopulateRandGaussian(this, mean, sd, values));
4703     } else {
4704       SetError();
4705       LOG(INFO) << DebugStreamPointers()
4706                 << " attempting to perform RNG operation using StreamExecutor"
4707                    " without RNG support.";
4708     }
4709   }
4710   return *this;
4711 }
4712 
ThenPopulateRandGaussian(double mean,double sd,DeviceMemory<double> * values)4713 Stream &Stream::ThenPopulateRandGaussian(double mean, double sd,
4714                                          DeviceMemory<double> *values) {
4715   VLOG_CALL(PARAM(mean), PARAM(sd), PARAM(values));
4716 
4717   if (ok()) {
4718     if (rng::RngSupport *rng = parent_->AsRng()) {
4719       CheckError(rng->DoPopulateRandGaussian(this, mean, sd, values));
4720     } else {
4721       SetError();
4722       LOG(INFO) << DebugStreamPointers()
4723                 << " attempting to perform RNG operation using StreamExecutor"
4724                    " without RNG support.";
4725     }
4726   }
4727   return *this;
4728 }
4729 
ThenPopulateRandUniform(DeviceMemory<double> * values)4730 Stream &Stream::ThenPopulateRandUniform(DeviceMemory<double> *values) {
4731   VLOG_CALL(PARAM(values));
4732 
4733   if (ok()) {
4734     if (rng::RngSupport *rng = parent_->AsRng()) {
4735       CheckError(rng->DoPopulateRandUniform(this, values));
4736     } else {
4737       SetError();
4738       LOG(INFO) << DebugStreamPointers()
4739                 << " attempting to perform RNG operation using StreamExecutor"
4740                    " without RNG support.";
4741     }
4742   }
4743   return *this;
4744 }
4745 
ThenPopulateRandUniform(DeviceMemory<std::complex<float>> * values)4746 Stream &Stream::ThenPopulateRandUniform(
4747     DeviceMemory<std::complex<float>> *values) {
4748   VLOG_CALL(PARAM(values));
4749 
4750   if (ok()) {
4751     if (rng::RngSupport *rng = parent_->AsRng()) {
4752       CheckError(rng->DoPopulateRandUniform(this, values));
4753     } else {
4754       SetError();
4755       LOG(INFO) << DebugStreamPointers()
4756                 << " attempting to perform RNG operation using StreamExecutor"
4757                    " without RNG support.";
4758     }
4759   }
4760   return *this;
4761 }
4762 
ThenPopulateRandUniform(DeviceMemory<std::complex<double>> * values)4763 Stream &Stream::ThenPopulateRandUniform(
4764     DeviceMemory<std::complex<double>> *values) {
4765   VLOG_CALL(PARAM(values));
4766 
4767   if (ok()) {
4768     if (rng::RngSupport *rng = parent_->AsRng()) {
4769       CheckError(rng->DoPopulateRandUniform(this, values));
4770     } else {
4771       SetError();
4772       LOG(INFO) << DebugStreamPointers()
4773                 << " attempting to perform RNG operation using StreamExecutor"
4774                    " without RNG support.";
4775     }
4776   }
4777   return *this;
4778 }
4779 
ThenMemcpy(void * host_dst,const DeviceMemoryBase & gpu_src,uint64 size)4780 Stream &Stream::ThenMemcpy(void *host_dst, const DeviceMemoryBase &gpu_src,
4781                            uint64 size) {
4782   VLOG_CALL(PARAM(host_dst), PARAM(gpu_src), PARAM(size));
4783 
4784   if (ok()) {
4785     CheckError(parent_->Memcpy(this, host_dst, gpu_src, size));
4786   } else {
4787     LOG(INFO) << DebugStreamPointers()
4788               << " did not memcpy device-to-host; source: " << gpu_src.opaque();
4789   }
4790   return *this;
4791 }
4792 
ThenMemcpy(DeviceMemoryBase * gpu_dst,const void * host_src,uint64 size)4793 Stream &Stream::ThenMemcpy(DeviceMemoryBase *gpu_dst, const void *host_src,
4794                            uint64 size) {
4795   VLOG_CALL(PARAM(gpu_dst), PARAM(host_src), PARAM(size));
4796 
4797   if (ok()) {
4798     CheckError(parent_->Memcpy(this, gpu_dst, host_src, size));
4799   } else {
4800     LOG(INFO) << DebugStreamPointers()
4801               << " did not memcpy host-to-device; source: " << host_src;
4802   }
4803   return *this;
4804 }
4805 
ThenMemcpy(DeviceMemoryBase * gpu_dst,const DeviceMemoryBase & gpu_src,uint64 size)4806 Stream &Stream::ThenMemcpy(DeviceMemoryBase *gpu_dst,
4807                            const DeviceMemoryBase &gpu_src, uint64 size) {
4808   VLOG_CALL(PARAM(gpu_dst), PARAM(gpu_src), PARAM(size));
4809 
4810   if (ok()) {
4811     CheckError(parent_->MemcpyDeviceToDevice(this, gpu_dst, gpu_src, size));
4812   } else {
4813     LOG(INFO) << DebugStreamPointers()
4814               << " did not memcpy gpu-to-gpu; source: " << &gpu_src;
4815   }
4816   return *this;
4817 }
4818 
ThenMemZero(DeviceMemoryBase * location,uint64 size)4819 Stream &Stream::ThenMemZero(DeviceMemoryBase *location, uint64 size) {
4820   VLOG_CALL(PARAM(location), PARAM(size));
4821 
4822   if (ok()) {
4823     CheckError(parent_->MemZero(this, location, size));
4824   } else {
4825     LOG(INFO) << DebugStreamPointers()
4826               << " did not memzero GPU location; source: " << location;
4827   }
4828   return *this;
4829 }
4830 
ThenMemset32(DeviceMemoryBase * location,uint32 pattern,uint64 size)4831 Stream &Stream::ThenMemset32(DeviceMemoryBase *location, uint32 pattern,
4832                              uint64 size) {
4833   VLOG_CALL(PARAM(location), PARAM(pattern), PARAM(size));
4834 
4835   if (ok()) {
4836     CheckError(parent_->Memset32(this, location, pattern, size));
4837   } else {
4838     LOG(INFO) << DebugStreamPointers()
4839               << " did not memset GPU location; source: " << location
4840               << "; size: " << size << "; pattern: " << std::hex << pattern;
4841   }
4842   return *this;
4843 }
4844 
ThenRnnForward(const dnn::RnnDescriptor & rnn_desc,const dnn::RnnSequenceTensorDescriptor & input_desc,const DeviceMemory<Eigen::half> & input_data,const dnn::RnnStateTensorDescriptor & input_h_desc,const DeviceMemory<Eigen::half> & input_h_data,const dnn::RnnStateTensorDescriptor & input_c_desc,const DeviceMemory<Eigen::half> & input_c_data,const DeviceMemory<Eigen::half> & params,const dnn::RnnSequenceTensorDescriptor & output_desc,DeviceMemory<Eigen::half> * output_data,const dnn::RnnStateTensorDescriptor & output_h_desc,DeviceMemory<Eigen::half> * output_h_data,const dnn::RnnStateTensorDescriptor & output_c_desc,DeviceMemory<Eigen::half> * output_c_data,bool is_training,ScratchAllocator * reserve_space_allocator,ScratchAllocator * workspace_allocator,dnn::ProfileResult * output_profile_result)4845 Stream &Stream::ThenRnnForward(
4846     const dnn::RnnDescriptor &rnn_desc,
4847     const dnn::RnnSequenceTensorDescriptor &input_desc,
4848     const DeviceMemory<Eigen::half> &input_data,
4849     const dnn::RnnStateTensorDescriptor &input_h_desc,
4850     const DeviceMemory<Eigen::half> &input_h_data,
4851     const dnn::RnnStateTensorDescriptor &input_c_desc,
4852     const DeviceMemory<Eigen::half> &input_c_data,
4853     const DeviceMemory<Eigen::half> &params,
4854     const dnn::RnnSequenceTensorDescriptor &output_desc,
4855     DeviceMemory<Eigen::half> *output_data,
4856     const dnn::RnnStateTensorDescriptor &output_h_desc,
4857     DeviceMemory<Eigen::half> *output_h_data,
4858     const dnn::RnnStateTensorDescriptor &output_c_desc,
4859     DeviceMemory<Eigen::half> *output_c_data, bool is_training,
4860     ScratchAllocator *reserve_space_allocator,
4861     ScratchAllocator *workspace_allocator,
4862     dnn::ProfileResult *output_profile_result) {
4863   // TODO(zhengxq): add VLOG PARAM calls.
4864   if (ok()) {
4865     if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
4866       auto status = dnn->DoRnnForward(
4867           this, rnn_desc, input_desc, input_data, input_h_desc, input_h_data,
4868           input_c_desc, input_c_data, params, output_desc, output_data,
4869           output_h_desc, output_h_data, output_c_desc, output_c_data,
4870           is_training, reserve_space_allocator, workspace_allocator,
4871           output_profile_result);
4872       if (!status && !output_profile_result) {
4873         SetError();
4874       }
4875     } else {
4876       SetErrorAndLogNoDnnSupport();
4877     }
4878   }
4879   return *this;
4880 }
4881 
ThenRnnForward(const dnn::RnnDescriptor & rnn_desc,const dnn::RnnSequenceTensorDescriptor & input_desc,const DeviceMemory<float> & input_data,const dnn::RnnStateTensorDescriptor & input_h_desc,const DeviceMemory<float> & input_h_data,const dnn::RnnStateTensorDescriptor & input_c_desc,const DeviceMemory<float> & input_c_data,const DeviceMemory<float> & params,const dnn::RnnSequenceTensorDescriptor & output_desc,DeviceMemory<float> * output_data,const dnn::RnnStateTensorDescriptor & output_h_desc,DeviceMemory<float> * output_h_data,const dnn::RnnStateTensorDescriptor & output_c_desc,DeviceMemory<float> * output_c_data,bool is_training,ScratchAllocator * reserve_space_allocator,ScratchAllocator * workspace_allocator,dnn::ProfileResult * output_profile_result)4882 Stream &Stream::ThenRnnForward(
4883     const dnn::RnnDescriptor &rnn_desc,
4884     const dnn::RnnSequenceTensorDescriptor &input_desc,
4885     const DeviceMemory<float> &input_data,
4886     const dnn::RnnStateTensorDescriptor &input_h_desc,
4887     const DeviceMemory<float> &input_h_data,
4888     const dnn::RnnStateTensorDescriptor &input_c_desc,
4889     const DeviceMemory<float> &input_c_data, const DeviceMemory<float> &params,
4890     const dnn::RnnSequenceTensorDescriptor &output_desc,
4891     DeviceMemory<float> *output_data,
4892     const dnn::RnnStateTensorDescriptor &output_h_desc,
4893     DeviceMemory<float> *output_h_data,
4894     const dnn::RnnStateTensorDescriptor &output_c_desc,
4895     DeviceMemory<float> *output_c_data, bool is_training,
4896     ScratchAllocator *reserve_space_allocator,
4897     ScratchAllocator *workspace_allocator,
4898     dnn::ProfileResult *output_profile_result) {
4899   // TODO(zhengxq): add VLOG PARAM calls.
4900   if (ok()) {
4901     if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
4902       auto status = dnn->DoRnnForward(
4903           this, rnn_desc, input_desc, input_data, input_h_desc, input_h_data,
4904           input_c_desc, input_c_data, params, output_desc, output_data,
4905           output_h_desc, output_h_data, output_c_desc, output_c_data,
4906           is_training, reserve_space_allocator, workspace_allocator,
4907           output_profile_result);
4908       if (!status && !output_profile_result) {
4909         SetError();
4910       }
4911     } else {
4912       SetErrorAndLogNoDnnSupport();
4913     }
4914   }
4915   return *this;
4916 }
4917 
ThenRnnForward(const dnn::RnnDescriptor & rnn_desc,const dnn::RnnSequenceTensorDescriptor & input_desc,const DeviceMemory<double> & input_data,const dnn::RnnStateTensorDescriptor & input_h_desc,const DeviceMemory<double> & input_h_data,const dnn::RnnStateTensorDescriptor & input_c_desc,const DeviceMemory<double> & input_c_data,const DeviceMemory<double> & params,const dnn::RnnSequenceTensorDescriptor & output_desc,DeviceMemory<double> * output_data,const dnn::RnnStateTensorDescriptor & output_h_desc,DeviceMemory<double> * output_h_data,const dnn::RnnStateTensorDescriptor & output_c_desc,DeviceMemory<double> * output_c_data,bool is_training,ScratchAllocator * reserve_space_allocator,ScratchAllocator * workspace_allocator,dnn::ProfileResult * output_profile_result)4918 Stream &Stream::ThenRnnForward(
4919     const dnn::RnnDescriptor &rnn_desc,
4920     const dnn::RnnSequenceTensorDescriptor &input_desc,
4921     const DeviceMemory<double> &input_data,
4922     const dnn::RnnStateTensorDescriptor &input_h_desc,
4923     const DeviceMemory<double> &input_h_data,
4924     const dnn::RnnStateTensorDescriptor &input_c_desc,
4925     const DeviceMemory<double> &input_c_data,
4926     const DeviceMemory<double> &params,
4927     const dnn::RnnSequenceTensorDescriptor &output_desc,
4928     DeviceMemory<double> *output_data,
4929     const dnn::RnnStateTensorDescriptor &output_h_desc,
4930     DeviceMemory<double> *output_h_data,
4931     const dnn::RnnStateTensorDescriptor &output_c_desc,
4932     DeviceMemory<double> *output_c_data, bool is_training,
4933     ScratchAllocator *reserve_space_allocator,
4934     ScratchAllocator *workspace_allocator,
4935     dnn::ProfileResult *output_profile_result) {
4936   // TODO(zhengxq): add VLOG PARAM calls.
4937   if (ok()) {
4938     if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
4939       auto status = dnn->DoRnnForward(
4940           this, rnn_desc, input_desc, input_data, input_h_desc, input_h_data,
4941           input_c_desc, input_c_data, params, output_desc, output_data,
4942           output_h_desc, output_h_data, output_c_desc, output_c_data,
4943           is_training, reserve_space_allocator, workspace_allocator,
4944           output_profile_result);
4945       if (!status && !output_profile_result) {
4946         SetError();
4947       }
4948     } else {
4949       SetErrorAndLogNoDnnSupport();
4950     }
4951   }
4952   return *this;
4953 }
4954 
ThenRnnBackward(const dnn::RnnDescriptor & rnn_desc,const dnn::RnnSequenceTensorDescriptor & input_desc,const DeviceMemory<Eigen::half> & input_data,const dnn::RnnStateTensorDescriptor & input_h_desc,const DeviceMemory<Eigen::half> & input_h_data,const dnn::RnnStateTensorDescriptor & input_c_desc,const DeviceMemory<Eigen::half> & input_c_data,const DeviceMemory<Eigen::half> & params,const dnn::RnnSequenceTensorDescriptor & output_desc,const DeviceMemory<Eigen::half> & output_data,const dnn::RnnStateTensorDescriptor & output_h_desc,const DeviceMemory<Eigen::half> & output_h_data,const dnn::RnnStateTensorDescriptor & output_c_desc,const DeviceMemory<Eigen::half> & output_c_data,const DeviceMemory<Eigen::half> & output_backprop_data,const DeviceMemory<Eigen::half> & output_h_backprop_data,const DeviceMemory<Eigen::half> & output_c_backprop_data,DeviceMemory<Eigen::half> * input_backprop_data,DeviceMemory<Eigen::half> * input_h_backprop_data,DeviceMemory<Eigen::half> * input_c_backprop_data,DeviceMemory<Eigen::half> * params_backprop_data,DeviceMemory<uint8> * reserve_space_data,ScratchAllocator * workspace_allocator,dnn::ProfileResult * output_profile_result)4955 Stream &Stream::ThenRnnBackward(
4956     const dnn::RnnDescriptor &rnn_desc,
4957     const dnn::RnnSequenceTensorDescriptor &input_desc,
4958     const DeviceMemory<Eigen::half> &input_data,
4959     const dnn::RnnStateTensorDescriptor &input_h_desc,
4960     const DeviceMemory<Eigen::half> &input_h_data,
4961     const dnn::RnnStateTensorDescriptor &input_c_desc,
4962     const DeviceMemory<Eigen::half> &input_c_data,
4963     const DeviceMemory<Eigen::half> &params,
4964     const dnn::RnnSequenceTensorDescriptor &output_desc,
4965     const DeviceMemory<Eigen::half> &output_data,
4966     const dnn::RnnStateTensorDescriptor &output_h_desc,
4967     const DeviceMemory<Eigen::half> &output_h_data,
4968     const dnn::RnnStateTensorDescriptor &output_c_desc,
4969     const DeviceMemory<Eigen::half> &output_c_data,
4970     const DeviceMemory<Eigen::half> &output_backprop_data,
4971     const DeviceMemory<Eigen::half> &output_h_backprop_data,
4972     const DeviceMemory<Eigen::half> &output_c_backprop_data,
4973     DeviceMemory<Eigen::half> *input_backprop_data,
4974     DeviceMemory<Eigen::half> *input_h_backprop_data,
4975     DeviceMemory<Eigen::half> *input_c_backprop_data,
4976     DeviceMemory<Eigen::half> *params_backprop_data,
4977     DeviceMemory<uint8> *reserve_space_data,
4978     ScratchAllocator *workspace_allocator,
4979     dnn::ProfileResult *output_profile_result) {
4980   // TODO(zhengxq): add VLOG PARAM calls.
4981   if (ok()) {
4982     if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
4983       auto status = dnn->DoRnnBackward(
4984           this, rnn_desc, input_desc, input_data, input_h_desc, input_h_data,
4985           input_c_desc, input_c_data, params, output_desc, output_data,
4986           output_h_desc, output_h_data, output_c_desc, output_c_data,
4987           output_backprop_data, output_h_backprop_data, output_c_backprop_data,
4988           input_backprop_data, input_h_backprop_data, input_c_backprop_data,
4989           params_backprop_data, reserve_space_data, workspace_allocator,
4990           output_profile_result);
4991       if (!status && !output_profile_result) {
4992         SetError();
4993       }
4994     } else {
4995       SetError();
4996       LOG(WARNING) << "Attempting to call ThenRnnBackward without DNN support";
4997     }
4998   }
4999   return *this;
5000 }
5001 
ThenRnnBackward(const dnn::RnnDescriptor & rnn_desc,const dnn::RnnSequenceTensorDescriptor & input_desc,const DeviceMemory<float> & input_data,const dnn::RnnStateTensorDescriptor & input_h_desc,const DeviceMemory<float> & input_h_data,const dnn::RnnStateTensorDescriptor & input_c_desc,const DeviceMemory<float> & input_c_data,const DeviceMemory<float> & params,const dnn::RnnSequenceTensorDescriptor & output_desc,const DeviceMemory<float> & output_data,const dnn::RnnStateTensorDescriptor & output_h_desc,const DeviceMemory<float> & output_h_data,const dnn::RnnStateTensorDescriptor & output_c_desc,const DeviceMemory<float> & output_c_data,const DeviceMemory<float> & output_backprop_data,const DeviceMemory<float> & output_h_backprop_data,const DeviceMemory<float> & output_c_backprop_data,DeviceMemory<float> * input_backprop_data,DeviceMemory<float> * input_h_backprop_data,DeviceMemory<float> * input_c_backprop_data,DeviceMemory<float> * params_backprop_data,DeviceMemory<uint8> * reserve_space_data,ScratchAllocator * workspace_allocator,dnn::ProfileResult * output_profile_result)5002 Stream &Stream::ThenRnnBackward(
5003     const dnn::RnnDescriptor &rnn_desc,
5004     const dnn::RnnSequenceTensorDescriptor &input_desc,
5005     const DeviceMemory<float> &input_data,
5006     const dnn::RnnStateTensorDescriptor &input_h_desc,
5007     const DeviceMemory<float> &input_h_data,
5008     const dnn::RnnStateTensorDescriptor &input_c_desc,
5009     const DeviceMemory<float> &input_c_data, const DeviceMemory<float> &params,
5010     const dnn::RnnSequenceTensorDescriptor &output_desc,
5011     const DeviceMemory<float> &output_data,
5012     const dnn::RnnStateTensorDescriptor &output_h_desc,
5013     const DeviceMemory<float> &output_h_data,
5014     const dnn::RnnStateTensorDescriptor &output_c_desc,
5015     const DeviceMemory<float> &output_c_data,
5016     const DeviceMemory<float> &output_backprop_data,
5017     const DeviceMemory<float> &output_h_backprop_data,
5018     const DeviceMemory<float> &output_c_backprop_data,
5019     DeviceMemory<float> *input_backprop_data,
5020     DeviceMemory<float> *input_h_backprop_data,
5021     DeviceMemory<float> *input_c_backprop_data,
5022     DeviceMemory<float> *params_backprop_data,
5023     DeviceMemory<uint8> *reserve_space_data,
5024     ScratchAllocator *workspace_allocator,
5025     dnn::ProfileResult *output_profile_result) {
5026   // TODO(zhengxq): add VLOG PARAM calls.
5027   if (ok()) {
5028     if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
5029       auto status = dnn->DoRnnBackward(
5030           this, rnn_desc, input_desc, input_data, input_h_desc, input_h_data,
5031           input_c_desc, input_c_data, params, output_desc, output_data,
5032           output_h_desc, output_h_data, output_c_desc, output_c_data,
5033           output_backprop_data, output_h_backprop_data, output_c_backprop_data,
5034           input_backprop_data, input_h_backprop_data, input_c_backprop_data,
5035           params_backprop_data, reserve_space_data, workspace_allocator,
5036           output_profile_result);
5037       if (!status && !output_profile_result) {
5038         SetError();
5039       }
5040     } else {
5041       SetError();
5042       LOG(WARNING) << "Attempting to call ThenRnnBackward without DNN support";
5043     }
5044   }
5045   return *this;
5046 }
5047 
ThenRnnBackward(const dnn::RnnDescriptor & rnn_desc,const dnn::RnnSequenceTensorDescriptor & input_desc,const DeviceMemory<double> & input_data,const dnn::RnnStateTensorDescriptor & input_h_desc,const DeviceMemory<double> & input_h_data,const dnn::RnnStateTensorDescriptor & input_c_desc,const DeviceMemory<double> & input_c_data,const DeviceMemory<double> & params,const dnn::RnnSequenceTensorDescriptor & output_desc,const DeviceMemory<double> & output_data,const dnn::RnnStateTensorDescriptor & output_h_desc,const DeviceMemory<double> & output_h_data,const dnn::RnnStateTensorDescriptor & output_c_desc,const DeviceMemory<double> & output_c_data,const DeviceMemory<double> & output_backprop_data,const DeviceMemory<double> & output_h_backprop_data,const DeviceMemory<double> & output_c_backprop_data,DeviceMemory<double> * input_backprop_data,DeviceMemory<double> * input_h_backprop_data,DeviceMemory<double> * input_c_backprop_data,DeviceMemory<double> * params_backprop_data,DeviceMemory<uint8> * reserve_space_data,ScratchAllocator * workspace_allocator,dnn::ProfileResult * output_profile_result)5048 Stream &Stream::ThenRnnBackward(
5049     const dnn::RnnDescriptor &rnn_desc,
5050     const dnn::RnnSequenceTensorDescriptor &input_desc,
5051     const DeviceMemory<double> &input_data,
5052     const dnn::RnnStateTensorDescriptor &input_h_desc,
5053     const DeviceMemory<double> &input_h_data,
5054     const dnn::RnnStateTensorDescriptor &input_c_desc,
5055     const DeviceMemory<double> &input_c_data,
5056     const DeviceMemory<double> &params,
5057     const dnn::RnnSequenceTensorDescriptor &output_desc,
5058     const DeviceMemory<double> &output_data,
5059     const dnn::RnnStateTensorDescriptor &output_h_desc,
5060     const DeviceMemory<double> &output_h_data,
5061     const dnn::RnnStateTensorDescriptor &output_c_desc,
5062     const DeviceMemory<double> &output_c_data,
5063     const DeviceMemory<double> &output_backprop_data,
5064     const DeviceMemory<double> &output_h_backprop_data,
5065     const DeviceMemory<double> &output_c_backprop_data,
5066     DeviceMemory<double> *input_backprop_data,
5067     DeviceMemory<double> *input_h_backprop_data,
5068     DeviceMemory<double> *input_c_backprop_data,
5069     DeviceMemory<double> *params_backprop_data,
5070     DeviceMemory<uint8> *reserve_space_data,
5071     ScratchAllocator *workspace_allocator,
5072     dnn::ProfileResult *output_profile_result) {
5073   // TODO(zhengxq): add VLOG PARAM calls.
5074   if (ok()) {
5075     if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
5076       auto status = dnn->DoRnnBackward(
5077           this, rnn_desc, input_desc, input_data, input_h_desc, input_h_data,
5078           input_c_desc, input_c_data, params, output_desc, output_data,
5079           output_h_desc, output_h_data, output_c_desc, output_c_data,
5080           output_backprop_data, output_h_backprop_data, output_c_backprop_data,
5081           input_backprop_data, input_h_backprop_data, input_c_backprop_data,
5082           params_backprop_data, reserve_space_data, workspace_allocator,
5083           output_profile_result);
5084       if (!status && !output_profile_result) {
5085         SetError();
5086       }
5087     } else {
5088       SetError();
5089       LOG(WARNING) << "Attempting to call ThenRnnBackward without DNN support";
5090     }
5091   }
5092   return *this;
5093 }
5094 
ThenTransformTensor(const dnn::BatchDescriptor & input_desc,dnn::DataType input_type,const DeviceMemoryBase & input_data,const dnn::BatchDescriptor & output_desc,dnn::DataType output_type,float scale,DeviceMemoryBase * output_data)5095 Stream &Stream::ThenTransformTensor(const dnn::BatchDescriptor &input_desc,
5096                                     dnn::DataType input_type,
5097                                     const DeviceMemoryBase &input_data,
5098                                     const dnn::BatchDescriptor &output_desc,
5099                                     dnn::DataType output_type, float scale,
5100                                     DeviceMemoryBase *output_data) {
5101   VLOG_CALL(PARAM(input_desc), PARAM(input_type), PARAM(input_data),
5102             PARAM(output_desc), PARAM(output_type), PARAM(scale),
5103             PARAM(output_data));
5104   if (ok()) {
5105     if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
5106       CheckError(dnn->DoTransformTensor(this, input_desc, input_type,
5107                                         input_data, output_desc, output_type,
5108                                         scale, output_data));
5109     } else {
5110       SetErrorAndLogNoDnnSupport();
5111     }
5112   }
5113   return *this;
5114 }
5115 
ThenDoHostCallback(std::function<void ()> callback)5116 Stream &Stream::ThenDoHostCallback(std::function<void()> callback) {
5117   VLOG_CALL(PARAM(callback));
5118 
5119   if (!ok()) {
5120     LOG(INFO) << DebugStreamPointers()
5121               << " was in error state before adding host callback";
5122   }
5123   CheckError(parent_->HostCallback(this, std::move(callback)));
5124   return *this;
5125 }
5126 
ThenDoHostCallbackWithStatus(std::function<port::Status ()> callback)5127 Stream &Stream::ThenDoHostCallbackWithStatus(
5128     std::function<port::Status()> callback) {
5129   VLOG_CALL(PARAM(callback));
5130 
5131   if (!ok()) {
5132     LOG(INFO) << DebugStreamPointers()
5133               << " was in error state before adding host callback";
5134   }
5135   CheckError(parent_->HostCallback(this, std::move(callback)));
5136   return *this;
5137 }
5138 
ThenFft(fft::Plan * plan,const DeviceMemory<std::complex<float>> & input,DeviceMemory<std::complex<float>> * output)5139 Stream &Stream::ThenFft(fft::Plan *plan,
5140                         const DeviceMemory<std::complex<float>> &input,
5141                         DeviceMemory<std::complex<float>> *output) {
5142   VLOG_CALL(PARAM(plan), PARAM(input), PARAM(output));
5143 
5144   if (ok()) {
5145     if (fft::FftSupport *fft = parent_->AsFft()) {
5146       CheckError(fft->DoFft(this, plan, input, output));
5147     } else {
5148       SetError();
5149       LOG(INFO) << DebugStreamPointers()
5150                 << " attempting to perform FFT operation using StreamExecutor"
5151                    " without FFT support";
5152     }
5153   }
5154   return *this;
5155 }
5156 
ThenFft(fft::Plan * plan,const DeviceMemory<std::complex<double>> & input,DeviceMemory<std::complex<double>> * output)5157 Stream &Stream::ThenFft(fft::Plan *plan,
5158                         const DeviceMemory<std::complex<double>> &input,
5159                         DeviceMemory<std::complex<double>> *output) {
5160   VLOG_CALL(PARAM(plan), PARAM(input), PARAM(output));
5161 
5162   if (ok()) {
5163     if (fft::FftSupport *fft = parent_->AsFft()) {
5164       CheckError(fft->DoFft(this, plan, input, output));
5165     } else {
5166       SetError();
5167       LOG(INFO) << DebugStreamPointers()
5168                 << " attempting to perform FFT operation using StreamExecutor"
5169                    " without FFT support";
5170     }
5171   }
5172   return *this;
5173 }
5174 
ThenFft(fft::Plan * plan,const DeviceMemory<float> & input,DeviceMemory<std::complex<float>> * output)5175 Stream &Stream::ThenFft(fft::Plan *plan, const DeviceMemory<float> &input,
5176                         DeviceMemory<std::complex<float>> *output) {
5177   VLOG_CALL(PARAM(plan), PARAM(input), PARAM(output));
5178 
5179   if (ok()) {
5180     if (fft::FftSupport *fft = parent_->AsFft()) {
5181       CheckError(fft->DoFft(this, plan, input, output));
5182     } else {
5183       SetError();
5184       LOG(INFO) << DebugStreamPointers()
5185                 << " attempting to perform FFT operation using StreamExecutor"
5186                    " without FFT support";
5187     }
5188   }
5189   return *this;
5190 }
5191 
ThenFft(fft::Plan * plan,const DeviceMemory<double> & input,DeviceMemory<std::complex<double>> * output)5192 Stream &Stream::ThenFft(fft::Plan *plan, const DeviceMemory<double> &input,
5193                         DeviceMemory<std::complex<double>> *output) {
5194   VLOG_CALL(PARAM(plan), PARAM(input), PARAM(output));
5195 
5196   if (ok()) {
5197     if (fft::FftSupport *fft = parent_->AsFft()) {
5198       CheckError(fft->DoFft(this, plan, input, output));
5199     } else {
5200       SetError();
5201       LOG(INFO) << DebugStreamPointers()
5202                 << " attempting to perform FFT operation using StreamExecutor"
5203                    " without FFT support";
5204     }
5205   }
5206   return *this;
5207 }
5208 
ThenFft(fft::Plan * plan,const DeviceMemory<std::complex<float>> & input,DeviceMemory<float> * output)5209 Stream &Stream::ThenFft(fft::Plan *plan,
5210                         const DeviceMemory<std::complex<float>> &input,
5211                         DeviceMemory<float> *output) {
5212   VLOG_CALL(PARAM(plan), PARAM(input), PARAM(output));
5213 
5214   if (ok()) {
5215     if (fft::FftSupport *fft = parent_->AsFft()) {
5216       CheckError(fft->DoFft(this, plan, input, output));
5217     } else {
5218       SetError();
5219       LOG(INFO) << DebugStreamPointers()
5220                 << " attempting to perform FFT operation using StreamExecutor"
5221                    " without FFT support";
5222     }
5223   }
5224   return *this;
5225 }
5226 
ThenFft(fft::Plan * plan,const DeviceMemory<std::complex<double>> & input,DeviceMemory<double> * output)5227 Stream &Stream::ThenFft(fft::Plan *plan,
5228                         const DeviceMemory<std::complex<double>> &input,
5229                         DeviceMemory<double> *output) {
5230   VLOG_CALL(PARAM(plan), PARAM(input), PARAM(output));
5231 
5232   if (ok()) {
5233     if (fft::FftSupport *fft = parent_->AsFft()) {
5234       CheckError(fft->DoFft(this, plan, input, output));
5235     } else {
5236       SetError();
5237       LOG(INFO) << DebugStreamPointers()
5238                 << " attempting to perform FFT operation using StreamExecutor"
5239                    " without FFT support";
5240     }
5241   }
5242   return *this;
5243 }
5244 
5245 // It looks confusing, but all this is doing is inserting a callback at the
5246 // present point in the stream to then enqueue a task on the host executor.
ThenEnqueueOnBackgroundThread(std::function<void (StreamExecutor *)> task)5247 Stream &Stream::ThenEnqueueOnBackgroundThread(
5248     std::function<void(StreamExecutor *)> task) {
5249   VLOG_CALL(PARAM(task));
5250 
5251   StreamExecutor *stream_executor = this->parent_;
5252   std::function<void()> bound_task = std::bind(task, stream_executor);
5253 
5254   return ThenDoHostCallback([stream_executor, bound_task]() {
5255     stream_executor->EnqueueOnBackgroundThread(bound_task);
5256   });
5257 }
5258 
BlockHostUntilDone()5259 port::Status Stream::BlockHostUntilDone() {
5260   VLOG_CALL();
5261 
5262   if (!ok()) {
5263     port::Status status = port::Status(
5264         port::error::INTERNAL,
5265         "stream did not block host until done; was already in an error state");
5266     LOG(INFO) << DebugStreamPointers() << " " << status;
5267     return status;
5268   }
5269 
5270   temporary_memory_manager_.DeallocateFinalizedTemporaries();
5271 
5272   port::Status error = parent_->BlockHostUntilDone(this);
5273   CheckError(error.ok());
5274   return error;
5275 }
5276 
DebugStreamPointers() const5277 string Stream::DebugStreamPointers() const {
5278   // Relies on the ToVlogString(const void*) overload above.
5279   return absl::StrCat("[stream=", ToVlogString(this),
5280                       ",impl=", ToVlogString(implementation_.get()), "]");
5281 }
5282 
CheckStatus(port::Status status)5283 void Stream::CheckStatus(port::Status status) {
5284   if (status.ok()) {
5285     return;
5286   }
5287   LOG(ERROR) << status;
5288   mutex_lock lock(mu_);
5289   ok_ = false;
5290 }
5291 
5292 }  // namespace stream_executor
5293