1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/stream_executor/stream.h"
17
18 #include "tensorflow/stream_executor/platform/port.h"
19
20 #include "absl/strings/str_cat.h"
21 #include "third_party/eigen3/Eigen/Core"
22 #include "tensorflow/stream_executor/blas.h"
23 #include "tensorflow/stream_executor/host_or_device_scalar.h"
24 #include "tensorflow/stream_executor/lib/stacktrace.h"
25 #include "tensorflow/stream_executor/platform.h"
26 #include "tensorflow/stream_executor/platform/logging.h"
27 #include "tensorflow/stream_executor/rng.h"
28 #include "tensorflow/stream_executor/stream_executor_internal.h"
29 #include "tensorflow/stream_executor/stream_executor_pimpl.h"
30
31 namespace stream_executor {
32
33 namespace {
34 // Code to turn parameters to functions on stream into strings that
35 // will be VLOG'ed. We need overloads, instead of
36 // e.g. BatchDescriptorToVlogString(), as the code that calls these
37 // functions does not know what the type of the parameter is.
ToVlogString(const dnn::BatchDescriptor & descriptor)38 string ToVlogString(const dnn::BatchDescriptor &descriptor) {
39 return descriptor.ToShortString();
40 }
41
ToVlogString(const dnn::FilterDescriptor & descriptor)42 string ToVlogString(const dnn::FilterDescriptor &descriptor) {
43 return descriptor.ToShortString();
44 }
45
ToVlogString(const dnn::ConvolutionDescriptor & descriptor)46 string ToVlogString(const dnn::ConvolutionDescriptor &descriptor) {
47 return descriptor.ToShortString();
48 }
49
ToVlogString(const dnn::PoolingDescriptor & descriptor)50 string ToVlogString(const dnn::PoolingDescriptor &descriptor) {
51 return descriptor.ToShortString();
52 }
53
ToVlogString(const dnn::NormalizeDescriptor & descriptor)54 string ToVlogString(const dnn::NormalizeDescriptor &descriptor) {
55 return descriptor.ToShortString();
56 }
57
ToVlogString(dnn::ActivationMode mode)58 string ToVlogString(dnn::ActivationMode mode) {
59 return dnn::ActivationModeString(mode);
60 }
61
ToVlogString(const dnn::AlgorithmConfig & algo_config)62 string ToVlogString(const dnn::AlgorithmConfig &algo_config) {
63 return algo_config.ToString();
64 }
65
ToVlogString(dnn::ElementwiseOperation op)66 string ToVlogString(dnn::ElementwiseOperation op) {
67 return dnn::ElementwiseOperationString(op);
68 }
69
ToVlogString(dnn::QuantizedActivationMode mode)70 string ToVlogString(dnn::QuantizedActivationMode mode) {
71 return dnn::QuantizedActivationModeString(mode);
72 }
73
ToVlogString(blas::Transpose t)74 string ToVlogString(blas::Transpose t) { return blas::TransposeString(t); }
75
ToVlogString(blas::UpperLower ul)76 string ToVlogString(blas::UpperLower ul) { return blas::UpperLowerString(ul); }
77
ToVlogString(blas::Diagonal d)78 string ToVlogString(blas::Diagonal d) { return blas::DiagonalString(d); }
79
ToVlogString(blas::Side s)80 string ToVlogString(blas::Side s) { return blas::SideString(s); }
81
ToVlogString(blas::ComputationType ty)82 string ToVlogString(blas::ComputationType ty) {
83 return blas::ComputationTypeString(ty);
84 }
85
ToVlogString(const void * ptr)86 string ToVlogString(const void *ptr) {
87 if (ptr == nullptr) {
88 return "null";
89 }
90
91 // StrCat does not convert pointers to text.
92 std::ostringstream out;
93 out << ptr;
94 return out.str();
95 }
96
97 template <class T>
ToVlogString(const std::complex<T> & c)98 string ToVlogString(const std::complex<T> &c) {
99 // StrCat does not convert std::complex to text.
100 std::ostringstream out;
101 out << c;
102 return out.str();
103 }
104
105 template <class T>
ToVlogString(const std::function<T> & f)106 string ToVlogString(const std::function<T> &f) {
107 return f == nullptr ? "null" : "<non-null function>";
108 }
109
ToVlogString(const DeviceMemoryBase & memory)110 string ToVlogString(const DeviceMemoryBase &memory) {
111 return ToVlogString(memory.opaque());
112 }
113
ToVlogString(const DeviceMemoryBase * memory)114 string ToVlogString(const DeviceMemoryBase *memory) {
115 return memory == nullptr ? "null" : ToVlogString(*memory);
116 }
117
ToVlogString(const Eigen::half & h)118 string ToVlogString(const Eigen::half &h) {
119 return absl::StrCat(static_cast<float>(h));
120 }
121
ToVlogString(int i)122 string ToVlogString(int i) { return absl::StrCat(i); }
123
ToVlogString(uint32 i)124 string ToVlogString(uint32 i) { return absl::StrCat(i); }
125
ToVlogString(uint64 i)126 string ToVlogString(uint64 i) { return absl::StrCat(i); }
127
ToVlogString(int64 i)128 string ToVlogString(int64 i) { return absl::StrCat(i); }
129
ToVlogString(float f)130 string ToVlogString(float f) { return absl::StrCat(f); }
131
ToVlogString(double d)132 string ToVlogString(double d) { return absl::StrCat(d); }
133
134 template <typename T>
ToVlogString(const HostOrDeviceScalar<T> & memory_or_constant)135 string ToVlogString(const HostOrDeviceScalar<T> &memory_or_constant) {
136 if (memory_or_constant.is_pointer()) {
137 return ToVlogString(memory_or_constant.pointer());
138 }
139 return ToVlogString(memory_or_constant.value());
140 }
141
142 template <class T>
ToVlogString(port::ArraySlice<T> elements)143 string ToVlogString(port::ArraySlice<T> elements) {
144 string str = absl::StrCat(
145 ToVlogString(reinterpret_cast<const void *>(elements.data())), "[",
146 elements.size(), "]{");
147 const char *separator = "";
148 size_t max_to_show = std::numeric_limits<size_t>::max();
149 if (!VLOG_IS_ON(2)) {
150 max_to_show = 5;
151 } else if (!VLOG_IS_ON(3)) {
152 max_to_show = 20;
153 } else if (!VLOG_IS_ON(11)) {
154 max_to_show = 1000;
155 }
156 for (size_t i = 0; i < elements.size(); ++i) {
157 if (i == max_to_show) {
158 str += ", ...";
159 break;
160 }
161 absl::StrAppend(&str, separator, ToVlogString(elements[i]));
162 separator = ", ";
163 }
164 str += "}";
165 return str;
166 }
167
168 template <class T>
ToVlogString(port::MutableArraySlice<T> elements)169 string ToVlogString(port::MutableArraySlice<T> elements) {
170 return ToVlogString(port::ArraySlice<T>(elements));
171 }
172
ToVlogString(dnn::DepthToSpaceLayout depth_to_space_layout)173 string ToVlogString(dnn::DepthToSpaceLayout depth_to_space_layout) {
174 switch (depth_to_space_layout) {
175 case dnn::DepthToSpaceLayout::DepthHeightWidth:
176 return "DepthToSpaceLayout::DepthHeightWidth";
177 }
178 return "unknown DepthToSpaceLayout";
179 }
180
ToVlogString(dnn::DataType data_type)181 string ToVlogString(dnn::DataType data_type) {
182 switch (data_type) {
183 case dnn::DataType::kFloat:
184 return "dnn::DataType::kFloat";
185 case dnn::DataType::kDouble:
186 return "dnn::DataType::kDouble";
187 case dnn::DataType::kHalf:
188 return "dnn::DataType::kHalf";
189 case dnn::DataType::kInt8:
190 return "dnn::DataType::kInt8";
191 case dnn::DataType::kInt32:
192 return "dnn::DataType::kInt32";
193 default:
194 return "unknown DataType";
195 }
196 }
197
198 // Used together with PARAM to VLOG calls made to the stream. Intended
199 // to be used like this:
200 //
201 // VLOG(1) << CallStr("MyFunction", this, {PARAM(a), PARAM(b)});
202 //
203 // where a and b are the parameters to MyFunction.
204 //
205 // See VLOG_CALL for a short-hand for this. This way of doing it saves
206 // a tremendous amount of boilerplate code given how many functions
207 // there are on Stream and how many parameters they each have.
CallStr(const char * function_name,Stream * stream,std::vector<std::pair<const char *,string>> params)208 string CallStr(const char *function_name, Stream *stream,
209 std::vector<std::pair<const char *, string>> params) {
210 // Do not call this function unless VLOG is on since just
211 // constructing all the strings in params is expensive.
212 CHECK(VLOG_IS_ON(1));
213
214 string str = absl::StrCat(stream->DebugStreamPointers(),
215 " Called Stream::", function_name, "(");
216 const char *separator = "";
217 for (const auto ¶m : params) {
218 absl::StrAppend(&str, separator, param.first, "=", param.second);
219 separator = ", ";
220 }
221 absl::StrAppend(&str, ")");
222 if (VLOG_IS_ON(10)) {
223 absl::StrAppend(&str, " ", port::CurrentStackTrace(), "\n");
224 }
225 return str;
226 }
227
228 // Use this macro to avoid having to type every parameter twice to log
229 // it with VLOG and CallStr.
230 #define PARAM(parameter) \
231 { #parameter, ToVlogString(parameter) }
232
233 // Use this macro to avoid having to type out the name of each
234 // function and to save some boilerplate. Intended to be used like this:
235 //
236 // VLOG_CALL(PARAM(a), PARAM(b))
237 //
238 // This saves a tremendous amount of boilerplate compared to the alternative:
239 //
240 // VLOG(1) << "Calling MyFunction(a=" << ToVlogString(a)
241 // << ", b=" << ToVlogString(b);
242 //
243 // Note here that most of the parameter names are not short and that
244 // most of the functions take many more than 2 parameters.
245 #define VLOG_CALL(...) VLOG(1) << CallStr(__func__, this, {__VA_ARGS__})
246
247 } // namespace
248
Stream(StreamExecutor * parent)249 Stream::Stream(StreamExecutor *parent)
250 : parent_(parent),
251 implementation_(parent->implementation()->GetStreamImplementation()),
252 allocated_(false),
253 ok_(false),
254 temporary_memory_manager_(this) {
255 VLOG_CALL(PARAM(parent));
256 }
257
Stream(StreamExecutor * parent,internal::StreamInterface * implementation)258 Stream::Stream(StreamExecutor *parent,
259 internal::StreamInterface *implementation)
260 : parent_(parent),
261 implementation_(implementation),
262 allocated_(false),
263 ok_(false),
264 temporary_memory_manager_(this) {
265 VLOG_CALL(PARAM(parent), PARAM(implementation));
266 }
267
~Stream()268 Stream::~Stream() {
269 VLOG_CALL();
270
271 // Ensure the stream is completed.
272 auto status = BlockHostUntilDone();
273 if (!status.ok()) {
274 LOG(WARNING) << "Error blocking host until done in stream destructor: "
275 << status;
276 }
277 temporary_memory_manager_.ForceDeallocateAll();
278
279 if (allocated_) {
280 parent_->DeallocateStream(this);
281 }
282 }
283
RefreshStatus()284 port::Status Stream::RefreshStatus() {
285 port::Status status = parent_->GetStatus(this);
286 CheckStatus(status);
287 return status;
288 }
289
Init()290 Stream &Stream::Init() {
291 VLOG_CALL();
292
293 mutex_lock lock(mu_);
294 CHECK_EQ(false, allocated_)
295 << "stream appears to already have been initialized";
296 CHECK(!ok_) << "stream should be in !ok() state pre-initialization";
297
298 if (parent_->AllocateStream(this)) {
299 // Successful initialization!
300 allocated_ = true;
301 ok_ = true;
302 } else {
303 LOG(ERROR) << "failed to allocate stream during initialization";
304 }
305
306 return *this;
307 }
308
InitTimer(Timer * timer)309 Stream &Stream::InitTimer(Timer *timer) {
310 VLOG_CALL(PARAM(timer));
311
312 if (ok()) {
313 CheckError(parent_->AllocateTimer(timer));
314 } else {
315 LOG(INFO) << "did not allocate timer: " << timer;
316 }
317 return *this;
318 }
319
InitWithTimer(Timer * timer)320 Stream &Stream::InitWithTimer(Timer *timer) {
321 VLOG_CALL(PARAM(timer));
322
323 return Init().InitTimer(timer);
324 }
325
ThenRecordEvent(Event * event)326 Stream &Stream::ThenRecordEvent(Event *event) {
327 VLOG_CALL(PARAM(event));
328
329 port::Status status = parent_->RecordEvent(this, event);
330 if (!status.ok()) {
331 LOG(ERROR) << "Error recording event in stream: " << status.error_message()
332 << "; not marking stream as bad, as the Event object may be "
333 << "at fault. Monitor for further errors.";
334 }
335
336 return *this;
337 }
338
ThenBatchNormalizationForward(const DeviceMemory<float> & x,const DeviceMemory<float> & scale,const DeviceMemory<float> & offset,const DeviceMemory<float> & estimated_mean,const DeviceMemory<float> & estimated_variance,const dnn::BatchDescriptor & x_desc,const dnn::BatchDescriptor & scale_offset_desc,const double epsilon,DeviceMemory<float> * y,DeviceMemory<float> * batch_mean,DeviceMemory<float> * batch_var,DeviceMemory<float> * saved_mean,DeviceMemory<float> * saved_inv_var,bool is_training,std::function<const DeviceMemory<float> & ()> var_to_inv_var,std::function<void ()> inv_var_to_var)339 Stream &Stream::ThenBatchNormalizationForward(
340 const DeviceMemory<float> &x, const DeviceMemory<float> &scale,
341 const DeviceMemory<float> &offset,
342 const DeviceMemory<float> &estimated_mean,
343 const DeviceMemory<float> &estimated_variance,
344 const dnn::BatchDescriptor &x_desc,
345 const dnn::BatchDescriptor &scale_offset_desc, const double epsilon,
346 DeviceMemory<float> *y, DeviceMemory<float> *batch_mean,
347 DeviceMemory<float> *batch_var, DeviceMemory<float> *saved_mean,
348 DeviceMemory<float> *saved_inv_var, bool is_training,
349 std::function<const DeviceMemory<float> &()> var_to_inv_var,
350 std::function<void()> inv_var_to_var) {
351 VLOG_CALL(PARAM(x), PARAM(scale), PARAM(offset), PARAM(x_desc),
352 PARAM(scale_offset_desc), PARAM(epsilon), PARAM(y));
353 if (ok()) {
354 if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
355 CheckError(dnn->DoBatchNormalizationForward(
356 this, x, scale, offset, estimated_mean, estimated_variance, x_desc,
357 scale_offset_desc, epsilon, y, batch_mean, batch_var, saved_mean,
358 saved_inv_var, is_training, std::move(var_to_inv_var),
359 std::move(inv_var_to_var)));
360 } else {
361 SetErrorAndLogNoDnnSupport();
362 }
363 }
364 return *this;
365 }
366
ThenBatchNormalizationBackward(const DeviceMemory<float> & y_backprop,const DeviceMemory<float> & x,const DeviceMemory<float> & scale,const DeviceMemory<float> & mean,const DeviceMemory<float> & inv_var,const dnn::BatchDescriptor & x_desc,const dnn::BatchDescriptor & scale_offset_desc,const double epsilon,DeviceMemory<float> * x_backprop,DeviceMemory<float> * scale_backprop,DeviceMemory<float> * offset_backprop)367 Stream &Stream::ThenBatchNormalizationBackward(
368 const DeviceMemory<float> &y_backprop, const DeviceMemory<float> &x,
369 const DeviceMemory<float> &scale, const DeviceMemory<float> &mean,
370 const DeviceMemory<float> &inv_var, const dnn::BatchDescriptor &x_desc,
371 const dnn::BatchDescriptor &scale_offset_desc, const double epsilon,
372 DeviceMemory<float> *x_backprop, DeviceMemory<float> *scale_backprop,
373 DeviceMemory<float> *offset_backprop) {
374 VLOG_CALL(PARAM(y_backprop), PARAM(x), PARAM(scale), PARAM(x_desc),
375 PARAM(scale_offset_desc), PARAM(epsilon), PARAM(x_backprop),
376 PARAM(scale_backprop), PARAM(offset_backprop));
377 if (ok()) {
378 if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
379 CheckError(dnn->DoBatchNormalizationBackward(
380 this, y_backprop, x, scale, mean, inv_var, x_desc, scale_offset_desc,
381 epsilon, x_backprop, scale_backprop, offset_backprop));
382 } else {
383 SetErrorAndLogNoDnnSupport();
384 }
385 }
386 return *this;
387 }
388
ThenBatchNormalizationForward(const DeviceMemory<Eigen::half> & x,const DeviceMemory<float> & scale,const DeviceMemory<float> & offset,const DeviceMemory<float> & estimated_mean,const DeviceMemory<float> & estimated_variance,const dnn::BatchDescriptor & x_desc,const dnn::BatchDescriptor & scale_offset_desc,const double epsilon,DeviceMemory<Eigen::half> * y,DeviceMemory<float> * batch_mean,DeviceMemory<float> * batch_var,DeviceMemory<float> * saved_mean,DeviceMemory<float> * saved_inv_var,bool is_training,std::function<const DeviceMemory<float> & ()> var_to_inv_var,std::function<void ()> inv_var_to_var)389 Stream &Stream::ThenBatchNormalizationForward(
390 const DeviceMemory<Eigen::half> &x, const DeviceMemory<float> &scale,
391 const DeviceMemory<float> &offset,
392 const DeviceMemory<float> &estimated_mean,
393 const DeviceMemory<float> &estimated_variance,
394 const dnn::BatchDescriptor &x_desc,
395 const dnn::BatchDescriptor &scale_offset_desc, const double epsilon,
396 DeviceMemory<Eigen::half> *y, DeviceMemory<float> *batch_mean,
397 DeviceMemory<float> *batch_var, DeviceMemory<float> *saved_mean,
398 DeviceMemory<float> *saved_inv_var, bool is_training,
399 std::function<const DeviceMemory<float> &()> var_to_inv_var,
400 std::function<void()> inv_var_to_var) {
401 VLOG_CALL(PARAM(x), PARAM(scale), PARAM(offset), PARAM(x_desc),
402 PARAM(scale_offset_desc), PARAM(epsilon), PARAM(y));
403 if (ok()) {
404 if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
405 CheckError(dnn->DoBatchNormalizationForward(
406 this, x, scale, offset, estimated_mean, estimated_variance, x_desc,
407 scale_offset_desc, epsilon, y, batch_mean, batch_var, saved_mean,
408 saved_inv_var, is_training, std::move(var_to_inv_var),
409 std::move(inv_var_to_var)));
410 } else {
411 SetErrorAndLogNoDnnSupport();
412 }
413 }
414 return *this;
415 }
416
ThenBatchNormalizationBackward(const DeviceMemory<Eigen::half> & y_backprop,const DeviceMemory<Eigen::half> & x,const DeviceMemory<float> & scale,const DeviceMemory<float> & mean,const DeviceMemory<float> & inv_var,const dnn::BatchDescriptor & x_desc,const dnn::BatchDescriptor & scale_offset_desc,const double epsilon,DeviceMemory<Eigen::half> * x_backprop,DeviceMemory<float> * scale_backprop,DeviceMemory<float> * offset_backprop)417 Stream &Stream::ThenBatchNormalizationBackward(
418 const DeviceMemory<Eigen::half> &y_backprop,
419 const DeviceMemory<Eigen::half> &x, const DeviceMemory<float> &scale,
420 const DeviceMemory<float> &mean, const DeviceMemory<float> &inv_var,
421 const dnn::BatchDescriptor &x_desc,
422 const dnn::BatchDescriptor &scale_offset_desc, const double epsilon,
423 DeviceMemory<Eigen::half> *x_backprop, DeviceMemory<float> *scale_backprop,
424 DeviceMemory<float> *offset_backprop) {
425 VLOG_CALL(PARAM(y_backprop), PARAM(x), PARAM(scale), PARAM(x_desc),
426 PARAM(scale_offset_desc), PARAM(epsilon), PARAM(x_backprop),
427 PARAM(scale_backprop), PARAM(offset_backprop));
428 if (ok()) {
429 if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
430 CheckError(dnn->DoBatchNormalizationBackward(
431 this, y_backprop, x, scale, mean, inv_var, x_desc, scale_offset_desc,
432 epsilon, x_backprop, scale_backprop, offset_backprop));
433 } else {
434 SetErrorAndLogNoDnnSupport();
435 }
436 }
437 return *this;
438 }
439
ThenFusedConvolveWithAlgorithm(const dnn::BatchDescriptor & conv_input_descriptor,const DeviceMemory<double> & conv_input_data,double conv_input_scale,const dnn::FilterDescriptor & filter_descriptor,const DeviceMemory<double> & filter_data,const dnn::ConvolutionDescriptor & convolution_descriptor,const DeviceMemory<double> & side_input_data,double side_input_scale,const dnn::BatchDescriptor & bias_descriptor,const DeviceMemory<double> & biases,dnn::ActivationMode activation_mode,const dnn::BatchDescriptor & output_descriptor,DeviceMemory<double> * output,ScratchAllocator * scratch_allocator,const dnn::AlgorithmConfig & algorithm_config,dnn::ProfileResult * output_profile_result)440 Stream &Stream::ThenFusedConvolveWithAlgorithm(
441 const dnn::BatchDescriptor &conv_input_descriptor,
442 const DeviceMemory<double> &conv_input_data, double conv_input_scale,
443 const dnn::FilterDescriptor &filter_descriptor,
444 const DeviceMemory<double> &filter_data,
445 const dnn::ConvolutionDescriptor &convolution_descriptor,
446 const DeviceMemory<double> &side_input_data, double side_input_scale,
447 const dnn::BatchDescriptor &bias_descriptor,
448 const DeviceMemory<double> &biases, dnn::ActivationMode activation_mode,
449 const dnn::BatchDescriptor &output_descriptor, DeviceMemory<double> *output,
450 ScratchAllocator *scratch_allocator,
451 const dnn::AlgorithmConfig &algorithm_config,
452 dnn::ProfileResult *output_profile_result) {
453 VLOG_CALL(PARAM(conv_input_descriptor), PARAM(conv_input_data),
454 PARAM(conv_input_scale), PARAM(filter_descriptor),
455 PARAM(filter_data), PARAM(convolution_descriptor), PARAM(biases),
456 PARAM(side_input_data), PARAM(side_input_scale),
457 PARAM(activation_mode), PARAM(output_descriptor), PARAM(output),
458 PARAM(algorithm_config));
459
460 if (ok()) {
461 if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
462 auto status = dnn->DoFusedConvolve(
463 this, conv_input_descriptor, conv_input_data, conv_input_scale,
464 filter_descriptor, filter_data, convolution_descriptor,
465 side_input_data, side_input_scale, bias_descriptor, biases,
466 activation_mode, output_descriptor, output, scratch_allocator,
467 algorithm_config, output_profile_result);
468 if (!status && !output_profile_result) {
469 SetError();
470 }
471 } else {
472 SetErrorAndLogNoDnnSupport();
473 }
474 }
475 return *this;
476 }
477
ThenFusedConvolveWithAlgorithm(const dnn::BatchDescriptor & conv_input_descriptor,const DeviceMemory<float> & conv_input_data,float conv_input_scale,const dnn::FilterDescriptor & filter_descriptor,const DeviceMemory<float> & filter_data,const dnn::ConvolutionDescriptor & convolution_descriptor,const DeviceMemory<float> & side_input_data,float side_input_scale,const dnn::BatchDescriptor & bias_descriptor,const DeviceMemory<float> & biases,dnn::ActivationMode activation_mode,const dnn::BatchDescriptor & output_descriptor,DeviceMemory<float> * output,ScratchAllocator * scratch_allocator,const dnn::AlgorithmConfig & algorithm_config,dnn::ProfileResult * output_profile_result)478 Stream &Stream::ThenFusedConvolveWithAlgorithm(
479 const dnn::BatchDescriptor &conv_input_descriptor,
480 const DeviceMemory<float> &conv_input_data, float conv_input_scale,
481 const dnn::FilterDescriptor &filter_descriptor,
482 const DeviceMemory<float> &filter_data,
483 const dnn::ConvolutionDescriptor &convolution_descriptor,
484 const DeviceMemory<float> &side_input_data, float side_input_scale,
485 const dnn::BatchDescriptor &bias_descriptor,
486 const DeviceMemory<float> &biases, dnn::ActivationMode activation_mode,
487 const dnn::BatchDescriptor &output_descriptor, DeviceMemory<float> *output,
488 ScratchAllocator *scratch_allocator,
489 const dnn::AlgorithmConfig &algorithm_config,
490 dnn::ProfileResult *output_profile_result) {
491 VLOG_CALL(PARAM(conv_input_descriptor), PARAM(conv_input_data),
492 PARAM(conv_input_scale), PARAM(filter_descriptor),
493 PARAM(filter_data), PARAM(convolution_descriptor), PARAM(biases),
494 PARAM(side_input_data), PARAM(side_input_scale),
495 PARAM(activation_mode), PARAM(output_descriptor), PARAM(output),
496 PARAM(algorithm_config));
497
498 if (ok()) {
499 if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
500 auto status = dnn->DoFusedConvolve(
501 this, conv_input_descriptor, conv_input_data, conv_input_scale,
502 filter_descriptor, filter_data, convolution_descriptor,
503 side_input_data, side_input_scale, bias_descriptor, biases,
504 activation_mode, output_descriptor, output, scratch_allocator,
505 algorithm_config, output_profile_result);
506 if (!status && !output_profile_result) {
507 SetError();
508 }
509 } else {
510 SetErrorAndLogNoDnnSupport();
511 }
512 }
513 return *this;
514 }
515
ThenFusedConvolveWithAlgorithm(const dnn::BatchDescriptor & conv_input_descriptor,const DeviceMemory<Eigen::half> & conv_input_data,float conv_input_scale,const dnn::FilterDescriptor & filter_descriptor,const DeviceMemory<Eigen::half> & filter_data,const dnn::ConvolutionDescriptor & convolution_descriptor,const DeviceMemory<Eigen::half> & side_input_data,float side_input_scale,const dnn::BatchDescriptor & bias_descriptor,const DeviceMemory<Eigen::half> & biases,dnn::ActivationMode activation_mode,const dnn::BatchDescriptor & output_descriptor,DeviceMemory<Eigen::half> * output,ScratchAllocator * scratch_allocator,const dnn::AlgorithmConfig & algorithm_config,dnn::ProfileResult * output_profile_result)516 Stream &Stream::ThenFusedConvolveWithAlgorithm(
517 const dnn::BatchDescriptor &conv_input_descriptor,
518 const DeviceMemory<Eigen::half> &conv_input_data, float conv_input_scale,
519 const dnn::FilterDescriptor &filter_descriptor,
520 const DeviceMemory<Eigen::half> &filter_data,
521 const dnn::ConvolutionDescriptor &convolution_descriptor,
522 const DeviceMemory<Eigen::half> &side_input_data, float side_input_scale,
523 const dnn::BatchDescriptor &bias_descriptor,
524 const DeviceMemory<Eigen::half> &biases,
525 dnn::ActivationMode activation_mode,
526 const dnn::BatchDescriptor &output_descriptor,
527 DeviceMemory<Eigen::half> *output, ScratchAllocator *scratch_allocator,
528 const dnn::AlgorithmConfig &algorithm_config,
529 dnn::ProfileResult *output_profile_result) {
530 VLOG_CALL(PARAM(conv_input_descriptor), PARAM(conv_input_data),
531 PARAM(conv_input_scale), PARAM(filter_descriptor),
532 PARAM(filter_data), PARAM(convolution_descriptor), PARAM(biases),
533 PARAM(side_input_data), PARAM(side_input_scale),
534 PARAM(bias_descriptor), PARAM(biases), PARAM(activation_mode),
535 PARAM(output_descriptor), PARAM(output), PARAM(algorithm_config));
536
537 if (ok()) {
538 if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
539 auto status = dnn->DoFusedConvolve(
540 this, conv_input_descriptor, conv_input_data, conv_input_scale,
541 filter_descriptor, filter_data, convolution_descriptor,
542 side_input_data, side_input_scale, bias_descriptor, biases,
543 activation_mode, output_descriptor, output, scratch_allocator,
544 algorithm_config, output_profile_result);
545 if (!status && !output_profile_result) {
546 SetError();
547 }
548 } else {
549 SetErrorAndLogNoDnnSupport();
550 }
551 }
552 return *this;
553 }
554
ThenFusedConvolveWithAlgorithm(const dnn::BatchDescriptor & conv_input_descriptor,const DeviceMemory<int8> & conv_input_data,float conv_input_scale,const dnn::FilterDescriptor & filter_descriptor,const DeviceMemory<int8> & filter_data,const dnn::ConvolutionDescriptor & convolution_descriptor,const DeviceMemory<int8> & side_input_data,float side_input_scale,const dnn::BatchDescriptor & bias_descriptor,const DeviceMemory<float> & biases,dnn::ActivationMode activation_mode,const dnn::BatchDescriptor & output_descriptor,DeviceMemory<int8> * output,ScratchAllocator * scratch_allocator,const dnn::AlgorithmConfig & algorithm_config,dnn::ProfileResult * output_profile_result)555 Stream &Stream::ThenFusedConvolveWithAlgorithm(
556 const dnn::BatchDescriptor &conv_input_descriptor,
557 const DeviceMemory<int8> &conv_input_data, float conv_input_scale,
558 const dnn::FilterDescriptor &filter_descriptor,
559 const DeviceMemory<int8> &filter_data,
560 const dnn::ConvolutionDescriptor &convolution_descriptor,
561 const DeviceMemory<int8> &side_input_data, float side_input_scale,
562 const dnn::BatchDescriptor &bias_descriptor,
563 const DeviceMemory<float> &biases, dnn::ActivationMode activation_mode,
564 const dnn::BatchDescriptor &output_descriptor, DeviceMemory<int8> *output,
565 ScratchAllocator *scratch_allocator,
566 const dnn::AlgorithmConfig &algorithm_config,
567 dnn::ProfileResult *output_profile_result) {
568 VLOG_CALL(PARAM(conv_input_descriptor), PARAM(conv_input_data),
569 PARAM(conv_input_scale), PARAM(filter_descriptor),
570 PARAM(filter_data), PARAM(convolution_descriptor), PARAM(biases),
571 PARAM(side_input_data), PARAM(side_input_scale),
572 PARAM(bias_descriptor), PARAM(biases), PARAM(activation_mode),
573 PARAM(output_descriptor), PARAM(output), PARAM(algorithm_config));
574
575 if (ok()) {
576 if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
577 auto status = dnn->DoFusedConvolve(
578 this, conv_input_descriptor, conv_input_data, conv_input_scale,
579 filter_descriptor, filter_data, convolution_descriptor,
580 side_input_data, side_input_scale, bias_descriptor, biases,
581 activation_mode, output_descriptor, output, scratch_allocator,
582 algorithm_config, output_profile_result);
583 if (!status && !output_profile_result) {
584 SetError();
585 }
586 } else {
587 SetErrorAndLogNoDnnSupport();
588 }
589 }
590 return *this;
591 }
592
ThenConvolveWithAlgorithm(const dnn::BatchDescriptor & input_descriptor,const DeviceMemory<double> & input_data,const dnn::FilterDescriptor & filter_descriptor,const DeviceMemory<double> & filter_data,const dnn::ConvolutionDescriptor & convolution_descriptor,const dnn::BatchDescriptor & output_descriptor,DeviceMemory<double> * output,ScratchAllocator * scratch_allocator,const dnn::AlgorithmConfig & algorithm_config,dnn::ProfileResult * output_profile_result)593 Stream &Stream::ThenConvolveWithAlgorithm(
594 const dnn::BatchDescriptor &input_descriptor,
595 const DeviceMemory<double> &input_data,
596 const dnn::FilterDescriptor &filter_descriptor,
597 const DeviceMemory<double> &filter_data,
598 const dnn::ConvolutionDescriptor &convolution_descriptor,
599 const dnn::BatchDescriptor &output_descriptor, DeviceMemory<double> *output,
600 ScratchAllocator *scratch_allocator,
601 const dnn::AlgorithmConfig &algorithm_config,
602 dnn::ProfileResult *output_profile_result) {
603 VLOG_CALL(PARAM(input_descriptor), PARAM(input_data),
604 PARAM(filter_descriptor), PARAM(filter_data),
605 PARAM(convolution_descriptor), PARAM(output_descriptor),
606 PARAM(output), PARAM(algorithm_config));
607
608 if (ok()) {
609 if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
610 DeviceMemory<uint8> scratch_memory;
611 dnn::AlgorithmDesc algorithm_desc;
612 auto status =
613 dnn->PrepareForConvolution(
614 dnn::ConvolutionKind::FORWARD, this, input_descriptor,
615 input_data, filter_descriptor, filter_data, output_descriptor,
616 *output, convolution_descriptor, algorithm_config,
617 scratch_allocator, &algorithm_desc, &scratch_memory)
618 .ok();
619 if (status) {
620 status = dnn->DoConvolve(
621 this, input_descriptor, input_data, filter_descriptor, filter_data,
622 convolution_descriptor, output_descriptor, output, algorithm_desc,
623 &scratch_memory, output_profile_result);
624 }
625 if (!status && !output_profile_result) {
626 SetError();
627 }
628 } else {
629 SetErrorAndLogNoDnnSupport();
630 }
631 }
632 return *this;
633 }
634
ThenConvolveWithAlgorithm(const dnn::BatchDescriptor & input_descriptor,const DeviceMemory<float> & input_data,const dnn::FilterDescriptor & filter_descriptor,const DeviceMemory<float> & filter_data,const dnn::ConvolutionDescriptor & convolution_descriptor,const dnn::BatchDescriptor & output_descriptor,DeviceMemory<float> * output,ScratchAllocator * scratch_allocator,const dnn::AlgorithmConfig & algorithm_config,dnn::ProfileResult * output_profile_result)635 Stream &Stream::ThenConvolveWithAlgorithm(
636 const dnn::BatchDescriptor &input_descriptor,
637 const DeviceMemory<float> &input_data,
638 const dnn::FilterDescriptor &filter_descriptor,
639 const DeviceMemory<float> &filter_data,
640 const dnn::ConvolutionDescriptor &convolution_descriptor,
641 const dnn::BatchDescriptor &output_descriptor, DeviceMemory<float> *output,
642 ScratchAllocator *scratch_allocator,
643 const dnn::AlgorithmConfig &algorithm_config,
644 dnn::ProfileResult *output_profile_result) {
645 VLOG_CALL(PARAM(input_descriptor), PARAM(input_data),
646 PARAM(filter_descriptor), PARAM(filter_data),
647 PARAM(convolution_descriptor), PARAM(output_descriptor),
648 PARAM(output), PARAM(algorithm_config));
649
650 if (ok()) {
651 if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
652 DeviceMemory<uint8> scratch_memory;
653 dnn::AlgorithmDesc algorithm_desc;
654 auto status =
655 dnn->PrepareForConvolution(
656 dnn::ConvolutionKind::FORWARD, this, input_descriptor,
657 input_data, filter_descriptor, filter_data, output_descriptor,
658 *output, convolution_descriptor, algorithm_config,
659 scratch_allocator, &algorithm_desc, &scratch_memory)
660 .ok();
661 if (status) {
662 status = dnn->DoConvolve(
663 this, input_descriptor, input_data, filter_descriptor, filter_data,
664 convolution_descriptor, output_descriptor, output, algorithm_desc,
665 &scratch_memory, output_profile_result);
666 }
667 if (!status && !output_profile_result) {
668 SetError();
669 }
670 } else {
671 SetErrorAndLogNoDnnSupport();
672 }
673 }
674 return *this;
675 }
676
ThenConvolveWithAlgorithm(const dnn::BatchDescriptor & input_descriptor,const DeviceMemory<Eigen::half> & input_data,const dnn::FilterDescriptor & filter_descriptor,const DeviceMemory<Eigen::half> & filter_data,const dnn::ConvolutionDescriptor & convolution_descriptor,const dnn::BatchDescriptor & output_descriptor,DeviceMemory<Eigen::half> * output,ScratchAllocator * scratch_allocator,const dnn::AlgorithmConfig & algorithm_config,dnn::ProfileResult * output_profile_result)677 Stream &Stream::ThenConvolveWithAlgorithm(
678 const dnn::BatchDescriptor &input_descriptor,
679 const DeviceMemory<Eigen::half> &input_data,
680 const dnn::FilterDescriptor &filter_descriptor,
681 const DeviceMemory<Eigen::half> &filter_data,
682 const dnn::ConvolutionDescriptor &convolution_descriptor,
683 const dnn::BatchDescriptor &output_descriptor,
684 DeviceMemory<Eigen::half> *output, ScratchAllocator *scratch_allocator,
685 const dnn::AlgorithmConfig &algorithm_config,
686 dnn::ProfileResult *output_profile_result) {
687 VLOG_CALL(PARAM(input_descriptor), PARAM(input_data),
688 PARAM(filter_descriptor), PARAM(filter_data),
689 PARAM(convolution_descriptor), PARAM(output_descriptor),
690 PARAM(output), PARAM(algorithm_config));
691
692 if (ok()) {
693 if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
694 DeviceMemory<uint8> scratch_memory;
695 dnn::AlgorithmDesc algorithm_desc;
696 auto status =
697 dnn->PrepareForConvolution(
698 dnn::ConvolutionKind::FORWARD, this, input_descriptor,
699 input_data, filter_descriptor, filter_data, output_descriptor,
700 *output, convolution_descriptor, algorithm_config,
701 scratch_allocator, &algorithm_desc, &scratch_memory)
702 .ok();
703 if (status) {
704 status = dnn->DoConvolve(
705 this, input_descriptor, input_data, filter_descriptor, filter_data,
706 convolution_descriptor, output_descriptor, output, algorithm_desc,
707 &scratch_memory, output_profile_result);
708 }
709 if (!status && !output_profile_result) {
710 SetError();
711 }
712 } else {
713 SetErrorAndLogNoDnnSupport();
714 }
715 }
716 return *this;
717 }
718
ThenConvolve(const dnn::BatchDescriptor & input_descriptor,const DeviceMemory<float> & input_data,const dnn::FilterDescriptor & filter_descriptor,const DeviceMemory<float> & filter_data,const dnn::ConvolutionDescriptor & convolution_descriptor,const dnn::BatchDescriptor & output_descriptor,DeviceMemory<float> * output)719 Stream &Stream::ThenConvolve(
720 const dnn::BatchDescriptor &input_descriptor,
721 const DeviceMemory<float> &input_data,
722 const dnn::FilterDescriptor &filter_descriptor,
723 const DeviceMemory<float> &filter_data,
724 const dnn::ConvolutionDescriptor &convolution_descriptor,
725 const dnn::BatchDescriptor &output_descriptor,
726 DeviceMemory<float> *output) {
727 return ThenConvolveWithAlgorithm(
728 input_descriptor, input_data, filter_descriptor, filter_data,
729 convolution_descriptor, output_descriptor, output,
730 /*scratch_allocator=*/nullptr, dnn::AlgorithmConfig(),
731 /*output_profile_result=*/nullptr);
732 }
733
ThenConvolveQuantized(const dnn::BatchDescriptor & input_descriptor,const DeviceMemory<float> & input_data,const dnn::FilterDescriptor & filter_descriptor,const DeviceMemory<int8> & filter_coefficients,const DeviceMemory<float> & coefficient_scales,const dnn::ConvolutionDescriptor & convolution_descriptor,const dnn::BatchDescriptor & output_descriptor,DeviceMemory<float> * output)734 Stream &Stream::ThenConvolveQuantized(
735 const dnn::BatchDescriptor &input_descriptor,
736 const DeviceMemory<float> &input_data,
737 const dnn::FilterDescriptor &filter_descriptor,
738 const DeviceMemory<int8> &filter_coefficients,
739 const DeviceMemory<float> &coefficient_scales,
740 const dnn::ConvolutionDescriptor &convolution_descriptor,
741 const dnn::BatchDescriptor &output_descriptor,
742 DeviceMemory<float> *output) {
743 VLOG_CALL(PARAM(input_descriptor), PARAM(input_data),
744 PARAM(filter_descriptor), PARAM(filter_coefficients),
745 PARAM(coefficient_scales), PARAM(convolution_descriptor),
746 PARAM(output_descriptor), PARAM(output));
747
748 if (ok()) {
749 if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
750 CheckError(dnn->DoConvolveQuantized(
751 this, input_descriptor, input_data, filter_descriptor,
752 filter_coefficients, coefficient_scales, convolution_descriptor,
753 output_descriptor, output));
754 } else {
755 SetError();
756 LOG(WARNING)
757 << "attempting to perform DNN operation using StreamExecutor "
758 "without DNN support";
759 }
760 }
761 return *this;
762 }
763
ThenConvolveQuantized(const dnn::BatchDescriptor & input_descriptor,const DeviceMemory<float> & input_data,const dnn::FilterDescriptor & filter_descriptor,const DeviceMemory<int16> & filter_coefficients,const DeviceMemory<float> & coefficient_scales,const dnn::ConvolutionDescriptor & convolution_descriptor,const dnn::BatchDescriptor & output_descriptor,DeviceMemory<float> * output)764 Stream &Stream::ThenConvolveQuantized(
765 const dnn::BatchDescriptor &input_descriptor,
766 const DeviceMemory<float> &input_data,
767 const dnn::FilterDescriptor &filter_descriptor,
768 const DeviceMemory<int16> &filter_coefficients,
769 const DeviceMemory<float> &coefficient_scales,
770 const dnn::ConvolutionDescriptor &convolution_descriptor,
771 const dnn::BatchDescriptor &output_descriptor,
772 DeviceMemory<float> *output) {
773 VLOG_CALL(PARAM(input_descriptor), PARAM(input_data),
774 PARAM(filter_descriptor), PARAM(filter_coefficients),
775 PARAM(coefficient_scales), PARAM(convolution_descriptor),
776 PARAM(output_descriptor), PARAM(output));
777
778 if (ok()) {
779 if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
780 CheckError(dnn->DoConvolveQuantized(
781 this, input_descriptor, input_data, filter_descriptor,
782 filter_coefficients, coefficient_scales, convolution_descriptor,
783 output_descriptor, output));
784 } else {
785 SetError();
786 LOG(WARNING)
787 << "attempting to perform DNN operation using StreamExecutor "
788 "without DNN support";
789 }
790 }
791 return *this;
792 }
793
ThenSeparableConvolve(const dnn::BatchDescriptor & batch_descriptor,const DeviceMemory<float> & input_data,const dnn::FilterDescriptor & filter_descriptor,int depth_multiplier,const DeviceMemory<float> & first_weights,const DeviceMemory<float> & second_weights,const dnn::ConvolutionDescriptor & convolution_descriptor,const dnn::BatchDescriptor & output_descriptor,DeviceMemory<float> * output)794 Stream &Stream::ThenSeparableConvolve(
795 const dnn::BatchDescriptor &batch_descriptor,
796 const DeviceMemory<float> &input_data,
797 const dnn::FilterDescriptor &filter_descriptor, int depth_multiplier,
798 const DeviceMemory<float> &first_weights,
799 const DeviceMemory<float> &second_weights,
800 const dnn::ConvolutionDescriptor &convolution_descriptor,
801 const dnn::BatchDescriptor &output_descriptor,
802 DeviceMemory<float> *output) {
803 VLOG_CALL(
804 PARAM(batch_descriptor), PARAM(input_data), PARAM(filter_descriptor),
805 PARAM(depth_multiplier), PARAM(first_weights), PARAM(second_weights),
806 PARAM(convolution_descriptor), PARAM(output_descriptor), PARAM(output));
807
808 if (ok()) {
809 if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
810 CheckError(dnn->DoSeparableConvolve(
811 this, batch_descriptor, input_data, filter_descriptor,
812 depth_multiplier, first_weights, second_weights,
813 convolution_descriptor, output_descriptor, output));
814 } else {
815 SetErrorAndLogNoDnnSupport();
816 }
817 }
818 return *this;
819 }
820
ThenConvolveBackwardDataWithAlgorithm(const dnn::FilterDescriptor & filter_descriptor,const DeviceMemory<double> & filter_data,const dnn::BatchDescriptor & output_descriptor,DeviceMemory<double> backward_output_data,const dnn::ConvolutionDescriptor & convolution_descriptor,const dnn::BatchDescriptor & input_descriptor,DeviceMemory<double> * backward_input_data,ScratchAllocator * scratch_allocator,const dnn::AlgorithmConfig & algorithm_config,dnn::ProfileResult * output_profile_result)821 Stream &Stream::ThenConvolveBackwardDataWithAlgorithm(
822 const dnn::FilterDescriptor &filter_descriptor,
823 const DeviceMemory<double> &filter_data,
824 const dnn::BatchDescriptor &output_descriptor,
825 DeviceMemory<double> backward_output_data,
826 const dnn::ConvolutionDescriptor &convolution_descriptor,
827 const dnn::BatchDescriptor &input_descriptor,
828 DeviceMemory<double> *backward_input_data,
829 ScratchAllocator *scratch_allocator,
830 const dnn::AlgorithmConfig &algorithm_config,
831 dnn::ProfileResult *output_profile_result) {
832 VLOG_CALL(PARAM(filter_descriptor), PARAM(filter_data),
833 PARAM(output_descriptor), PARAM(backward_output_data),
834 PARAM(convolution_descriptor), PARAM(input_descriptor),
835 PARAM(backward_input_data));
836
837 if (ok()) {
838 if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
839 DeviceMemory<uint8> scratch_memory;
840 dnn::AlgorithmDesc algorithm_desc;
841 auto status =
842 dnn->PrepareForConvolution(
843 dnn::ConvolutionKind::BACKWARD_DATA, this, input_descriptor,
844 *backward_input_data, filter_descriptor, filter_data,
845 output_descriptor, backward_output_data,
846 convolution_descriptor, algorithm_config, scratch_allocator,
847 &algorithm_desc, &scratch_memory)
848 .ok();
849 if (status) {
850 status = dnn->DoConvolveBackwardData(
851 this, filter_descriptor, filter_data, output_descriptor,
852 backward_output_data, convolution_descriptor, input_descriptor,
853 backward_input_data, algorithm_desc, &scratch_memory,
854 output_profile_result);
855 }
856 if (!status && !output_profile_result) {
857 SetError();
858 }
859 } else {
860 SetErrorAndLogNoDnnSupport();
861 }
862 }
863 return *this;
864 }
865
ThenConvolveBackwardDataWithAlgorithm(const dnn::FilterDescriptor & filter_descriptor,const DeviceMemory<float> & filter_data,const dnn::BatchDescriptor & output_descriptor,DeviceMemory<float> backward_output_data,const dnn::ConvolutionDescriptor & convolution_descriptor,const dnn::BatchDescriptor & input_descriptor,DeviceMemory<float> * backward_input_data,ScratchAllocator * scratch_allocator,const dnn::AlgorithmConfig & algorithm_config,dnn::ProfileResult * output_profile_result)866 Stream &Stream::ThenConvolveBackwardDataWithAlgorithm(
867 const dnn::FilterDescriptor &filter_descriptor,
868 const DeviceMemory<float> &filter_data,
869 const dnn::BatchDescriptor &output_descriptor,
870 DeviceMemory<float> backward_output_data,
871 const dnn::ConvolutionDescriptor &convolution_descriptor,
872 const dnn::BatchDescriptor &input_descriptor,
873 DeviceMemory<float> *backward_input_data,
874 ScratchAllocator *scratch_allocator,
875 const dnn::AlgorithmConfig &algorithm_config,
876 dnn::ProfileResult *output_profile_result) {
877 VLOG_CALL(PARAM(filter_descriptor), PARAM(filter_data),
878 PARAM(output_descriptor), PARAM(backward_output_data),
879 PARAM(convolution_descriptor), PARAM(input_descriptor),
880 PARAM(backward_input_data));
881
882 if (ok()) {
883 if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
884 DeviceMemory<uint8> scratch_memory;
885 dnn::AlgorithmDesc algorithm_desc;
886 auto status =
887 dnn->PrepareForConvolution(
888 dnn::ConvolutionKind::BACKWARD_DATA, this, input_descriptor,
889 *backward_input_data, filter_descriptor, filter_data,
890 output_descriptor, backward_output_data,
891 convolution_descriptor, algorithm_config, scratch_allocator,
892 &algorithm_desc, &scratch_memory)
893 .ok();
894 if (status) {
895 status = dnn->DoConvolveBackwardData(
896 this, filter_descriptor, filter_data, output_descriptor,
897 backward_output_data, convolution_descriptor, input_descriptor,
898 backward_input_data, algorithm_desc, &scratch_memory,
899 output_profile_result);
900 }
901 if (!status && !output_profile_result) {
902 SetError();
903 }
904 } else {
905 SetErrorAndLogNoDnnSupport();
906 }
907 }
908 return *this;
909 }
910
ThenConvolveBackwardDataWithAlgorithm(const dnn::FilterDescriptor & filter_descriptor,const DeviceMemory<Eigen::half> & filter_data,const dnn::BatchDescriptor & output_descriptor,DeviceMemory<Eigen::half> backward_output_data,const dnn::ConvolutionDescriptor & convolution_descriptor,const dnn::BatchDescriptor & input_descriptor,DeviceMemory<Eigen::half> * backward_input_data,ScratchAllocator * scratch_allocator,const dnn::AlgorithmConfig & algorithm_config,dnn::ProfileResult * output_profile_result)911 Stream &Stream::ThenConvolveBackwardDataWithAlgorithm(
912 const dnn::FilterDescriptor &filter_descriptor,
913 const DeviceMemory<Eigen::half> &filter_data,
914 const dnn::BatchDescriptor &output_descriptor,
915 DeviceMemory<Eigen::half> backward_output_data,
916 const dnn::ConvolutionDescriptor &convolution_descriptor,
917 const dnn::BatchDescriptor &input_descriptor,
918 DeviceMemory<Eigen::half> *backward_input_data,
919 ScratchAllocator *scratch_allocator,
920 const dnn::AlgorithmConfig &algorithm_config,
921 dnn::ProfileResult *output_profile_result) {
922 VLOG_CALL(PARAM(filter_descriptor), PARAM(filter_data),
923 PARAM(output_descriptor), PARAM(backward_output_data),
924 PARAM(convolution_descriptor), PARAM(input_descriptor),
925 PARAM(backward_input_data));
926
927 if (ok()) {
928 if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
929 DeviceMemory<uint8> scratch_memory;
930 dnn::AlgorithmDesc algorithm_desc;
931 auto status =
932 dnn->PrepareForConvolution(
933 dnn::ConvolutionKind::BACKWARD_DATA, this, input_descriptor,
934 *backward_input_data, filter_descriptor, filter_data,
935 output_descriptor, backward_output_data,
936 convolution_descriptor, algorithm_config, scratch_allocator,
937 &algorithm_desc, &scratch_memory)
938 .ok();
939 if (status) {
940 status = dnn->DoConvolveBackwardData(
941 this, filter_descriptor, filter_data, output_descriptor,
942 backward_output_data, convolution_descriptor, input_descriptor,
943 backward_input_data, algorithm_desc, &scratch_memory,
944 output_profile_result);
945 }
946 if (!status && !output_profile_result) {
947 SetError();
948 }
949 } else {
950 SetErrorAndLogNoDnnSupport();
951 }
952 }
953 return *this;
954 }
955
ThenConvolveBackwardFilterWithAlgorithm(const dnn::BatchDescriptor & input_descriptor,const DeviceMemory<double> & input_data,const dnn::BatchDescriptor & output_descriptor,DeviceMemory<double> backward_output_data,const dnn::ConvolutionDescriptor & convolution_descriptor,const dnn::FilterDescriptor & filter_descriptor,DeviceMemory<double> * backward_filter_data,ScratchAllocator * scratch_allocator,const dnn::AlgorithmConfig & algorithm_config,dnn::ProfileResult * output_profile_result)956 Stream &Stream::ThenConvolveBackwardFilterWithAlgorithm(
957 const dnn::BatchDescriptor &input_descriptor,
958 const DeviceMemory<double> &input_data,
959 const dnn::BatchDescriptor &output_descriptor,
960 DeviceMemory<double> backward_output_data,
961 const dnn::ConvolutionDescriptor &convolution_descriptor,
962 const dnn::FilterDescriptor &filter_descriptor,
963 DeviceMemory<double> *backward_filter_data,
964 ScratchAllocator *scratch_allocator,
965 const dnn::AlgorithmConfig &algorithm_config,
966 dnn::ProfileResult *output_profile_result) {
967 VLOG_CALL(PARAM(input_descriptor), PARAM(input_data),
968 PARAM(output_descriptor), PARAM(backward_output_data),
969 PARAM(convolution_descriptor), PARAM(filter_descriptor),
970 PARAM(backward_filter_data));
971
972 if (ok()) {
973 if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
974 DeviceMemory<uint8> scratch_memory;
975 dnn::AlgorithmDesc algorithm_desc;
976 auto status =
977 dnn->PrepareForConvolution(
978 dnn::ConvolutionKind::BACKWARD_FILTER, this, input_descriptor,
979 input_data, filter_descriptor, *backward_filter_data,
980 output_descriptor, backward_output_data,
981 convolution_descriptor, algorithm_config, scratch_allocator,
982 &algorithm_desc, &scratch_memory)
983 .ok();
984 if (status) {
985 status = dnn->DoConvolveBackwardFilter(
986 this, input_descriptor, input_data, output_descriptor,
987 backward_output_data, convolution_descriptor, filter_descriptor,
988 backward_filter_data, algorithm_desc, &scratch_memory,
989 output_profile_result);
990 }
991 if (!status && !output_profile_result) {
992 SetError();
993 }
994 } else {
995 SetErrorAndLogNoDnnSupport();
996 }
997 }
998 return *this;
999 }
1000
ThenConvolveBackwardFilterWithAlgorithm(const dnn::BatchDescriptor & input_descriptor,const DeviceMemory<float> & input_data,const dnn::BatchDescriptor & output_descriptor,DeviceMemory<float> backward_output_data,const dnn::ConvolutionDescriptor & convolution_descriptor,const dnn::FilterDescriptor & filter_descriptor,DeviceMemory<float> * backward_filter_data,ScratchAllocator * scratch_allocator,const dnn::AlgorithmConfig & algorithm_config,dnn::ProfileResult * output_profile_result)1001 Stream &Stream::ThenConvolveBackwardFilterWithAlgorithm(
1002 const dnn::BatchDescriptor &input_descriptor,
1003 const DeviceMemory<float> &input_data,
1004 const dnn::BatchDescriptor &output_descriptor,
1005 DeviceMemory<float> backward_output_data,
1006 const dnn::ConvolutionDescriptor &convolution_descriptor,
1007 const dnn::FilterDescriptor &filter_descriptor,
1008 DeviceMemory<float> *backward_filter_data,
1009 ScratchAllocator *scratch_allocator,
1010 const dnn::AlgorithmConfig &algorithm_config,
1011 dnn::ProfileResult *output_profile_result) {
1012 VLOG_CALL(PARAM(input_descriptor), PARAM(input_data),
1013 PARAM(output_descriptor), PARAM(backward_output_data),
1014 PARAM(convolution_descriptor), PARAM(filter_descriptor),
1015 PARAM(backward_filter_data));
1016
1017 if (ok()) {
1018 if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
1019 DeviceMemory<uint8> scratch_memory;
1020 dnn::AlgorithmDesc algorithm_desc;
1021 auto status =
1022 dnn->PrepareForConvolution(
1023 dnn::ConvolutionKind::BACKWARD_FILTER, this, input_descriptor,
1024 input_data, filter_descriptor, *backward_filter_data,
1025 output_descriptor, backward_output_data,
1026 convolution_descriptor, algorithm_config, scratch_allocator,
1027 &algorithm_desc, &scratch_memory)
1028 .ok();
1029 if (status) {
1030 status = dnn->DoConvolveBackwardFilter(
1031 this, input_descriptor, input_data, output_descriptor,
1032 backward_output_data, convolution_descriptor, filter_descriptor,
1033 backward_filter_data, algorithm_desc, &scratch_memory,
1034 output_profile_result);
1035 }
1036 if (!status && !output_profile_result) {
1037 SetError();
1038 }
1039 } else {
1040 SetErrorAndLogNoDnnSupport();
1041 }
1042 }
1043 return *this;
1044 }
1045
ThenConvolveBackwardFilterWithAlgorithm(const dnn::BatchDescriptor & input_descriptor,const DeviceMemory<Eigen::half> & input_data,const dnn::BatchDescriptor & output_descriptor,DeviceMemory<Eigen::half> backward_output_data,const dnn::ConvolutionDescriptor & convolution_descriptor,const dnn::FilterDescriptor & filter_descriptor,DeviceMemory<Eigen::half> * backward_filter_data,ScratchAllocator * scratch_allocator,const dnn::AlgorithmConfig & algorithm_config,dnn::ProfileResult * output_profile_result)1046 Stream &Stream::ThenConvolveBackwardFilterWithAlgorithm(
1047 const dnn::BatchDescriptor &input_descriptor,
1048 const DeviceMemory<Eigen::half> &input_data,
1049 const dnn::BatchDescriptor &output_descriptor,
1050 DeviceMemory<Eigen::half> backward_output_data,
1051 const dnn::ConvolutionDescriptor &convolution_descriptor,
1052 const dnn::FilterDescriptor &filter_descriptor,
1053 DeviceMemory<Eigen::half> *backward_filter_data,
1054 ScratchAllocator *scratch_allocator,
1055 const dnn::AlgorithmConfig &algorithm_config,
1056 dnn::ProfileResult *output_profile_result) {
1057 VLOG_CALL(PARAM(input_descriptor), PARAM(input_data),
1058 PARAM(output_descriptor), PARAM(backward_output_data),
1059 PARAM(convolution_descriptor), PARAM(filter_descriptor),
1060 PARAM(backward_filter_data));
1061
1062 if (ok()) {
1063 if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
1064 DeviceMemory<uint8> scratch_memory;
1065 dnn::AlgorithmDesc algorithm_desc;
1066 auto status =
1067 dnn->PrepareForConvolution(
1068 dnn::ConvolutionKind::BACKWARD_FILTER, this, input_descriptor,
1069 input_data, filter_descriptor, *backward_filter_data,
1070 output_descriptor, backward_output_data,
1071 convolution_descriptor, algorithm_config, scratch_allocator,
1072 &algorithm_desc, &scratch_memory)
1073 .ok();
1074 if (status) {
1075 status = dnn->DoConvolveBackwardFilter(
1076 this, input_descriptor, input_data, output_descriptor,
1077 backward_output_data, convolution_descriptor, filter_descriptor,
1078 backward_filter_data, algorithm_desc, &scratch_memory,
1079 output_profile_result);
1080 }
1081 if (!status && !output_profile_result) {
1082 SetError();
1083 }
1084 } else {
1085 SetErrorAndLogNoDnnSupport();
1086 }
1087 }
1088 return *this;
1089 }
1090
1091 template <typename T>
ThenConvolveBackwardBiasImpl(const dnn::BatchDescriptor & input_descriptor,const DeviceMemory<T> & input_data,const dnn::BatchDescriptor & bias_descriptor,DeviceMemory<T> * backward_bias_data)1092 Stream &Stream::ThenConvolveBackwardBiasImpl(
1093 const dnn::BatchDescriptor &input_descriptor,
1094 const DeviceMemory<T> &input_data,
1095 const dnn::BatchDescriptor &bias_descriptor,
1096 DeviceMemory<T> *backward_bias_data) {
1097 VLOG_CALL(PARAM(input_descriptor), PARAM(input_data), PARAM(bias_descriptor),
1098 PARAM(backward_bias_data));
1099
1100 if (ok()) {
1101 if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
1102 CheckError(dnn->DoConvolveBackwardBias(this, input_descriptor, input_data,
1103 bias_descriptor,
1104 backward_bias_data));
1105 } else {
1106 SetErrorAndLogNoDnnSupport();
1107 }
1108 }
1109 return *this;
1110 }
1111
ThenConvolveBackwardBias(const dnn::BatchDescriptor & input_descriptor,const DeviceMemory<double> & input_data,const dnn::BatchDescriptor & bias_descriptor,DeviceMemory<double> * backward_bias_data)1112 Stream &Stream::ThenConvolveBackwardBias(
1113 const dnn::BatchDescriptor &input_descriptor,
1114 const DeviceMemory<double> &input_data,
1115 const dnn::BatchDescriptor &bias_descriptor,
1116 DeviceMemory<double> *backward_bias_data) {
1117 return ThenConvolveBackwardBiasImpl(input_descriptor, input_data,
1118 bias_descriptor, backward_bias_data);
1119 }
1120
ThenConvolveBackwardBias(const dnn::BatchDescriptor & input_descriptor,const DeviceMemory<float> & input_data,const dnn::BatchDescriptor & bias_descriptor,DeviceMemory<float> * backward_bias_data)1121 Stream &Stream::ThenConvolveBackwardBias(
1122 const dnn::BatchDescriptor &input_descriptor,
1123 const DeviceMemory<float> &input_data,
1124 const dnn::BatchDescriptor &bias_descriptor,
1125 DeviceMemory<float> *backward_bias_data) {
1126 return ThenConvolveBackwardBiasImpl(input_descriptor, input_data,
1127 bias_descriptor, backward_bias_data);
1128 }
1129
ThenConvolveBackwardBias(const dnn::BatchDescriptor & input_descriptor,const DeviceMemory<Eigen::half> & input_data,const dnn::BatchDescriptor & bias_descriptor,DeviceMemory<Eigen::half> * backward_bias_data)1130 Stream &Stream::ThenConvolveBackwardBias(
1131 const dnn::BatchDescriptor &input_descriptor,
1132 const DeviceMemory<Eigen::half> &input_data,
1133 const dnn::BatchDescriptor &bias_descriptor,
1134 DeviceMemory<Eigen::half> *backward_bias_data) {
1135 return ThenConvolveBackwardBiasImpl(input_descriptor, input_data,
1136 bias_descriptor, backward_bias_data);
1137 }
1138
ThenMatMul(const DeviceMemory<float> & input_data,const DeviceMemory<float> & weights,const dnn::BatchDescriptor & input_dimensions,const dnn::BatchDescriptor & output_dimensions,DeviceMemory<float> * output_data)1139 Stream &Stream::ThenMatMul(const DeviceMemory<float> &input_data,
1140 const DeviceMemory<float> &weights,
1141 const dnn::BatchDescriptor &input_dimensions,
1142 const dnn::BatchDescriptor &output_dimensions,
1143 DeviceMemory<float> *output_data) {
1144 VLOG_CALL(PARAM(input_data), PARAM(weights), PARAM(input_dimensions),
1145 PARAM(output_dimensions), PARAM(output_data));
1146
1147 if (ok()) {
1148 if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
1149 CheckError(dnn->DoMatMul(this, input_data, weights, input_dimensions,
1150 output_dimensions, output_data));
1151 } else {
1152 SetErrorAndLogNoDnnSupport();
1153 }
1154 }
1155 return *this;
1156 }
1157
ThenMatMulQuantized(const DeviceMemory<float> & input_data,const DeviceMemory<int8> & weights,const DeviceMemory<float> & weight_scales,const dnn::BatchDescriptor & input_dimensions,const dnn::BatchDescriptor & output_dimensions,DeviceMemory<float> * output_data)1158 Stream &Stream::ThenMatMulQuantized(
1159 const DeviceMemory<float> &input_data, const DeviceMemory<int8> &weights,
1160 const DeviceMemory<float> &weight_scales,
1161 const dnn::BatchDescriptor &input_dimensions,
1162 const dnn::BatchDescriptor &output_dimensions,
1163 DeviceMemory<float> *output_data) {
1164 VLOG_CALL(PARAM(input_data), PARAM(weights), PARAM(weight_scales),
1165 PARAM(input_dimensions), PARAM(output_dimensions),
1166 PARAM(output_data));
1167
1168 if (ok()) {
1169 if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
1170 CheckError(dnn->DoMatMulQuantized(this, input_data, weights,
1171 weight_scales, input_dimensions,
1172 output_dimensions, output_data));
1173 } else {
1174 SetErrorAndLogNoDnnSupport();
1175 }
1176 }
1177 return *this;
1178 }
1179
ThenMatMulQuantized(const DeviceMemory<float> & input_data,const DeviceMemory<int16> & weights,const DeviceMemory<float> & weight_scales,const dnn::BatchDescriptor & input_dimensions,const dnn::BatchDescriptor & output_dimensions,DeviceMemory<float> * output_data)1180 Stream &Stream::ThenMatMulQuantized(
1181 const DeviceMemory<float> &input_data, const DeviceMemory<int16> &weights,
1182 const DeviceMemory<float> &weight_scales,
1183 const dnn::BatchDescriptor &input_dimensions,
1184 const dnn::BatchDescriptor &output_dimensions,
1185 DeviceMemory<float> *output_data) {
1186 VLOG_CALL(PARAM(input_data), PARAM(weights), PARAM(weight_scales),
1187 PARAM(input_dimensions), PARAM(output_dimensions),
1188 PARAM(output_data));
1189
1190 if (ok()) {
1191 if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
1192 CheckError(dnn->DoMatMulQuantized(this, input_data, weights,
1193 weight_scales, input_dimensions,
1194 output_dimensions, output_data));
1195 } else {
1196 SetErrorAndLogNoDnnSupport();
1197 }
1198 }
1199 return *this;
1200 }
1201
ThenBiasAdd(const DeviceMemory<float> & input_data,const DeviceMemory<float> & biases,const dnn::BatchDescriptor & dimensions,DeviceMemory<float> * output_data)1202 Stream &Stream::ThenBiasAdd(const DeviceMemory<float> &input_data,
1203 const DeviceMemory<float> &biases,
1204 const dnn::BatchDescriptor &dimensions,
1205 DeviceMemory<float> *output_data) {
1206 VLOG_CALL(PARAM(input_data), PARAM(biases), PARAM(dimensions),
1207 PARAM(output_data));
1208
1209 if (ok()) {
1210 if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
1211 CheckError(
1212 dnn->DoBiasAdd(this, input_data, biases, dimensions, output_data));
1213 } else {
1214 SetErrorAndLogNoDnnSupport();
1215 }
1216 }
1217 return *this;
1218 }
1219
ThenPoolForward(const dnn::PoolingDescriptor & pooling_dimensions,const dnn::BatchDescriptor & input_dimensions,const DeviceMemory<double> & input_data,const dnn::BatchDescriptor & output_dimensions,DeviceMemory<double> * output_data,ScratchAllocator * workspace_allocator)1220 Stream &Stream::ThenPoolForward(
1221 const dnn::PoolingDescriptor &pooling_dimensions,
1222 const dnn::BatchDescriptor &input_dimensions,
1223 const DeviceMemory<double> &input_data,
1224 const dnn::BatchDescriptor &output_dimensions,
1225 DeviceMemory<double> *output_data, ScratchAllocator *workspace_allocator) {
1226 VLOG_CALL(PARAM(pooling_dimensions), PARAM(input_dimensions),
1227 PARAM(input_data), PARAM(output_dimensions), PARAM(output_data),
1228 PARAM(workspace_allocator));
1229
1230 if (ok()) {
1231 if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
1232 CheckError(dnn->DoPoolForward(this, pooling_dimensions, input_dimensions,
1233 input_data, output_dimensions, output_data,
1234 workspace_allocator));
1235 } else {
1236 SetError();
1237 LOG(WARNING)
1238 << "attempting to perform DNN operation using StreamExecutor "
1239 "without DNN support";
1240 }
1241 }
1242 return *this;
1243 }
1244
ThenPoolForward(const dnn::PoolingDescriptor & pooling_dimensions,const dnn::BatchDescriptor & input_dimensions,const DeviceMemory<float> & input_data,const dnn::BatchDescriptor & output_dimensions,DeviceMemory<float> * output_data,ScratchAllocator * workspace_allocator)1245 Stream &Stream::ThenPoolForward(
1246 const dnn::PoolingDescriptor &pooling_dimensions,
1247 const dnn::BatchDescriptor &input_dimensions,
1248 const DeviceMemory<float> &input_data,
1249 const dnn::BatchDescriptor &output_dimensions,
1250 DeviceMemory<float> *output_data, ScratchAllocator *workspace_allocator) {
1251 VLOG_CALL(PARAM(pooling_dimensions), PARAM(input_dimensions),
1252 PARAM(input_data), PARAM(output_dimensions), PARAM(output_data),
1253 PARAM(workspace_allocator));
1254
1255 if (ok()) {
1256 if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
1257 CheckError(dnn->DoPoolForward(this, pooling_dimensions, input_dimensions,
1258 input_data, output_dimensions, output_data,
1259 workspace_allocator));
1260 } else {
1261 SetErrorAndLogNoDnnSupport();
1262 }
1263 }
1264 return *this;
1265 }
1266
ThenPoolForward(const dnn::PoolingDescriptor & pooling_dimensions,const dnn::BatchDescriptor & input_dimensions,const DeviceMemory<Eigen::half> & input_data,const dnn::BatchDescriptor & output_dimensions,DeviceMemory<Eigen::half> * output_data,ScratchAllocator * workspace_allocator)1267 Stream &Stream::ThenPoolForward(
1268 const dnn::PoolingDescriptor &pooling_dimensions,
1269 const dnn::BatchDescriptor &input_dimensions,
1270 const DeviceMemory<Eigen::half> &input_data,
1271 const dnn::BatchDescriptor &output_dimensions,
1272 DeviceMemory<Eigen::half> *output_data,
1273 ScratchAllocator *workspace_allocator) {
1274 VLOG_CALL(PARAM(pooling_dimensions), PARAM(input_dimensions),
1275 PARAM(input_data), PARAM(output_dimensions), PARAM(output_data),
1276 PARAM(workspace_allocator));
1277
1278 if (ok()) {
1279 if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
1280 CheckError(dnn->DoPoolForward(this, pooling_dimensions, input_dimensions,
1281 input_data, output_dimensions, output_data,
1282 workspace_allocator));
1283 } else {
1284 SetErrorAndLogNoDnnSupport();
1285 }
1286 }
1287 return *this;
1288 }
1289
ThenPoolForward(const dnn::PoolingDescriptor & pooling_dimensions,const dnn::BatchDescriptor & input_dimensions,const DeviceMemory<int8> & input_data,const dnn::BatchDescriptor & output_dimensions,DeviceMemory<int8> * output_data,ScratchAllocator * workspace_allocator)1290 Stream &Stream::ThenPoolForward(
1291 const dnn::PoolingDescriptor &pooling_dimensions,
1292 const dnn::BatchDescriptor &input_dimensions,
1293 const DeviceMemory<int8> &input_data,
1294 const dnn::BatchDescriptor &output_dimensions,
1295 DeviceMemory<int8> *output_data, ScratchAllocator *workspace_allocator) {
1296 VLOG_CALL(PARAM(pooling_dimensions), PARAM(input_dimensions),
1297 PARAM(input_data), PARAM(output_dimensions), PARAM(output_data),
1298 PARAM(workspace_allocator));
1299
1300 if (ok()) {
1301 if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
1302 CheckError(dnn->DoPoolForward(this, pooling_dimensions, input_dimensions,
1303 input_data, output_dimensions, output_data,
1304 workspace_allocator));
1305 } else {
1306 SetErrorAndLogNoDnnSupport();
1307 }
1308 }
1309 return *this;
1310 }
1311
ThenPoolBackward(const dnn::PoolingDescriptor & pooling_dimensions,const dnn::BatchDescriptor & input_dimensions,const DeviceMemory<double> & input_data,const dnn::BatchDescriptor & output_dimensions,const DeviceMemory<double> & output_data,const DeviceMemory<double> & input_diff_data,DeviceMemory<double> * output_diff_data,ScratchAllocator * workspace_allocator)1312 Stream &Stream::ThenPoolBackward(
1313 const dnn::PoolingDescriptor &pooling_dimensions,
1314 const dnn::BatchDescriptor &input_dimensions,
1315 const DeviceMemory<double> &input_data,
1316 const dnn::BatchDescriptor &output_dimensions,
1317 const DeviceMemory<double> &output_data,
1318 const DeviceMemory<double> &input_diff_data,
1319 DeviceMemory<double> *output_diff_data,
1320 ScratchAllocator *workspace_allocator) {
1321 VLOG_CALL(PARAM(pooling_dimensions), PARAM(input_dimensions),
1322 PARAM(input_data), PARAM(output_dimensions), PARAM(output_data),
1323 PARAM(input_diff_data), PARAM(output_diff_data),
1324 PARAM(workspace_allocator));
1325
1326 if (ok()) {
1327 if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
1328 CheckError(dnn->DoPoolBackward(this, pooling_dimensions, input_dimensions,
1329 input_data, output_dimensions, output_data,
1330 input_diff_data, output_diff_data,
1331 workspace_allocator));
1332 } else {
1333 SetError();
1334 LOG(WARNING)
1335 << "attempting to perform DNN operation using StreamExecutor "
1336 "without DNN support";
1337 }
1338 }
1339 return *this;
1340 }
1341
ThenPoolBackward(const dnn::PoolingDescriptor & pooling_dimensions,const dnn::BatchDescriptor & input_dimensions,const DeviceMemory<float> & input_data,const dnn::BatchDescriptor & output_dimensions,const DeviceMemory<float> & output_data,const DeviceMemory<float> & input_diff_data,DeviceMemory<float> * output_diff_data,ScratchAllocator * workspace_allocator)1342 Stream &Stream::ThenPoolBackward(
1343 const dnn::PoolingDescriptor &pooling_dimensions,
1344 const dnn::BatchDescriptor &input_dimensions,
1345 const DeviceMemory<float> &input_data,
1346 const dnn::BatchDescriptor &output_dimensions,
1347 const DeviceMemory<float> &output_data,
1348 const DeviceMemory<float> &input_diff_data,
1349 DeviceMemory<float> *output_diff_data,
1350 ScratchAllocator *workspace_allocator) {
1351 VLOG_CALL(PARAM(pooling_dimensions), PARAM(input_dimensions),
1352 PARAM(input_data), PARAM(output_dimensions), PARAM(output_data),
1353 PARAM(input_diff_data), PARAM(output_diff_data),
1354 PARAM(workspace_allocator));
1355
1356 if (ok()) {
1357 if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
1358 CheckError(dnn->DoPoolBackward(this, pooling_dimensions, input_dimensions,
1359 input_data, output_dimensions, output_data,
1360 input_diff_data, output_diff_data,
1361 workspace_allocator));
1362 } else {
1363 SetErrorAndLogNoDnnSupport();
1364 }
1365 }
1366 return *this;
1367 }
1368
ThenPoolBackward(const dnn::PoolingDescriptor & pooling_dimensions,const dnn::BatchDescriptor & input_dimensions,const DeviceMemory<Eigen::half> & input_data,const dnn::BatchDescriptor & output_dimensions,const DeviceMemory<Eigen::half> & output_data,const DeviceMemory<Eigen::half> & input_diff_data,DeviceMemory<Eigen::half> * output_diff_data,ScratchAllocator * workspace_allocator)1369 Stream &Stream::ThenPoolBackward(
1370 const dnn::PoolingDescriptor &pooling_dimensions,
1371 const dnn::BatchDescriptor &input_dimensions,
1372 const DeviceMemory<Eigen::half> &input_data,
1373 const dnn::BatchDescriptor &output_dimensions,
1374 const DeviceMemory<Eigen::half> &output_data,
1375 const DeviceMemory<Eigen::half> &input_diff_data,
1376 DeviceMemory<Eigen::half> *output_diff_data,
1377 ScratchAllocator *workspace_allocator) {
1378 VLOG_CALL(PARAM(pooling_dimensions), PARAM(input_dimensions),
1379 PARAM(input_data), PARAM(output_dimensions), PARAM(output_data),
1380 PARAM(input_diff_data), PARAM(output_diff_data),
1381 PARAM(workspace_allocator));
1382
1383 if (ok()) {
1384 if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
1385 CheckError(dnn->DoPoolBackward(this, pooling_dimensions, input_dimensions,
1386 input_data, output_dimensions, output_data,
1387 input_diff_data, output_diff_data,
1388 workspace_allocator));
1389 } else {
1390 SetErrorAndLogNoDnnSupport();
1391 }
1392 }
1393 return *this;
1394 }
1395
ThenNormalizeWithDimensions(const dnn::NormalizeDescriptor & normalize_descriptor,const dnn::BatchDescriptor & dimensions,const DeviceMemory<float> & input_data,DeviceMemory<float> * output_data)1396 Stream &Stream::ThenNormalizeWithDimensions(
1397 const dnn::NormalizeDescriptor &normalize_descriptor,
1398 const dnn::BatchDescriptor &dimensions,
1399 const DeviceMemory<float> &input_data, DeviceMemory<float> *output_data) {
1400 VLOG_CALL(PARAM(normalize_descriptor), PARAM(dimensions), PARAM(input_data),
1401 PARAM(output_data));
1402
1403 if (ok()) {
1404 if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
1405 CheckError(dnn->DoNormalizeWithDimensions(
1406 this, normalize_descriptor, dimensions, input_data, output_data));
1407 } else {
1408 SetErrorAndLogNoDnnSupport();
1409 }
1410 }
1411 return *this;
1412 }
1413
ThenNormalizeBackwardWithDimensions(const dnn::NormalizeDescriptor & normalize_descriptor,const dnn::BatchDescriptor & dimensions,const DeviceMemory<float> & raw_data,const DeviceMemory<float> & normalized_data,const DeviceMemory<float> & normalized_variable_gradient,DeviceMemory<float> * raw_variable_gradient,ScratchAllocator * workspace_allocator)1414 Stream &Stream::ThenNormalizeBackwardWithDimensions(
1415 const dnn::NormalizeDescriptor &normalize_descriptor,
1416 const dnn::BatchDescriptor &dimensions, const DeviceMemory<float> &raw_data,
1417 const DeviceMemory<float> &normalized_data,
1418 const DeviceMemory<float> &normalized_variable_gradient,
1419 DeviceMemory<float> *raw_variable_gradient,
1420 ScratchAllocator *workspace_allocator) {
1421 VLOG_CALL(PARAM(normalize_descriptor), PARAM(dimensions), PARAM(raw_data),
1422 PARAM(normalized_data), PARAM(normalized_variable_gradient),
1423 PARAM(raw_variable_gradient), PARAM(workspace_allocator));
1424
1425 if (ok()) {
1426 if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
1427 CheckError(dnn->DoNormalizeBackwardWithDimensions(
1428 this, normalize_descriptor, dimensions, raw_data, normalized_data,
1429 normalized_variable_gradient, raw_variable_gradient,
1430 workspace_allocator));
1431 } else {
1432 SetErrorAndLogNoDnnSupport();
1433 }
1434 }
1435 return *this;
1436 }
1437
ThenActivate(dnn::ActivationMode activation_mode,const dnn::BatchDescriptor & dimensions,const DeviceMemory<float> & input_data,DeviceMemory<float> * output_data)1438 Stream &Stream::ThenActivate(dnn::ActivationMode activation_mode,
1439 const dnn::BatchDescriptor &dimensions,
1440 const DeviceMemory<float> &input_data,
1441 DeviceMemory<float> *output_data) {
1442 return ThenActivateWithOptions(activation_mode, dimensions, input_data,
1443 output_data, /*options=*/0);
1444 }
1445
ThenActivateWithOptions(dnn::ActivationMode activation_mode,const dnn::BatchDescriptor & dimensions,const DeviceMemory<float> & input_data,DeviceMemory<float> * output_data,uint64 options)1446 Stream &Stream::ThenActivateWithOptions(dnn::ActivationMode activation_mode,
1447 const dnn::BatchDescriptor &dimensions,
1448 const DeviceMemory<float> &input_data,
1449 DeviceMemory<float> *output_data,
1450 uint64 options) {
1451 VLOG_CALL(PARAM(activation_mode), PARAM(dimensions), PARAM(input_data),
1452 PARAM(output_data), PARAM(options));
1453
1454 if (ok()) {
1455 if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
1456 CheckError(dnn->DoActivate(this, activation_mode, dimensions, input_data,
1457 output_data, options));
1458 } else {
1459 SetErrorAndLogNoDnnSupport();
1460 }
1461 }
1462 return *this;
1463 }
1464
ThenDepthConcatenate(port::ArraySlice<dnn::BatchDescriptor> input_dimensions,port::ArraySlice<const DeviceMemory<float> * > input_data,DeviceMemory<float> * output_data)1465 Stream &Stream::ThenDepthConcatenate(
1466 port::ArraySlice<dnn::BatchDescriptor> input_dimensions,
1467 port::ArraySlice<const DeviceMemory<float> *> input_data,
1468 DeviceMemory<float> *output_data) {
1469 VLOG_CALL(PARAM(input_dimensions), PARAM(input_data), PARAM(output_data));
1470
1471 for (size_t i = 1; i < input_dimensions.size(); ++i) {
1472 if (input_dimensions[i].count() != input_dimensions[0].count() ||
1473 input_dimensions[i].height() != input_dimensions[0].height() ||
1474 input_dimensions[i].width() != input_dimensions[0].width()) {
1475 SetError();
1476 LOG(ERROR) << "Incompatible dimensions for depth concatenation.\n"
1477 << "input_dimensions[0]: " << input_dimensions[0].ToString()
1478 << "input_dimensions[" << i
1479 << "]: " << input_dimensions[i].ToString();
1480 return *this;
1481 }
1482 }
1483
1484 if (ok()) {
1485 if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
1486 CheckError(dnn->DoDepthConcatenate(this, input_dimensions, input_data,
1487 output_data));
1488 } else {
1489 SetErrorAndLogNoDnnSupport();
1490 }
1491 }
1492 return *this;
1493 }
1494
ThenSpaceConcatenate(port::ArraySlice<dnn::BatchDescriptor> input_dimensions,port::ArraySlice<const DeviceMemory<float> * > input_data,DeviceMemory<float> * output_data,dnn::SpaceConcatenateMode concat_direction)1495 Stream &Stream::ThenSpaceConcatenate(
1496 port::ArraySlice<dnn::BatchDescriptor> input_dimensions,
1497 port::ArraySlice<const DeviceMemory<float> *> input_data,
1498 DeviceMemory<float> *output_data,
1499 dnn::SpaceConcatenateMode concat_direction) {
1500 VLOG_CALL(PARAM(input_dimensions), PARAM(input_data), PARAM(output_data));
1501
1502 // Check that the input dimensions of all the other batches match those of the
1503 // first batch.
1504 for (size_t i = 1; i < input_dimensions.size(); ++i) {
1505 if ((concat_direction == dnn::SpaceConcatenateMode::XDirection) &&
1506 (input_dimensions[i].count() != input_dimensions[0].count() ||
1507 input_dimensions[i].height() != input_dimensions[0].height() ||
1508 input_dimensions[i].feature_map_count() !=
1509 input_dimensions[0].feature_map_count())) {
1510 SetError();
1511 LOG(ERROR) << "Incompatible dimensions for X concatenation.\n"
1512 << "input_dimensions[0]: " << input_dimensions[0].ToString()
1513 << "input_dimensions[" << i
1514 << "]: " << input_dimensions[i].ToString();
1515 return *this;
1516 }
1517
1518 if ((concat_direction == dnn::SpaceConcatenateMode::YDirection) &&
1519 (input_dimensions[i].count() != input_dimensions[0].count() ||
1520 input_dimensions[i].width() != input_dimensions[0].width() ||
1521 input_dimensions[i].feature_map_count() !=
1522 input_dimensions[0].feature_map_count())) {
1523 SetError();
1524 LOG(ERROR) << "Incompatible dimensions for Y concatenation.\n"
1525 << "input_dimensions[0]: " << input_dimensions[0].ToString()
1526 << "input_dimensions[" << i
1527 << "]: " << input_dimensions[i].ToString();
1528 return *this;
1529 }
1530 }
1531 if (ok()) {
1532 if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
1533 CheckError(dnn->DoSpaceConcatenate(this, input_dimensions, input_data,
1534 output_data, concat_direction));
1535 } else {
1536 SetErrorAndLogNoDnnSupport();
1537 }
1538 }
1539 return *this;
1540 }
1541
ThenReshape(const dnn::BatchDescriptor & input_dimensions,const DeviceMemory<float> & input_data,const dnn::BatchDescriptor & output_dimensions,DeviceMemory<float> * output_data)1542 Stream &Stream::ThenReshape(const dnn::BatchDescriptor &input_dimensions,
1543 const DeviceMemory<float> &input_data,
1544 const dnn::BatchDescriptor &output_dimensions,
1545 DeviceMemory<float> *output_data) {
1546 VLOG_CALL(PARAM(input_dimensions), PARAM(input_data),
1547 PARAM(output_dimensions), PARAM(output_data));
1548
1549 if (ok()) {
1550 if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
1551 CheckError(dnn->DoReshape(this, input_dimensions, input_data,
1552 output_dimensions, output_data));
1553 } else {
1554 SetErrorAndLogNoDnnSupport();
1555 }
1556 }
1557 return *this;
1558 }
1559
ThenDepthToSpace(const dnn::BatchDescriptor & input_dimensions,const DeviceMemory<float> & input_data,const dnn::DepthToSpaceLayout & depth_to_space_layout,const int sqrt_depth_reduction,DeviceMemory<float> * output_data)1560 Stream &Stream::ThenDepthToSpace(
1561 const dnn::BatchDescriptor &input_dimensions,
1562 const DeviceMemory<float> &input_data,
1563 const dnn::DepthToSpaceLayout &depth_to_space_layout,
1564 const int sqrt_depth_reduction, DeviceMemory<float> *output_data) {
1565 VLOG_CALL(PARAM(input_dimensions), PARAM(input_data),
1566 PARAM(depth_to_space_layout), PARAM(sqrt_depth_reduction),
1567 PARAM(output_data));
1568
1569 if (ok()) {
1570 if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
1571 CheckError(dnn->DoDepthToSpace(this, input_dimensions, input_data,
1572 depth_to_space_layout,
1573 sqrt_depth_reduction, output_data));
1574 } else {
1575 SetErrorAndLogNoDnnSupport();
1576 }
1577 }
1578 return *this;
1579 }
1580
ThenSpaceToDepth(const dnn::BatchDescriptor & input_dimensions,const DeviceMemory<float> & input_data,const dnn::DepthToSpaceLayout & space_to_depth_layout,const int sqrt_depth_increase,DeviceMemory<float> * output_data)1581 Stream &Stream::ThenSpaceToDepth(
1582 const dnn::BatchDescriptor &input_dimensions,
1583 const DeviceMemory<float> &input_data,
1584 const dnn::DepthToSpaceLayout &space_to_depth_layout,
1585 const int sqrt_depth_increase, DeviceMemory<float> *output_data) {
1586 VLOG_CALL(PARAM(input_dimensions), PARAM(input_data),
1587 PARAM(space_to_depth_layout), PARAM(sqrt_depth_increase),
1588 PARAM(output_data));
1589
1590 if (ok()) {
1591 if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
1592 CheckError(dnn->DoSpaceToDepth(this, input_dimensions, input_data,
1593 space_to_depth_layout, sqrt_depth_increase,
1594 output_data));
1595 } else {
1596 SetErrorAndLogNoDnnSupport();
1597 }
1598 }
1599 return *this;
1600 }
1601
ThenElementwiseOperate(dnn::ElementwiseOperation operation,port::ArraySlice<dnn::BatchDescriptor> input_dimensions,port::ArraySlice<const DeviceMemory<float> * > input_data,const dnn::BatchDescriptor & output_dimensions,DeviceMemory<float> * output_data)1602 Stream &Stream::ThenElementwiseOperate(
1603 dnn::ElementwiseOperation operation,
1604 port::ArraySlice<dnn::BatchDescriptor> input_dimensions,
1605 port::ArraySlice<const DeviceMemory<float> *> input_data,
1606 const dnn::BatchDescriptor &output_dimensions,
1607 DeviceMemory<float> *output_data) {
1608 VLOG_CALL(PARAM(operation), PARAM(input_dimensions), PARAM(input_data),
1609 PARAM(output_dimensions), PARAM(output_data));
1610
1611 if (ok()) {
1612 if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
1613 CheckError(dnn->DoElementwiseOperate(this, operation, input_dimensions,
1614 input_data, output_dimensions,
1615 output_data));
1616 } else {
1617 SetErrorAndLogNoDnnSupport();
1618 }
1619 }
1620 return *this;
1621 }
1622
ThenElementwiseOperateScaledQuantized(dnn::ElementwiseOperation operation,port::ArraySlice<int> input_multiplicands,int output_divisor,port::ArraySlice<dnn::BatchDescriptor> input_dimensions,port::ArraySlice<const DeviceMemory<float> * > input_data,const dnn::BatchDescriptor & output_dimensions,DeviceMemory<float> * output_data)1623 Stream &Stream::ThenElementwiseOperateScaledQuantized(
1624 dnn::ElementwiseOperation operation,
1625 port::ArraySlice<int> input_multiplicands, int output_divisor,
1626 port::ArraySlice<dnn::BatchDescriptor> input_dimensions,
1627 port::ArraySlice<const DeviceMemory<float> *> input_data,
1628 const dnn::BatchDescriptor &output_dimensions,
1629 DeviceMemory<float> *output_data) {
1630 VLOG_CALL(PARAM(operation), PARAM(input_multiplicands), PARAM(output_divisor),
1631 PARAM(input_dimensions), PARAM(input_data),
1632 PARAM(output_dimensions), PARAM(output_data));
1633
1634 if (ok()) {
1635 if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
1636 CheckError(dnn->DoElementwiseOperateScaledQuantized(
1637 this, operation, input_multiplicands, output_divisor,
1638 input_dimensions, input_data, output_dimensions, output_data));
1639 } else {
1640 SetErrorAndLogNoDnnSupport();
1641 }
1642 }
1643 return *this;
1644 }
1645
ThenXYPad(const dnn::BatchDescriptor & dimensions,const DeviceMemory<float> & input_data,int64 left_pad,int64 right_pad,int64 top_pad,int64 bottom_pad,DeviceMemory<float> * output_data)1646 Stream &Stream::ThenXYPad(const dnn::BatchDescriptor &dimensions,
1647 const DeviceMemory<float> &input_data, int64 left_pad,
1648 int64 right_pad, int64 top_pad, int64 bottom_pad,
1649 DeviceMemory<float> *output_data) {
1650 VLOG_CALL(PARAM(dimensions), PARAM(input_data), PARAM(left_pad),
1651 PARAM(right_pad), PARAM(top_pad), PARAM(bottom_pad),
1652 PARAM(output_data));
1653
1654 if (ok()) {
1655 if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
1656 CheckError(dnn->DoXYPad(this, dimensions, input_data, left_pad, right_pad,
1657 top_pad, bottom_pad, output_data));
1658 } else {
1659 SetErrorAndLogNoDnnSupport();
1660 }
1661 }
1662 return *this;
1663 }
1664
ThenXYSlice(const dnn::BatchDescriptor & dimensions,const DeviceMemory<float> & input_data,int64 left_trim,int64 right_trim,int64 top_trim,int64 bottom_trim,DeviceMemory<float> * output_data)1665 Stream &Stream::ThenXYSlice(const dnn::BatchDescriptor &dimensions,
1666 const DeviceMemory<float> &input_data,
1667 int64 left_trim, int64 right_trim, int64 top_trim,
1668 int64 bottom_trim,
1669 DeviceMemory<float> *output_data) {
1670 VLOG_CALL(PARAM(dimensions), PARAM(input_data), PARAM(left_trim),
1671 PARAM(right_trim), PARAM(top_trim), PARAM(bottom_trim),
1672 PARAM(output_data));
1673
1674 if (ok()) {
1675 if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
1676 CheckError(dnn->DoXYSlice(this, dimensions, input_data, left_trim,
1677 right_trim, top_trim, bottom_trim,
1678 output_data));
1679 } else {
1680 SetErrorAndLogNoDnnSupport();
1681 }
1682 }
1683 return *this;
1684 }
1685
ThenXYBroadcast(const dnn::BatchDescriptor & dimensions,const DeviceMemory<float> & input_data,int64 replicate_x,int64 replicate_y,DeviceMemory<float> * output_data)1686 Stream &Stream::ThenXYBroadcast(const dnn::BatchDescriptor &dimensions,
1687 const DeviceMemory<float> &input_data,
1688 int64 replicate_x, int64 replicate_y,
1689 DeviceMemory<float> *output_data) {
1690 VLOG_CALL(PARAM(dimensions), PARAM(input_data), PARAM(replicate_x),
1691 PARAM(replicate_y), PARAM(output_data));
1692
1693 if (ok()) {
1694 if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
1695 CheckError(dnn->DoXYBroadcast(this, dimensions, input_data, replicate_x,
1696 replicate_y, output_data));
1697 } else {
1698 SetErrorAndLogNoDnnSupport();
1699 }
1700 }
1701 return *this;
1702 }
1703
ThenMemcpyD2HQuantized(const DeviceMemory<float> & gpu_unquantized_src,dnn::QuantizedActivationMode mode,void * host_dst,uint64 size)1704 Stream &Stream::ThenMemcpyD2HQuantized(
1705 const DeviceMemory<float> &gpu_unquantized_src,
1706 dnn::QuantizedActivationMode mode, void *host_dst, uint64 size) {
1707 VLOG_CALL(PARAM(gpu_unquantized_src), PARAM(mode), PARAM(host_dst),
1708 PARAM(size));
1709
1710 if (ok()) {
1711 if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
1712 CheckError(dnn->DoMemcpyD2HQuantized(this, gpu_unquantized_src, mode,
1713 host_dst, size));
1714 } else {
1715 SetErrorAndLogNoDnnSupport();
1716 }
1717 }
1718 return *this;
1719 }
1720
ThenMemcpyH2DQuantized(const void * host_src,uint64 size,dnn::QuantizedActivationMode mode,DeviceMemory<float> * gpu_unquantized_dst)1721 Stream &Stream::ThenMemcpyH2DQuantized(
1722 const void *host_src, uint64 size, dnn::QuantizedActivationMode mode,
1723 DeviceMemory<float> *gpu_unquantized_dst) {
1724 VLOG_CALL(PARAM(host_src), PARAM(size), PARAM(mode),
1725 PARAM(gpu_unquantized_dst));
1726
1727 if (ok()) {
1728 if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
1729 CheckError(dnn->DoMemcpyH2DQuantized(this, host_src, size, mode,
1730 gpu_unquantized_dst));
1731 } else {
1732 SetErrorAndLogNoDnnSupport();
1733 }
1734 }
1735 return *this;
1736 }
1737
GetOrCreateSubStream()1738 Stream *Stream::GetOrCreateSubStream() {
1739 mutex_lock lock(mu_);
1740
1741 // Look for the first reusable sub_stream that is ok, dropping !ok sub_streams
1742 // we encounter along the way.
1743 for (int64 index = 0; index < sub_streams_.size();) {
1744 std::pair<std::unique_ptr<Stream>, bool> &pair = sub_streams_[index];
1745 if (pair.second) {
1746 // The sub_stream is reusable.
1747 Stream *sub_stream = pair.first.get();
1748 if (sub_stream->ok()) {
1749 VLOG(1) << DebugStreamPointers() << " reusing sub_stream "
1750 << sub_stream->DebugStreamPointers();
1751 pair.second = false;
1752 return sub_stream;
1753 }
1754
1755 // The stream is reusable and not ok. Streams have a monotonic state
1756 // machine; the stream will remain in !ok forever. Swap it with the last
1757 // stream and pop it off.
1758 const int64 last = sub_streams_.size() - 1;
1759 if (index != last) {
1760 std::swap(pair, sub_streams_[last]);
1761 }
1762 sub_streams_.pop_back();
1763 VLOG(1) << DebugStreamPointers() << " dropped !ok sub_stream "
1764 << sub_stream->DebugStreamPointers();
1765 } else {
1766 // The sub_stream is not reusable, move on to the next one.
1767 ++index;
1768 }
1769 }
1770
1771 // No streams are reusable; create a new stream.
1772 sub_streams_.emplace_back(std::unique_ptr<Stream>{new Stream{parent_}},
1773 false);
1774 Stream *sub_stream = sub_streams_.back().first.get();
1775 sub_stream->Init();
1776 if (!sub_stream->ok_) {
1777 LOG(ERROR) << "sub-stream failed to be initialized";
1778 }
1779 VLOG(1) << DebugStreamPointers() << " created new sub_stream "
1780 << sub_stream->DebugStreamPointers();
1781
1782 return sub_stream;
1783 }
1784
ReturnSubStream(Stream * sub_stream)1785 void Stream::ReturnSubStream(Stream *sub_stream) {
1786 mutex_lock lock(mu_);
1787
1788 // Look for the sub-stream.
1789 for (int64 index = 0; index < sub_streams_.size(); ++index) {
1790 std::pair<std::unique_ptr<Stream>, bool> &pair = sub_streams_[index];
1791 if (pair.first.get() != sub_stream) {
1792 continue;
1793 }
1794
1795 // Found the sub_stream.
1796 if (sub_stream->ok()) {
1797 VLOG(1) << DebugStreamPointers() << " returned ok sub_stream "
1798 << sub_stream->DebugStreamPointers();
1799 pair.second = true;
1800 } else {
1801 // The returned stream is not ok. Streams have a monotonic state
1802 // machine; the stream will remain in !ok forever. Swap it with the last
1803 // stream and pop it off.
1804 VLOG(1) << DebugStreamPointers() << " returned !ok sub_stream "
1805 << sub_stream->DebugStreamPointers();
1806 const int64 last = sub_streams_.size() - 1;
1807 if (index != last) {
1808 std::swap(pair, sub_streams_[last]);
1809 }
1810 sub_streams_.pop_back();
1811 }
1812 return;
1813 }
1814
1815 LOG(FATAL) << DebugStreamPointers()
1816 << " did not create the returned sub-stream "
1817 << sub_stream->DebugStreamPointers();
1818 }
1819
ThenStartTimer(Timer * t)1820 Stream &Stream::ThenStartTimer(Timer *t) {
1821 VLOG_CALL(PARAM(t));
1822
1823 if (ok()) {
1824 CheckError(parent_->StartTimer(this, t));
1825 } else {
1826 LOG(INFO) << DebugStreamPointers()
1827 << " did not enqueue 'start timer': " << t;
1828 }
1829 return *this;
1830 }
1831
ThenStopTimer(Timer * t)1832 Stream &Stream::ThenStopTimer(Timer *t) {
1833 VLOG_CALL(PARAM(t));
1834
1835 if (ok()) {
1836 CheckError(parent_->StopTimer(this, t));
1837 } else {
1838 LOG(INFO) << DebugStreamPointers()
1839 << " did not enqueue 'stop timer': " << t;
1840 }
1841 return *this;
1842 }
1843
ThenWaitFor(Stream * other)1844 Stream &Stream::ThenWaitFor(Stream *other) {
1845 VLOG_CALL(PARAM(other));
1846
1847 CHECK(this != other) << "stream cannot wait for itself";
1848 if (ok() && other->ok()) {
1849 CheckError(parent_->CreateStreamDependency(this, other));
1850 } else {
1851 SetError();
1852 LOG(INFO) << DebugStreamPointers() << " did not wait for "
1853 << other->DebugStreamPointers();
1854 }
1855 return *this;
1856 }
1857
ThenWaitFor(Event * event)1858 Stream &Stream::ThenWaitFor(Event *event) {
1859 VLOG_CALL(PARAM(event));
1860
1861 if (ok()) {
1862 port::Status status = parent_->WaitForEvent(this, event);
1863 if (!status.ok()) {
1864 LOG(ERROR) << "Error waiting for event in stream: "
1865 << status.error_message()
1866 << "; not marking stream as bad, as the Event object may be "
1867 << "at fault. Monitor for further errors.";
1868 }
1869 } else {
1870 LOG(INFO) << DebugStreamPointers() << " did not wait for an event.";
1871 }
1872 return *this;
1873 }
1874
1875 // A functor that implements ThenBlasXXX interfaces, which calls DoBlasXXX
1876 // functions and logs for errors.
1877 template <typename... Args>
1878 struct ThenBlasImpl {
1879 // blas_func is the DoBlasXXX member function pointer, and args are its
1880 // arguments except the first one of Stream* type.
operator ()stream_executor::ThenBlasImpl1881 Stream &operator()(Stream *stream,
1882 bool (blas::BlasSupport::*blas_func)(Stream *, Args...),
1883 Args... args) {
1884 return Run(stream, blas_func, /*record_error=*/true, args...);
1885 }
1886
1887 // Like operator(), but only calls stream->CheckError() if record_error is
1888 // true.
1889 Stream &Run(Stream *stream,
1890 bool (blas::BlasSupport::*blas_func)(Stream *, Args...),
1891 bool record_error, Args... args);
1892 };
1893
1894 template <typename... Args>
Run(Stream * stream,bool (blas::BlasSupport::* blas_func)(Stream *,Args...),bool record_error,Args...args)1895 Stream &ThenBlasImpl<Args...>::Run(
1896 Stream *stream, bool (blas::BlasSupport::*blas_func)(Stream *, Args...),
1897 bool record_error, Args... args) {
1898 if (stream->ok()) {
1899 bool ok;
1900 if (blas::BlasSupport *blas = stream->parent_->AsBlas()) {
1901 ok = (blas->*blas_func)(stream, args...);
1902 } else {
1903 LOG(WARNING)
1904 << "attempting to perform BLAS operation using StreamExecutor "
1905 "without BLAS support";
1906 ok = false;
1907 }
1908 if (record_error) {
1909 stream->CheckError(ok);
1910 }
1911 }
1912 return *stream;
1913 }
1914
ThenBlasAsum(uint64 elem_count,const DeviceMemory<float> & x,int incx,DeviceMemory<float> * result)1915 Stream &Stream::ThenBlasAsum(uint64 elem_count, const DeviceMemory<float> &x,
1916 int incx, DeviceMemory<float> *result) {
1917 VLOG_CALL(PARAM(elem_count), PARAM(x), PARAM(incx), PARAM(result));
1918
1919 ThenBlasImpl<uint64, const DeviceMemory<float> &, int, DeviceMemory<float> *>
1920 impl;
1921 return impl(this, &blas::BlasSupport::DoBlasAsum, elem_count, x, incx,
1922 result);
1923 }
1924
ThenBlasAsum(uint64 elem_count,const DeviceMemory<double> & x,int incx,DeviceMemory<double> * result)1925 Stream &Stream::ThenBlasAsum(uint64 elem_count, const DeviceMemory<double> &x,
1926 int incx, DeviceMemory<double> *result) {
1927 VLOG_CALL(PARAM(elem_count), PARAM(x), PARAM(incx), PARAM(result));
1928
1929 ThenBlasImpl<uint64, const DeviceMemory<double> &, int,
1930 DeviceMemory<double> *> impl;
1931 return impl(this, &blas::BlasSupport::DoBlasAsum, elem_count, x, incx,
1932 result);
1933 }
1934
ThenBlasAsum(uint64 elem_count,const DeviceMemory<std::complex<float>> & x,int incx,DeviceMemory<float> * result)1935 Stream &Stream::ThenBlasAsum(uint64 elem_count,
1936 const DeviceMemory<std::complex<float>> &x,
1937 int incx, DeviceMemory<float> *result) {
1938 VLOG_CALL(PARAM(elem_count), PARAM(x), PARAM(incx), PARAM(result));
1939
1940 ThenBlasImpl<uint64, const DeviceMemory<std::complex<float>> &, int,
1941 DeviceMemory<float> *> impl;
1942 return impl(this, &blas::BlasSupport::DoBlasAsum, elem_count, x, incx,
1943 result);
1944 }
1945
ThenBlasAsum(uint64 elem_count,const DeviceMemory<std::complex<double>> & x,int incx,DeviceMemory<double> * result)1946 Stream &Stream::ThenBlasAsum(uint64 elem_count,
1947 const DeviceMemory<std::complex<double>> &x,
1948 int incx, DeviceMemory<double> *result) {
1949 VLOG_CALL(PARAM(elem_count), PARAM(x), PARAM(incx), PARAM(result));
1950
1951 ThenBlasImpl<uint64, const DeviceMemory<std::complex<double>> &, int,
1952 DeviceMemory<double> *> impl;
1953 return impl(this, &blas::BlasSupport::DoBlasAsum, elem_count, x, incx,
1954 result);
1955 }
1956
ThenBlasAxpy(uint64 elem_count,float alpha,const DeviceMemory<float> & x,int incx,DeviceMemory<float> * y,int incy)1957 Stream &Stream::ThenBlasAxpy(uint64 elem_count, float alpha,
1958 const DeviceMemory<float> &x, int incx,
1959 DeviceMemory<float> *y, int incy) {
1960 VLOG_CALL(PARAM(elem_count), PARAM(alpha), PARAM(x), PARAM(incx), PARAM(y),
1961 PARAM(incy));
1962
1963 ThenBlasImpl<uint64, float, const DeviceMemory<float> &, int,
1964 DeviceMemory<float> *, int> impl;
1965 return impl(this, &blas::BlasSupport::DoBlasAxpy, elem_count, alpha, x, incx,
1966 y, incy);
1967 }
1968
ThenBlasAxpy(uint64 elem_count,double alpha,const DeviceMemory<double> & x,int incx,DeviceMemory<double> * y,int incy)1969 Stream &Stream::ThenBlasAxpy(uint64 elem_count, double alpha,
1970 const DeviceMemory<double> &x, int incx,
1971 DeviceMemory<double> *y, int incy) {
1972 VLOG_CALL(PARAM(elem_count), PARAM(alpha), PARAM(x), PARAM(incx), PARAM(y),
1973 PARAM(incy));
1974
1975 ThenBlasImpl<uint64, double, const DeviceMemory<double> &, int,
1976 DeviceMemory<double> *, int> impl;
1977 return impl(this, &blas::BlasSupport::DoBlasAxpy, elem_count, alpha, x, incx,
1978 y, incy);
1979 }
1980
ThenBlasAxpy(uint64 elem_count,std::complex<float> alpha,const DeviceMemory<std::complex<float>> & x,int incx,DeviceMemory<std::complex<float>> * y,int incy)1981 Stream &Stream::ThenBlasAxpy(uint64 elem_count, std::complex<float> alpha,
1982 const DeviceMemory<std::complex<float>> &x,
1983 int incx, DeviceMemory<std::complex<float>> *y,
1984 int incy) {
1985 VLOG_CALL(PARAM(elem_count), PARAM(alpha), PARAM(x), PARAM(incx), PARAM(y),
1986 PARAM(incy));
1987
1988 ThenBlasImpl<uint64, std::complex<float>,
1989 const DeviceMemory<std::complex<float>> &, int,
1990 DeviceMemory<std::complex<float>> *, int> impl;
1991 return impl(this, &blas::BlasSupport::DoBlasAxpy, elem_count, alpha, x, incx,
1992 y, incy);
1993 }
1994
ThenBlasAxpy(uint64 elem_count,std::complex<double> alpha,const DeviceMemory<std::complex<double>> & x,int incx,DeviceMemory<std::complex<double>> * y,int incy)1995 Stream &Stream::ThenBlasAxpy(uint64 elem_count, std::complex<double> alpha,
1996 const DeviceMemory<std::complex<double>> &x,
1997 int incx, DeviceMemory<std::complex<double>> *y,
1998 int incy) {
1999 VLOG_CALL(PARAM(elem_count), PARAM(alpha), PARAM(x), PARAM(incx), PARAM(y),
2000 PARAM(incy));
2001
2002 ThenBlasImpl<uint64, std::complex<double>,
2003 const DeviceMemory<std::complex<double>> &, int,
2004 DeviceMemory<std::complex<double>> *, int> impl;
2005 return impl(this, &blas::BlasSupport::DoBlasAxpy, elem_count, alpha, x, incx,
2006 y, incy);
2007 }
2008
ThenBlasCopy(uint64 elem_count,const DeviceMemory<float> & x,int incx,DeviceMemory<float> * y,int incy)2009 Stream &Stream::ThenBlasCopy(uint64 elem_count, const DeviceMemory<float> &x,
2010 int incx, DeviceMemory<float> *y, int incy) {
2011 VLOG_CALL(PARAM(elem_count), PARAM(x), PARAM(incx), PARAM(y), PARAM(incy));
2012
2013 ThenBlasImpl<uint64, const DeviceMemory<float> &, int, DeviceMemory<float> *,
2014 int> impl;
2015 return impl(this, &blas::BlasSupport::DoBlasCopy, elem_count, x, incx, y,
2016 incy);
2017 }
2018
ThenBlasCopy(uint64 elem_count,const DeviceMemory<double> & x,int incx,DeviceMemory<double> * y,int incy)2019 Stream &Stream::ThenBlasCopy(uint64 elem_count, const DeviceMemory<double> &x,
2020 int incx, DeviceMemory<double> *y, int incy) {
2021 VLOG_CALL(PARAM(elem_count), PARAM(x), PARAM(incx), PARAM(y), PARAM(incy));
2022
2023 ThenBlasImpl<uint64, const DeviceMemory<double> &, int,
2024 DeviceMemory<double> *, int> impl;
2025 return impl(this, &blas::BlasSupport::DoBlasCopy, elem_count, x, incx, y,
2026 incy);
2027 }
2028
ThenBlasCopy(uint64 elem_count,const DeviceMemory<std::complex<float>> & x,int incx,DeviceMemory<std::complex<float>> * y,int incy)2029 Stream &Stream::ThenBlasCopy(uint64 elem_count,
2030 const DeviceMemory<std::complex<float>> &x,
2031 int incx, DeviceMemory<std::complex<float>> *y,
2032 int incy) {
2033 VLOG_CALL(PARAM(elem_count), PARAM(x), PARAM(incx), PARAM(y), PARAM(incy));
2034
2035 ThenBlasImpl<uint64, const DeviceMemory<std::complex<float>> &, int,
2036 DeviceMemory<std::complex<float>> *, int> impl;
2037 return impl(this, &blas::BlasSupport::DoBlasCopy, elem_count, x, incx, y,
2038 incy);
2039 }
2040
ThenBlasCopy(uint64 elem_count,const DeviceMemory<std::complex<double>> & x,int incx,DeviceMemory<std::complex<double>> * y,int incy)2041 Stream &Stream::ThenBlasCopy(uint64 elem_count,
2042 const DeviceMemory<std::complex<double>> &x,
2043 int incx, DeviceMemory<std::complex<double>> *y,
2044 int incy) {
2045 VLOG_CALL(PARAM(elem_count), PARAM(x), PARAM(incx), PARAM(y), PARAM(incy));
2046
2047 ThenBlasImpl<uint64, const DeviceMemory<std::complex<double>> &, int,
2048 DeviceMemory<std::complex<double>> *, int> impl;
2049 return impl(this, &blas::BlasSupport::DoBlasCopy, elem_count, x, incx, y,
2050 incy);
2051 }
2052
ThenBlasDot(uint64 elem_count,const DeviceMemory<float> & x,int incx,const DeviceMemory<float> & y,int incy,DeviceMemory<float> * result)2053 Stream &Stream::ThenBlasDot(uint64 elem_count, const DeviceMemory<float> &x,
2054 int incx, const DeviceMemory<float> &y, int incy,
2055 DeviceMemory<float> *result) {
2056 VLOG_CALL(PARAM(elem_count), PARAM(x), PARAM(incx), PARAM(y), PARAM(incy),
2057 PARAM(result));
2058
2059 ThenBlasImpl<uint64, const DeviceMemory<float> &, int,
2060 const DeviceMemory<float> &, int, DeviceMemory<float> *> impl;
2061 return impl(this, &blas::BlasSupport::DoBlasDot, elem_count, x, incx, y, incy,
2062 result);
2063 }
2064
ThenBlasDot(uint64 elem_count,const DeviceMemory<double> & x,int incx,const DeviceMemory<double> & y,int incy,DeviceMemory<double> * result)2065 Stream &Stream::ThenBlasDot(uint64 elem_count, const DeviceMemory<double> &x,
2066 int incx, const DeviceMemory<double> &y, int incy,
2067 DeviceMemory<double> *result) {
2068 VLOG_CALL(PARAM(elem_count), PARAM(x), PARAM(incx), PARAM(y), PARAM(incy),
2069 PARAM(result));
2070
2071 ThenBlasImpl<uint64, const DeviceMemory<double> &, int,
2072 const DeviceMemory<double> &, int, DeviceMemory<double> *> impl;
2073 return impl(this, &blas::BlasSupport::DoBlasDot, elem_count, x, incx, y, incy,
2074 result);
2075 }
2076
ThenBlasDotc(uint64 elem_count,const DeviceMemory<std::complex<float>> & x,int incx,const DeviceMemory<std::complex<float>> & y,int incy,DeviceMemory<std::complex<float>> * result)2077 Stream &Stream::ThenBlasDotc(uint64 elem_count,
2078 const DeviceMemory<std::complex<float>> &x,
2079 int incx,
2080 const DeviceMemory<std::complex<float>> &y,
2081 int incy,
2082 DeviceMemory<std::complex<float>> *result) {
2083 VLOG_CALL(PARAM(elem_count), PARAM(x), PARAM(incx), PARAM(y), PARAM(incy),
2084 PARAM(result));
2085
2086 ThenBlasImpl<uint64, const DeviceMemory<std::complex<float>> &, int,
2087 const DeviceMemory<std::complex<float>> &, int,
2088 DeviceMemory<std::complex<float>> *> impl;
2089 return impl(this, &blas::BlasSupport::DoBlasDotc, elem_count, x, incx, y,
2090 incy, result);
2091 }
2092
ThenBlasDotc(uint64 elem_count,const DeviceMemory<std::complex<double>> & x,int incx,const DeviceMemory<std::complex<double>> & y,int incy,DeviceMemory<std::complex<double>> * result)2093 Stream &Stream::ThenBlasDotc(uint64 elem_count,
2094 const DeviceMemory<std::complex<double>> &x,
2095 int incx,
2096 const DeviceMemory<std::complex<double>> &y,
2097 int incy,
2098 DeviceMemory<std::complex<double>> *result) {
2099 VLOG_CALL(PARAM(elem_count), PARAM(x), PARAM(incx), PARAM(y), PARAM(incy),
2100 PARAM(result));
2101
2102 ThenBlasImpl<uint64, const DeviceMemory<std::complex<double>> &, int,
2103 const DeviceMemory<std::complex<double>> &, int,
2104 DeviceMemory<std::complex<double>> *> impl;
2105 return impl(this, &blas::BlasSupport::DoBlasDotc, elem_count, x, incx, y,
2106 incy, result);
2107 }
2108
ThenBlasDotu(uint64 elem_count,const DeviceMemory<std::complex<float>> & x,int incx,const DeviceMemory<std::complex<float>> & y,int incy,DeviceMemory<std::complex<float>> * result)2109 Stream &Stream::ThenBlasDotu(uint64 elem_count,
2110 const DeviceMemory<std::complex<float>> &x,
2111 int incx,
2112 const DeviceMemory<std::complex<float>> &y,
2113 int incy,
2114 DeviceMemory<std::complex<float>> *result) {
2115 VLOG_CALL(PARAM(elem_count), PARAM(x), PARAM(incx), PARAM(y), PARAM(incy),
2116 PARAM(result));
2117
2118 ThenBlasImpl<uint64, const DeviceMemory<std::complex<float>> &, int,
2119 const DeviceMemory<std::complex<float>> &, int,
2120 DeviceMemory<std::complex<float>> *> impl;
2121 return impl(this, &blas::BlasSupport::DoBlasDotu, elem_count, x, incx, y,
2122 incy, result);
2123 }
2124
ThenBlasDotu(uint64 elem_count,const DeviceMemory<std::complex<double>> & x,int incx,const DeviceMemory<std::complex<double>> & y,int incy,DeviceMemory<std::complex<double>> * result)2125 Stream &Stream::ThenBlasDotu(uint64 elem_count,
2126 const DeviceMemory<std::complex<double>> &x,
2127 int incx,
2128 const DeviceMemory<std::complex<double>> &y,
2129 int incy,
2130 DeviceMemory<std::complex<double>> *result) {
2131 VLOG_CALL(PARAM(elem_count), PARAM(x), PARAM(incx), PARAM(y), PARAM(incy),
2132 PARAM(result));
2133
2134 ThenBlasImpl<uint64, const DeviceMemory<std::complex<double>> &, int,
2135 const DeviceMemory<std::complex<double>> &, int,
2136 DeviceMemory<std::complex<double>> *> impl;
2137 return impl(this, &blas::BlasSupport::DoBlasDotu, elem_count, x, incx, y,
2138 incy, result);
2139 }
2140
ThenBlasNrm2(uint64 elem_count,const DeviceMemory<float> & x,int incx,DeviceMemory<float> * result)2141 Stream &Stream::ThenBlasNrm2(uint64 elem_count, const DeviceMemory<float> &x,
2142 int incx, DeviceMemory<float> *result) {
2143 VLOG_CALL(PARAM(elem_count), PARAM(x), PARAM(incx), PARAM(result));
2144
2145 ThenBlasImpl<uint64, const DeviceMemory<float> &, int, DeviceMemory<float> *>
2146 impl;
2147 return impl(this, &blas::BlasSupport::DoBlasNrm2, elem_count, x, incx,
2148 result);
2149 }
2150
ThenBlasNrm2(uint64 elem_count,const DeviceMemory<double> & x,int incx,DeviceMemory<double> * result)2151 Stream &Stream::ThenBlasNrm2(uint64 elem_count, const DeviceMemory<double> &x,
2152 int incx, DeviceMemory<double> *result) {
2153 VLOG_CALL(PARAM(elem_count), PARAM(x), PARAM(incx), PARAM(result));
2154
2155 ThenBlasImpl<uint64, const DeviceMemory<double> &, int,
2156 DeviceMemory<double> *> impl;
2157 return impl(this, &blas::BlasSupport::DoBlasNrm2, elem_count, x, incx,
2158 result);
2159 }
2160
ThenBlasNrm2(uint64 elem_count,const DeviceMemory<std::complex<float>> & x,int incx,DeviceMemory<float> * result)2161 Stream &Stream::ThenBlasNrm2(uint64 elem_count,
2162 const DeviceMemory<std::complex<float>> &x,
2163 int incx, DeviceMemory<float> *result) {
2164 VLOG_CALL(PARAM(elem_count), PARAM(x), PARAM(incx), PARAM(result));
2165
2166 ThenBlasImpl<uint64, const DeviceMemory<std::complex<float>> &, int,
2167 DeviceMemory<float> *> impl;
2168 return impl(this, &blas::BlasSupport::DoBlasNrm2, elem_count, x, incx,
2169 result);
2170 }
2171
ThenBlasNrm2(uint64 elem_count,const DeviceMemory<std::complex<double>> & x,int incx,DeviceMemory<double> * result)2172 Stream &Stream::ThenBlasNrm2(uint64 elem_count,
2173 const DeviceMemory<std::complex<double>> &x,
2174 int incx, DeviceMemory<double> *result) {
2175 VLOG_CALL(PARAM(elem_count), PARAM(x), PARAM(incx), PARAM(result));
2176
2177 ThenBlasImpl<uint64, const DeviceMemory<std::complex<double>> &, int,
2178 DeviceMemory<double> *> impl;
2179 return impl(this, &blas::BlasSupport::DoBlasNrm2, elem_count, x, incx,
2180 result);
2181 }
2182
ThenBlasRot(uint64 elem_count,DeviceMemory<float> * x,int incx,DeviceMemory<float> * y,int incy,float c,float s)2183 Stream &Stream::ThenBlasRot(uint64 elem_count, DeviceMemory<float> *x, int incx,
2184 DeviceMemory<float> *y, int incy, float c,
2185 float s) {
2186 VLOG_CALL(PARAM(elem_count), PARAM(x), PARAM(incx), PARAM(y), PARAM(incy),
2187 PARAM(c), PARAM(s));
2188
2189 ThenBlasImpl<uint64, DeviceMemory<float> *, int, DeviceMemory<float> *, int,
2190 float, float> impl;
2191 return impl(this, &blas::BlasSupport::DoBlasRot, elem_count, x, incx, y, incy,
2192 c, s);
2193 }
2194
ThenBlasRot(uint64 elem_count,DeviceMemory<double> * x,int incx,DeviceMemory<double> * y,int incy,double c,double s)2195 Stream &Stream::ThenBlasRot(uint64 elem_count, DeviceMemory<double> *x,
2196 int incx, DeviceMemory<double> *y, int incy,
2197 double c, double s) {
2198 VLOG_CALL(PARAM(elem_count), PARAM(x), PARAM(incx), PARAM(y), PARAM(incy),
2199 PARAM(c), PARAM(s));
2200
2201 ThenBlasImpl<uint64, DeviceMemory<double> *, int, DeviceMemory<double> *, int,
2202 double, double> impl;
2203 return impl(this, &blas::BlasSupport::DoBlasRot, elem_count, x, incx, y, incy,
2204 c, s);
2205 }
2206
ThenBlasRot(uint64 elem_count,DeviceMemory<std::complex<float>> * x,int incx,DeviceMemory<std::complex<float>> * y,int incy,float c,float s)2207 Stream &Stream::ThenBlasRot(uint64 elem_count,
2208 DeviceMemory<std::complex<float>> *x, int incx,
2209 DeviceMemory<std::complex<float>> *y, int incy,
2210 float c, float s) {
2211 VLOG_CALL(PARAM(elem_count), PARAM(x), PARAM(incx), PARAM(y), PARAM(incy),
2212 PARAM(c), PARAM(s));
2213
2214 ThenBlasImpl<uint64, DeviceMemory<std::complex<float>> *, int,
2215 DeviceMemory<std::complex<float>> *, int, float, float> impl;
2216 return impl(this, &blas::BlasSupport::DoBlasRot, elem_count, x, incx, y, incy,
2217 c, s);
2218 }
2219
ThenBlasRot(uint64 elem_count,DeviceMemory<std::complex<double>> * x,int incx,DeviceMemory<std::complex<double>> * y,int incy,double c,double s)2220 Stream &Stream::ThenBlasRot(uint64 elem_count,
2221 DeviceMemory<std::complex<double>> *x, int incx,
2222 DeviceMemory<std::complex<double>> *y, int incy,
2223 double c, double s) {
2224 VLOG_CALL(PARAM(elem_count), PARAM(x), PARAM(incx), PARAM(y), PARAM(incy),
2225 PARAM(c), PARAM(s));
2226
2227 ThenBlasImpl<uint64, DeviceMemory<std::complex<double>> *, int,
2228 DeviceMemory<std::complex<double>> *, int, double, double> impl;
2229 return impl(this, &blas::BlasSupport::DoBlasRot, elem_count, x, incx, y, incy,
2230 c, s);
2231 }
2232
ThenBlasRotg(DeviceMemory<float> * a,DeviceMemory<float> * b,DeviceMemory<float> * c,DeviceMemory<float> * s)2233 Stream &Stream::ThenBlasRotg(DeviceMemory<float> *a, DeviceMemory<float> *b,
2234 DeviceMemory<float> *c, DeviceMemory<float> *s) {
2235 VLOG_CALL(PARAM(a), PARAM(b), PARAM(c), PARAM(s));
2236
2237 ThenBlasImpl<DeviceMemory<float> *, DeviceMemory<float> *,
2238 DeviceMemory<float> *, DeviceMemory<float> *> impl;
2239 return impl(this, &blas::BlasSupport::DoBlasRotg, a, b, c, s);
2240 }
2241
ThenBlasRotg(DeviceMemory<double> * a,DeviceMemory<double> * b,DeviceMemory<double> * c,DeviceMemory<double> * s)2242 Stream &Stream::ThenBlasRotg(DeviceMemory<double> *a, DeviceMemory<double> *b,
2243 DeviceMemory<double> *c, DeviceMemory<double> *s) {
2244 VLOG_CALL(PARAM(a), PARAM(b), PARAM(c), PARAM(s));
2245
2246 ThenBlasImpl<DeviceMemory<double> *, DeviceMemory<double> *,
2247 DeviceMemory<double> *, DeviceMemory<double> *> impl;
2248 return impl(this, &blas::BlasSupport::DoBlasRotg, a, b, c, s);
2249 }
2250
ThenBlasRotg(DeviceMemory<std::complex<float>> * a,DeviceMemory<std::complex<float>> * b,DeviceMemory<float> * c,DeviceMemory<std::complex<float>> * s)2251 Stream &Stream::ThenBlasRotg(DeviceMemory<std::complex<float>> *a,
2252 DeviceMemory<std::complex<float>> *b,
2253 DeviceMemory<float> *c,
2254 DeviceMemory<std::complex<float>> *s) {
2255 VLOG_CALL(PARAM(a), PARAM(b), PARAM(c), PARAM(s));
2256
2257 ThenBlasImpl<DeviceMemory<std::complex<float>> *,
2258 DeviceMemory<std::complex<float>> *, DeviceMemory<float> *,
2259 DeviceMemory<std::complex<float>> *> impl;
2260 return impl(this, &blas::BlasSupport::DoBlasRotg, a, b, c, s);
2261 }
2262
ThenBlasRotg(DeviceMemory<std::complex<double>> * a,DeviceMemory<std::complex<double>> * b,DeviceMemory<double> * c,DeviceMemory<std::complex<double>> * s)2263 Stream &Stream::ThenBlasRotg(DeviceMemory<std::complex<double>> *a,
2264 DeviceMemory<std::complex<double>> *b,
2265 DeviceMemory<double> *c,
2266 DeviceMemory<std::complex<double>> *s) {
2267 VLOG_CALL(PARAM(a), PARAM(b), PARAM(c), PARAM(s));
2268
2269 ThenBlasImpl<DeviceMemory<std::complex<double>> *,
2270 DeviceMemory<std::complex<double>> *, DeviceMemory<double> *,
2271 DeviceMemory<std::complex<double>> *> impl;
2272 return impl(this, &blas::BlasSupport::DoBlasRotg, a, b, c, s);
2273 }
2274
ThenBlasRotm(uint64 elem_count,DeviceMemory<float> * x,int incx,DeviceMemory<float> * y,int incy,const DeviceMemory<float> & param)2275 Stream &Stream::ThenBlasRotm(uint64 elem_count, DeviceMemory<float> *x,
2276 int incx, DeviceMemory<float> *y, int incy,
2277 const DeviceMemory<float> ¶m) {
2278 VLOG_CALL(PARAM(elem_count), PARAM(x), PARAM(incx), PARAM(y), PARAM(incy),
2279 PARAM(param));
2280
2281 ThenBlasImpl<uint64, DeviceMemory<float> *, int, DeviceMemory<float> *, int,
2282 const DeviceMemory<float> &> impl;
2283 return impl(this, &blas::BlasSupport::DoBlasRotm, elem_count, x, incx, y,
2284 incy, param);
2285 }
2286
ThenBlasRotm(uint64 elem_count,DeviceMemory<double> * x,int incx,DeviceMemory<double> * y,int incy,const DeviceMemory<double> & param)2287 Stream &Stream::ThenBlasRotm(uint64 elem_count, DeviceMemory<double> *x,
2288 int incx, DeviceMemory<double> *y, int incy,
2289 const DeviceMemory<double> ¶m) {
2290 VLOG_CALL(PARAM(elem_count), PARAM(x), PARAM(incx), PARAM(y), PARAM(incy),
2291 PARAM(param));
2292
2293 ThenBlasImpl<uint64, DeviceMemory<double> *, int, DeviceMemory<double> *, int,
2294 const DeviceMemory<double> &> impl;
2295 return impl(this, &blas::BlasSupport::DoBlasRotm, elem_count, x, incx, y,
2296 incy, param);
2297 }
2298
ThenBlasRotmg(DeviceMemory<float> * d1,DeviceMemory<float> * d2,DeviceMemory<float> * x1,const DeviceMemory<float> & y1,DeviceMemory<float> * param)2299 Stream &Stream::ThenBlasRotmg(DeviceMemory<float> *d1, DeviceMemory<float> *d2,
2300 DeviceMemory<float> *x1,
2301 const DeviceMemory<float> &y1,
2302 DeviceMemory<float> *param) {
2303 VLOG_CALL(PARAM(d1), PARAM(d2), PARAM(x1), PARAM(y1), PARAM(param));
2304
2305 ThenBlasImpl<DeviceMemory<float> *, DeviceMemory<float> *,
2306 DeviceMemory<float> *, const DeviceMemory<float> &,
2307 DeviceMemory<float> *> impl;
2308 return impl(this, &blas::BlasSupport::DoBlasRotmg, d1, d2, x1, y1, param);
2309 }
2310
ThenBlasRotmg(DeviceMemory<double> * d1,DeviceMemory<double> * d2,DeviceMemory<double> * x1,const DeviceMemory<double> & y1,DeviceMemory<double> * param)2311 Stream &Stream::ThenBlasRotmg(DeviceMemory<double> *d1,
2312 DeviceMemory<double> *d2,
2313 DeviceMemory<double> *x1,
2314 const DeviceMemory<double> &y1,
2315 DeviceMemory<double> *param) {
2316 VLOG_CALL(PARAM(d1), PARAM(d2), PARAM(x1), PARAM(y1), PARAM(param));
2317
2318 ThenBlasImpl<DeviceMemory<double> *, DeviceMemory<double> *,
2319 DeviceMemory<double> *, const DeviceMemory<double> &,
2320 DeviceMemory<double> *> impl;
2321 return impl(this, &blas::BlasSupport::DoBlasRotmg, d1, d2, x1, y1, param);
2322 }
2323
ThenBlasScal(uint64 elem_count,float alpha,DeviceMemory<float> * x,int incx)2324 Stream &Stream::ThenBlasScal(uint64 elem_count, float alpha,
2325 DeviceMemory<float> *x, int incx) {
2326 VLOG_CALL(PARAM(elem_count), PARAM(alpha), PARAM(x), PARAM(incx));
2327
2328 ThenBlasImpl<uint64, float, DeviceMemory<float> *, int> impl;
2329 return impl(this, &blas::BlasSupport::DoBlasScal, elem_count, alpha, x, incx);
2330 }
2331
ThenBlasScal(uint64 elem_count,double alpha,DeviceMemory<double> * x,int incx)2332 Stream &Stream::ThenBlasScal(uint64 elem_count, double alpha,
2333 DeviceMemory<double> *x, int incx) {
2334 VLOG_CALL(PARAM(elem_count), PARAM(alpha), PARAM(x), PARAM(incx));
2335
2336 ThenBlasImpl<uint64, double, DeviceMemory<double> *, int> impl;
2337 return impl(this, &blas::BlasSupport::DoBlasScal, elem_count, alpha, x, incx);
2338 }
2339
ThenBlasScal(uint64 elem_count,float alpha,DeviceMemory<std::complex<float>> * x,int incx)2340 Stream &Stream::ThenBlasScal(uint64 elem_count, float alpha,
2341 DeviceMemory<std::complex<float>> *x, int incx) {
2342 VLOG_CALL(PARAM(elem_count), PARAM(alpha), PARAM(x), PARAM(incx));
2343
2344 ThenBlasImpl<uint64, float, DeviceMemory<std::complex<float>> *, int> impl;
2345 return impl(this, &blas::BlasSupport::DoBlasScal, elem_count, alpha, x, incx);
2346 }
2347
ThenBlasScal(uint64 elem_count,double alpha,DeviceMemory<std::complex<double>> * x,int incx)2348 Stream &Stream::ThenBlasScal(uint64 elem_count, double alpha,
2349 DeviceMemory<std::complex<double>> *x, int incx) {
2350 VLOG_CALL(PARAM(elem_count), PARAM(alpha), PARAM(x), PARAM(incx));
2351
2352 ThenBlasImpl<uint64, double, DeviceMemory<std::complex<double>> *, int> impl;
2353 return impl(this, &blas::BlasSupport::DoBlasScal, elem_count, alpha, x, incx);
2354 }
2355
ThenBlasScal(uint64 elem_count,std::complex<float> alpha,DeviceMemory<std::complex<float>> * x,int incx)2356 Stream &Stream::ThenBlasScal(uint64 elem_count, std::complex<float> alpha,
2357 DeviceMemory<std::complex<float>> *x, int incx) {
2358 VLOG_CALL(PARAM(elem_count), PARAM(alpha), PARAM(x), PARAM(incx));
2359
2360 ThenBlasImpl<uint64, std::complex<float>, DeviceMemory<std::complex<float>> *,
2361 int> impl;
2362 return impl(this, &blas::BlasSupport::DoBlasScal, elem_count, alpha, x, incx);
2363 }
2364
ThenBlasScal(uint64 elem_count,std::complex<double> alpha,DeviceMemory<std::complex<double>> * x,int incx)2365 Stream &Stream::ThenBlasScal(uint64 elem_count, std::complex<double> alpha,
2366 DeviceMemory<std::complex<double>> *x, int incx) {
2367 VLOG_CALL(PARAM(elem_count), PARAM(alpha), PARAM(x), PARAM(incx));
2368
2369 ThenBlasImpl<uint64, std::complex<double>,
2370 DeviceMemory<std::complex<double>> *, int> impl;
2371 return impl(this, &blas::BlasSupport::DoBlasScal, elem_count, alpha, x, incx);
2372 }
2373
ThenBlasSwap(uint64 elem_count,DeviceMemory<float> * x,int incx,DeviceMemory<float> * y,int incy)2374 Stream &Stream::ThenBlasSwap(uint64 elem_count, DeviceMemory<float> *x,
2375 int incx, DeviceMemory<float> *y, int incy) {
2376 VLOG_CALL(PARAM(elem_count), PARAM(x), PARAM(incx), PARAM(y), PARAM(incy));
2377
2378 ThenBlasImpl<uint64, DeviceMemory<float> *, int, DeviceMemory<float> *, int>
2379 impl;
2380 return impl(this, &blas::BlasSupport::DoBlasSwap, elem_count, x, incx, y,
2381 incy);
2382 }
2383
ThenBlasSwap(uint64 elem_count,DeviceMemory<double> * x,int incx,DeviceMemory<double> * y,int incy)2384 Stream &Stream::ThenBlasSwap(uint64 elem_count, DeviceMemory<double> *x,
2385 int incx, DeviceMemory<double> *y, int incy) {
2386 VLOG_CALL(PARAM(elem_count), PARAM(x), PARAM(incx), PARAM(y), PARAM(incy));
2387
2388 ThenBlasImpl<uint64, DeviceMemory<double> *, int, DeviceMemory<double> *, int>
2389 impl;
2390 return impl(this, &blas::BlasSupport::DoBlasSwap, elem_count, x, incx, y,
2391 incy);
2392 }
2393
ThenBlasSwap(uint64 elem_count,DeviceMemory<std::complex<float>> * x,int incx,DeviceMemory<std::complex<float>> * y,int incy)2394 Stream &Stream::ThenBlasSwap(uint64 elem_count,
2395 DeviceMemory<std::complex<float>> *x, int incx,
2396 DeviceMemory<std::complex<float>> *y, int incy) {
2397 VLOG_CALL(PARAM(elem_count), PARAM(x), PARAM(incx), PARAM(y), PARAM(incy));
2398
2399 ThenBlasImpl<uint64, DeviceMemory<std::complex<float>> *, int,
2400 DeviceMemory<std::complex<float>> *, int> impl;
2401 return impl(this, &blas::BlasSupport::DoBlasSwap, elem_count, x, incx, y,
2402 incy);
2403 }
2404
ThenBlasSwap(uint64 elem_count,DeviceMemory<std::complex<double>> * x,int incx,DeviceMemory<std::complex<double>> * y,int incy)2405 Stream &Stream::ThenBlasSwap(uint64 elem_count,
2406 DeviceMemory<std::complex<double>> *x, int incx,
2407 DeviceMemory<std::complex<double>> *y, int incy) {
2408 VLOG_CALL(PARAM(elem_count), PARAM(x), PARAM(incx), PARAM(y), PARAM(incy));
2409
2410 ThenBlasImpl<uint64, DeviceMemory<std::complex<double>> *, int,
2411 DeviceMemory<std::complex<double>> *, int> impl;
2412 return impl(this, &blas::BlasSupport::DoBlasSwap, elem_count, x, incx, y,
2413 incy);
2414 }
2415
ThenBlasIamax(uint64 elem_count,const DeviceMemory<float> & x,int incx,DeviceMemory<int> * result)2416 Stream &Stream::ThenBlasIamax(uint64 elem_count, const DeviceMemory<float> &x,
2417 int incx, DeviceMemory<int> *result) {
2418 VLOG_CALL(PARAM(elem_count), PARAM(x), PARAM(incx), PARAM(result));
2419
2420 ThenBlasImpl<uint64, const DeviceMemory<float> &, int, DeviceMemory<int> *>
2421 impl;
2422 return impl(this, &blas::BlasSupport::DoBlasIamax, elem_count, x, incx,
2423 result);
2424 }
2425
ThenBlasIamax(uint64 elem_count,const DeviceMemory<double> & x,int incx,DeviceMemory<int> * result)2426 Stream &Stream::ThenBlasIamax(uint64 elem_count, const DeviceMemory<double> &x,
2427 int incx, DeviceMemory<int> *result) {
2428 VLOG_CALL(PARAM(elem_count), PARAM(x), PARAM(incx), PARAM(result));
2429
2430 ThenBlasImpl<uint64, const DeviceMemory<double> &, int, DeviceMemory<int> *>
2431 impl;
2432 return impl(this, &blas::BlasSupport::DoBlasIamax, elem_count, x, incx,
2433 result);
2434 }
2435
ThenBlasIamax(uint64 elem_count,const DeviceMemory<std::complex<float>> & x,int incx,DeviceMemory<int> * result)2436 Stream &Stream::ThenBlasIamax(uint64 elem_count,
2437 const DeviceMemory<std::complex<float>> &x,
2438 int incx, DeviceMemory<int> *result) {
2439 VLOG_CALL(PARAM(elem_count), PARAM(x), PARAM(incx), PARAM(result));
2440
2441 ThenBlasImpl<uint64, const DeviceMemory<std::complex<float>> &, int,
2442 DeviceMemory<int> *> impl;
2443 return impl(this, &blas::BlasSupport::DoBlasIamax, elem_count, x, incx,
2444 result);
2445 }
2446
ThenBlasIamax(uint64 elem_count,const DeviceMemory<std::complex<double>> & x,int incx,DeviceMemory<int> * result)2447 Stream &Stream::ThenBlasIamax(uint64 elem_count,
2448 const DeviceMemory<std::complex<double>> &x,
2449 int incx, DeviceMemory<int> *result) {
2450 VLOG_CALL(PARAM(elem_count), PARAM(x), PARAM(incx), PARAM(result));
2451
2452 ThenBlasImpl<uint64, const DeviceMemory<std::complex<double>> &, int,
2453 DeviceMemory<int> *> impl;
2454 return impl(this, &blas::BlasSupport::DoBlasIamax, elem_count, x, incx,
2455 result);
2456 }
2457
ThenBlasIamin(uint64 elem_count,const DeviceMemory<float> & x,int incx,DeviceMemory<int> * result)2458 Stream &Stream::ThenBlasIamin(uint64 elem_count, const DeviceMemory<float> &x,
2459 int incx, DeviceMemory<int> *result) {
2460 VLOG_CALL(PARAM(elem_count), PARAM(x), PARAM(incx), PARAM(result));
2461
2462 ThenBlasImpl<uint64, const DeviceMemory<float> &, int, DeviceMemory<int> *>
2463 impl;
2464 return impl(this, &blas::BlasSupport::DoBlasIamin, elem_count, x, incx,
2465 result);
2466 }
2467
ThenBlasIamin(uint64 elem_count,const DeviceMemory<double> & x,int incx,DeviceMemory<int> * result)2468 Stream &Stream::ThenBlasIamin(uint64 elem_count, const DeviceMemory<double> &x,
2469 int incx, DeviceMemory<int> *result) {
2470 VLOG_CALL(PARAM(elem_count), PARAM(x), PARAM(incx), PARAM(result));
2471
2472 ThenBlasImpl<uint64, const DeviceMemory<double> &, int, DeviceMemory<int> *>
2473 impl;
2474 return impl(this, &blas::BlasSupport::DoBlasIamin, elem_count, x, incx,
2475 result);
2476 }
2477
ThenBlasIamin(uint64 elem_count,const DeviceMemory<std::complex<float>> & x,int incx,DeviceMemory<int> * result)2478 Stream &Stream::ThenBlasIamin(uint64 elem_count,
2479 const DeviceMemory<std::complex<float>> &x,
2480 int incx, DeviceMemory<int> *result) {
2481 VLOG_CALL(PARAM(elem_count), PARAM(x), PARAM(incx), PARAM(result));
2482
2483 ThenBlasImpl<uint64, const DeviceMemory<std::complex<float>> &, int,
2484 DeviceMemory<int> *> impl;
2485 return impl(this, &blas::BlasSupport::DoBlasIamin, elem_count, x, incx,
2486 result);
2487 }
2488
ThenBlasIamin(uint64 elem_count,const DeviceMemory<std::complex<double>> & x,int incx,DeviceMemory<int> * result)2489 Stream &Stream::ThenBlasIamin(uint64 elem_count,
2490 const DeviceMemory<std::complex<double>> &x,
2491 int incx, DeviceMemory<int> *result) {
2492 VLOG_CALL(PARAM(elem_count), PARAM(x), PARAM(incx), PARAM(result));
2493
2494 ThenBlasImpl<uint64, const DeviceMemory<std::complex<double>> &, int,
2495 DeviceMemory<int> *> impl;
2496 return impl(this, &blas::BlasSupport::DoBlasIamin, elem_count, x, incx,
2497 result);
2498 }
2499
ThenBlasGbmv(blas::Transpose trans,uint64 m,uint64 n,uint64 kl,uint64 ku,float alpha,const DeviceMemory<float> & a,int lda,const DeviceMemory<float> & x,int incx,float beta,DeviceMemory<float> * y,int incy)2500 Stream &Stream::ThenBlasGbmv(blas::Transpose trans, uint64 m, uint64 n,
2501 uint64 kl, uint64 ku, float alpha,
2502 const DeviceMemory<float> &a, int lda,
2503 const DeviceMemory<float> &x, int incx, float beta,
2504 DeviceMemory<float> *y, int incy) {
2505 VLOG_CALL(PARAM(trans), PARAM(m), PARAM(n), PARAM(kl), PARAM(ku),
2506 PARAM(alpha), PARAM(a), PARAM(lda), PARAM(x), PARAM(incx),
2507 PARAM(beta), PARAM(y), PARAM(incy));
2508
2509 ThenBlasImpl<blas::Transpose, uint64, uint64, uint64, uint64, float,
2510 const DeviceMemory<float> &, int, const DeviceMemory<float> &,
2511 int, float, DeviceMemory<float> *, int> impl;
2512 return impl(this, &blas::BlasSupport::DoBlasGbmv, trans, m, n, kl, ku, alpha,
2513 a, lda, x, incx, beta, y, incy);
2514 }
2515
ThenBlasGbmv(blas::Transpose trans,uint64 m,uint64 n,uint64 kl,uint64 ku,double alpha,const DeviceMemory<double> & a,int lda,const DeviceMemory<double> & x,int incx,double beta,DeviceMemory<double> * y,int incy)2516 Stream &Stream::ThenBlasGbmv(blas::Transpose trans, uint64 m, uint64 n,
2517 uint64 kl, uint64 ku, double alpha,
2518 const DeviceMemory<double> &a, int lda,
2519 const DeviceMemory<double> &x, int incx,
2520 double beta, DeviceMemory<double> *y, int incy) {
2521 VLOG_CALL(PARAM(trans), PARAM(m), PARAM(n), PARAM(kl), PARAM(ku),
2522 PARAM(alpha), PARAM(a), PARAM(lda), PARAM(x), PARAM(incx),
2523 PARAM(beta), PARAM(y), PARAM(incy));
2524
2525 ThenBlasImpl<blas::Transpose, uint64, uint64, uint64, uint64, double,
2526 const DeviceMemory<double> &, int, const DeviceMemory<double> &,
2527 int, double, DeviceMemory<double> *, int> impl;
2528 return impl(this, &blas::BlasSupport::DoBlasGbmv, trans, m, n, kl, ku, alpha,
2529 a, lda, x, incx, beta, y, incy);
2530 }
2531
ThenBlasGbmv(blas::Transpose trans,uint64 m,uint64 n,uint64 kl,uint64 ku,std::complex<float> alpha,const DeviceMemory<std::complex<float>> & a,int lda,const DeviceMemory<std::complex<float>> & x,int incx,std::complex<float> beta,DeviceMemory<std::complex<float>> * y,int incy)2532 Stream &Stream::ThenBlasGbmv(blas::Transpose trans, uint64 m, uint64 n,
2533 uint64 kl, uint64 ku, std::complex<float> alpha,
2534 const DeviceMemory<std::complex<float>> &a,
2535 int lda,
2536 const DeviceMemory<std::complex<float>> &x,
2537 int incx, std::complex<float> beta,
2538 DeviceMemory<std::complex<float>> *y, int incy) {
2539 VLOG_CALL(PARAM(trans), PARAM(m), PARAM(n), PARAM(kl), PARAM(ku),
2540 PARAM(alpha), PARAM(a), PARAM(lda), PARAM(x), PARAM(incx),
2541 PARAM(beta), PARAM(y), PARAM(incy));
2542
2543 ThenBlasImpl<blas::Transpose, uint64, uint64, uint64, uint64,
2544 std::complex<float>, const DeviceMemory<std::complex<float>> &,
2545 int, const DeviceMemory<std::complex<float>> &, int,
2546 std::complex<float>, DeviceMemory<std::complex<float>> *,
2547 int> impl;
2548 return impl(this, &blas::BlasSupport::DoBlasGbmv, trans, m, n, kl, ku, alpha,
2549 a, lda, x, incx, beta, y, incy);
2550 }
2551
ThenBlasGbmv(blas::Transpose trans,uint64 m,uint64 n,uint64 kl,uint64 ku,std::complex<double> alpha,const DeviceMemory<std::complex<double>> & a,int lda,const DeviceMemory<std::complex<double>> & x,int incx,std::complex<double> beta,DeviceMemory<std::complex<double>> * y,int incy)2552 Stream &Stream::ThenBlasGbmv(blas::Transpose trans, uint64 m, uint64 n,
2553 uint64 kl, uint64 ku, std::complex<double> alpha,
2554 const DeviceMemory<std::complex<double>> &a,
2555 int lda,
2556 const DeviceMemory<std::complex<double>> &x,
2557 int incx, std::complex<double> beta,
2558 DeviceMemory<std::complex<double>> *y, int incy) {
2559 VLOG_CALL(PARAM(trans), PARAM(m), PARAM(n), PARAM(kl), PARAM(ku),
2560 PARAM(alpha), PARAM(a), PARAM(lda), PARAM(x), PARAM(incx),
2561 PARAM(beta), PARAM(y), PARAM(incy));
2562
2563 ThenBlasImpl<blas::Transpose, uint64, uint64, uint64, uint64,
2564 std::complex<double>, const DeviceMemory<std::complex<double>> &,
2565 int, const DeviceMemory<std::complex<double>> &, int,
2566 std::complex<double>, DeviceMemory<std::complex<double>> *,
2567 int> impl;
2568 return impl(this, &blas::BlasSupport::DoBlasGbmv, trans, m, n, kl, ku, alpha,
2569 a, lda, x, incx, beta, y, incy);
2570 }
2571
ThenBlasGemv(blas::Transpose trans,uint64 m,uint64 n,float alpha,const DeviceMemory<float> & a,int lda,const DeviceMemory<float> & x,int incx,float beta,DeviceMemory<float> * y,int incy)2572 Stream &Stream::ThenBlasGemv(blas::Transpose trans, uint64 m, uint64 n,
2573 float alpha, const DeviceMemory<float> &a, int lda,
2574 const DeviceMemory<float> &x, int incx, float beta,
2575 DeviceMemory<float> *y, int incy) {
2576 VLOG_CALL(PARAM(trans), PARAM(m), PARAM(n), PARAM(alpha), PARAM(a),
2577 PARAM(lda), PARAM(x), PARAM(incx), PARAM(beta), PARAM(y),
2578 PARAM(incy));
2579
2580 ThenBlasImpl<blas::Transpose, uint64, uint64, float,
2581 const DeviceMemory<float> &, int, const DeviceMemory<float> &,
2582 int, float, DeviceMemory<float> *, int> impl;
2583 return impl(this, &blas::BlasSupport::DoBlasGemv, trans, m, n, alpha, a, lda,
2584 x, incx, beta, y, incy);
2585 }
2586
ThenBlasGemv(blas::Transpose trans,uint64 m,uint64 n,double alpha,const DeviceMemory<double> & a,int lda,const DeviceMemory<double> & x,int incx,double beta,DeviceMemory<double> * y,int incy)2587 Stream &Stream::ThenBlasGemv(blas::Transpose trans, uint64 m, uint64 n,
2588 double alpha, const DeviceMemory<double> &a,
2589 int lda, const DeviceMemory<double> &x, int incx,
2590 double beta, DeviceMemory<double> *y, int incy) {
2591 VLOG_CALL(PARAM(trans), PARAM(m), PARAM(n), PARAM(alpha), PARAM(a),
2592 PARAM(lda), PARAM(x), PARAM(incx), PARAM(beta), PARAM(y),
2593 PARAM(incy));
2594
2595 ThenBlasImpl<blas::Transpose, uint64, uint64, double,
2596 const DeviceMemory<double> &, int, const DeviceMemory<double> &,
2597 int, double, DeviceMemory<double> *, int> impl;
2598 return impl(this, &blas::BlasSupport::DoBlasGemv, trans, m, n, alpha, a, lda,
2599 x, incx, beta, y, incy);
2600 }
2601
ThenBlasGemv(blas::Transpose trans,uint64 m,uint64 n,std::complex<float> alpha,const DeviceMemory<std::complex<float>> & a,int lda,const DeviceMemory<std::complex<float>> & x,int incx,std::complex<float> beta,DeviceMemory<std::complex<float>> * y,int incy)2602 Stream &Stream::ThenBlasGemv(blas::Transpose trans, uint64 m, uint64 n,
2603 std::complex<float> alpha,
2604 const DeviceMemory<std::complex<float>> &a,
2605 int lda,
2606 const DeviceMemory<std::complex<float>> &x,
2607 int incx, std::complex<float> beta,
2608 DeviceMemory<std::complex<float>> *y, int incy) {
2609 VLOG_CALL(PARAM(trans), PARAM(m), PARAM(n), PARAM(alpha), PARAM(a),
2610 PARAM(lda), PARAM(x), PARAM(incx), PARAM(beta), PARAM(y),
2611 PARAM(incy));
2612
2613 ThenBlasImpl<blas::Transpose, uint64, uint64, std::complex<float>,
2614 const DeviceMemory<std::complex<float>> &, int,
2615 const DeviceMemory<std::complex<float>> &, int,
2616 std::complex<float>, DeviceMemory<std::complex<float>> *,
2617 int> impl;
2618 return impl(this, &blas::BlasSupport::DoBlasGemv, trans, m, n, alpha, a, lda,
2619 x, incx, beta, y, incy);
2620 }
2621
ThenBlasGemv(blas::Transpose trans,uint64 m,uint64 n,std::complex<double> alpha,const DeviceMemory<std::complex<double>> & a,int lda,const DeviceMemory<std::complex<double>> & x,int incx,std::complex<double> beta,DeviceMemory<std::complex<double>> * y,int incy)2622 Stream &Stream::ThenBlasGemv(blas::Transpose trans, uint64 m, uint64 n,
2623 std::complex<double> alpha,
2624 const DeviceMemory<std::complex<double>> &a,
2625 int lda,
2626 const DeviceMemory<std::complex<double>> &x,
2627 int incx, std::complex<double> beta,
2628 DeviceMemory<std::complex<double>> *y, int incy) {
2629 VLOG_CALL(PARAM(trans), PARAM(m), PARAM(n), PARAM(alpha), PARAM(a),
2630 PARAM(lda), PARAM(x), PARAM(incx), PARAM(beta), PARAM(y),
2631 PARAM(incy));
2632
2633 ThenBlasImpl<blas::Transpose, uint64, uint64, std::complex<double>,
2634 const DeviceMemory<std::complex<double>> &, int,
2635 const DeviceMemory<std::complex<double>> &, int,
2636 std::complex<double>, DeviceMemory<std::complex<double>> *,
2637 int> impl;
2638 return impl(this, &blas::BlasSupport::DoBlasGemv, trans, m, n, alpha, a, lda,
2639 x, incx, beta, y, incy);
2640 }
2641
ThenBlasGer(uint64 m,uint64 n,float alpha,const DeviceMemory<float> & x,int incx,const DeviceMemory<float> & y,int incy,DeviceMemory<float> * a,int lda)2642 Stream &Stream::ThenBlasGer(uint64 m, uint64 n, float alpha,
2643 const DeviceMemory<float> &x, int incx,
2644 const DeviceMemory<float> &y, int incy,
2645 DeviceMemory<float> *a, int lda) {
2646 VLOG_CALL(PARAM(m), PARAM(n), PARAM(alpha), PARAM(x), PARAM(incx), PARAM(y),
2647 PARAM(incy), PARAM(a), PARAM(lda));
2648
2649 ThenBlasImpl<uint64, uint64, float, const DeviceMemory<float> &, int,
2650 const DeviceMemory<float> &, int, DeviceMemory<float> *,
2651 int> impl;
2652 return impl(this, &blas::BlasSupport::DoBlasGer, m, n, alpha, x, incx, y,
2653 incy, a, lda);
2654 }
2655
ThenBlasGer(uint64 m,uint64 n,double alpha,const DeviceMemory<double> & x,int incx,const DeviceMemory<double> & y,int incy,DeviceMemory<double> * a,int lda)2656 Stream &Stream::ThenBlasGer(uint64 m, uint64 n, double alpha,
2657 const DeviceMemory<double> &x, int incx,
2658 const DeviceMemory<double> &y, int incy,
2659 DeviceMemory<double> *a, int lda) {
2660 VLOG_CALL(PARAM(m), PARAM(n), PARAM(alpha), PARAM(x), PARAM(incx), PARAM(y),
2661 PARAM(incy), PARAM(a), PARAM(lda));
2662
2663 ThenBlasImpl<uint64, uint64, double, const DeviceMemory<double> &, int,
2664 const DeviceMemory<double> &, int, DeviceMemory<double> *,
2665 int> impl;
2666 return impl(this, &blas::BlasSupport::DoBlasGer, m, n, alpha, x, incx, y,
2667 incy, a, lda);
2668 }
2669
ThenBlasGerc(uint64 m,uint64 n,std::complex<float> alpha,const DeviceMemory<std::complex<float>> & x,int incx,const DeviceMemory<std::complex<float>> & y,int incy,DeviceMemory<std::complex<float>> * a,int lda)2670 Stream &Stream::ThenBlasGerc(uint64 m, uint64 n, std::complex<float> alpha,
2671 const DeviceMemory<std::complex<float>> &x,
2672 int incx,
2673 const DeviceMemory<std::complex<float>> &y,
2674 int incy, DeviceMemory<std::complex<float>> *a,
2675 int lda) {
2676 VLOG_CALL(PARAM(m), PARAM(n), PARAM(alpha), PARAM(x), PARAM(incx), PARAM(y),
2677 PARAM(incy), PARAM(a), PARAM(lda));
2678
2679 ThenBlasImpl<uint64, uint64, std::complex<float>,
2680 const DeviceMemory<std::complex<float>> &, int,
2681 const DeviceMemory<std::complex<float>> &, int,
2682 DeviceMemory<std::complex<float>> *, int> impl;
2683 return impl(this, &blas::BlasSupport::DoBlasGerc, m, n, alpha, x, incx, y,
2684 incy, a, lda);
2685 }
2686
ThenBlasGerc(uint64 m,uint64 n,std::complex<double> alpha,const DeviceMemory<std::complex<double>> & x,int incx,const DeviceMemory<std::complex<double>> & y,int incy,DeviceMemory<std::complex<double>> * a,int lda)2687 Stream &Stream::ThenBlasGerc(uint64 m, uint64 n, std::complex<double> alpha,
2688 const DeviceMemory<std::complex<double>> &x,
2689 int incx,
2690 const DeviceMemory<std::complex<double>> &y,
2691 int incy, DeviceMemory<std::complex<double>> *a,
2692 int lda) {
2693 VLOG_CALL(PARAM(m), PARAM(n), PARAM(alpha), PARAM(x), PARAM(incx), PARAM(y),
2694 PARAM(incy), PARAM(a), PARAM(lda));
2695
2696 ThenBlasImpl<uint64, uint64, std::complex<double>,
2697 const DeviceMemory<std::complex<double>> &, int,
2698 const DeviceMemory<std::complex<double>> &, int,
2699 DeviceMemory<std::complex<double>> *, int> impl;
2700 return impl(this, &blas::BlasSupport::DoBlasGerc, m, n, alpha, x, incx, y,
2701 incy, a, lda);
2702 }
2703
ThenBlasGeru(uint64 m,uint64 n,std::complex<float> alpha,const DeviceMemory<std::complex<float>> & x,int incx,const DeviceMemory<std::complex<float>> & y,int incy,DeviceMemory<std::complex<float>> * a,int lda)2704 Stream &Stream::ThenBlasGeru(uint64 m, uint64 n, std::complex<float> alpha,
2705 const DeviceMemory<std::complex<float>> &x,
2706 int incx,
2707 const DeviceMemory<std::complex<float>> &y,
2708 int incy, DeviceMemory<std::complex<float>> *a,
2709 int lda) {
2710 VLOG_CALL(PARAM(m), PARAM(n), PARAM(alpha), PARAM(x), PARAM(incx), PARAM(y),
2711 PARAM(incy), PARAM(a), PARAM(lda));
2712
2713 ThenBlasImpl<uint64, uint64, std::complex<float>,
2714 const DeviceMemory<std::complex<float>> &, int,
2715 const DeviceMemory<std::complex<float>> &, int,
2716 DeviceMemory<std::complex<float>> *, int> impl;
2717 return impl(this, &blas::BlasSupport::DoBlasGeru, m, n, alpha, x, incx, y,
2718 incy, a, lda);
2719 }
2720
ThenBlasGeru(uint64 m,uint64 n,std::complex<double> alpha,const DeviceMemory<std::complex<double>> & x,int incx,const DeviceMemory<std::complex<double>> & y,int incy,DeviceMemory<std::complex<double>> * a,int lda)2721 Stream &Stream::ThenBlasGeru(uint64 m, uint64 n, std::complex<double> alpha,
2722 const DeviceMemory<std::complex<double>> &x,
2723 int incx,
2724 const DeviceMemory<std::complex<double>> &y,
2725 int incy, DeviceMemory<std::complex<double>> *a,
2726 int lda) {
2727 VLOG_CALL(PARAM(m), PARAM(n), PARAM(alpha), PARAM(x), PARAM(incx), PARAM(y),
2728 PARAM(incy), PARAM(a), PARAM(lda));
2729
2730 ThenBlasImpl<uint64, uint64, std::complex<double>,
2731 const DeviceMemory<std::complex<double>> &, int,
2732 const DeviceMemory<std::complex<double>> &, int,
2733 DeviceMemory<std::complex<double>> *, int> impl;
2734 return impl(this, &blas::BlasSupport::DoBlasGeru, m, n, alpha, x, incx, y,
2735 incy, a, lda);
2736 }
2737
ThenBlasHbmv(blas::UpperLower uplo,uint64 n,uint64 k,std::complex<float> alpha,const DeviceMemory<std::complex<float>> & a,int lda,const DeviceMemory<std::complex<float>> & x,int incx,std::complex<float> beta,DeviceMemory<std::complex<float>> * y,int incy)2738 Stream &Stream::ThenBlasHbmv(blas::UpperLower uplo, uint64 n, uint64 k,
2739 std::complex<float> alpha,
2740 const DeviceMemory<std::complex<float>> &a,
2741 int lda,
2742 const DeviceMemory<std::complex<float>> &x,
2743 int incx, std::complex<float> beta,
2744 DeviceMemory<std::complex<float>> *y, int incy) {
2745 VLOG_CALL(PARAM(uplo), PARAM(n), PARAM(k), PARAM(alpha), PARAM(a), PARAM(lda),
2746 PARAM(x), PARAM(incx), PARAM(beta), PARAM(y), PARAM(incy));
2747
2748 ThenBlasImpl<blas::UpperLower, uint64, uint64, std::complex<float>,
2749 const DeviceMemory<std::complex<float>> &, int,
2750 const DeviceMemory<std::complex<float>> &, int,
2751 std::complex<float>, DeviceMemory<std::complex<float>> *,
2752 int> impl;
2753 return impl(this, &blas::BlasSupport::DoBlasHbmv, uplo, n, k, alpha, a, lda,
2754 x, incx, beta, y, incy);
2755 }
2756
ThenBlasHbmv(blas::UpperLower uplo,uint64 n,uint64 k,std::complex<double> alpha,const DeviceMemory<std::complex<double>> & a,int lda,const DeviceMemory<std::complex<double>> & x,int incx,std::complex<double> beta,DeviceMemory<std::complex<double>> * y,int incy)2757 Stream &Stream::ThenBlasHbmv(blas::UpperLower uplo, uint64 n, uint64 k,
2758 std::complex<double> alpha,
2759 const DeviceMemory<std::complex<double>> &a,
2760 int lda,
2761 const DeviceMemory<std::complex<double>> &x,
2762 int incx, std::complex<double> beta,
2763 DeviceMemory<std::complex<double>> *y, int incy) {
2764 VLOG_CALL(PARAM(uplo), PARAM(n), PARAM(k), PARAM(alpha), PARAM(a), PARAM(lda),
2765 PARAM(x), PARAM(incx), PARAM(beta), PARAM(y), PARAM(incy));
2766
2767 ThenBlasImpl<blas::UpperLower, uint64, uint64, std::complex<double>,
2768 const DeviceMemory<std::complex<double>> &, int,
2769 const DeviceMemory<std::complex<double>> &, int,
2770 std::complex<double>, DeviceMemory<std::complex<double>> *,
2771 int> impl;
2772 return impl(this, &blas::BlasSupport::DoBlasHbmv, uplo, n, k, alpha, a, lda,
2773 x, incx, beta, y, incy);
2774 }
2775
ThenBlasHemv(blas::UpperLower uplo,uint64 n,std::complex<float> alpha,const DeviceMemory<std::complex<float>> & a,int lda,const DeviceMemory<std::complex<float>> & x,int incx,std::complex<float> beta,DeviceMemory<std::complex<float>> * y,int incy)2776 Stream &Stream::ThenBlasHemv(blas::UpperLower uplo, uint64 n,
2777 std::complex<float> alpha,
2778 const DeviceMemory<std::complex<float>> &a,
2779 int lda,
2780 const DeviceMemory<std::complex<float>> &x,
2781 int incx, std::complex<float> beta,
2782 DeviceMemory<std::complex<float>> *y, int incy) {
2783 VLOG_CALL(PARAM(uplo), PARAM(n), PARAM(alpha), PARAM(a), PARAM(lda), PARAM(x),
2784 PARAM(incx), PARAM(beta), PARAM(y), PARAM(incy));
2785
2786 ThenBlasImpl<blas::UpperLower, uint64, std::complex<float>,
2787 const DeviceMemory<std::complex<float>> &, int,
2788 const DeviceMemory<std::complex<float>> &, int,
2789 std::complex<float>, DeviceMemory<std::complex<float>> *,
2790 int> impl;
2791 return impl(this, &blas::BlasSupport::DoBlasHemv, uplo, n, alpha, a, lda, x,
2792 incx, beta, y, incy);
2793 }
2794
ThenBlasHemv(blas::UpperLower uplo,uint64 n,std::complex<double> alpha,const DeviceMemory<std::complex<double>> & a,int lda,const DeviceMemory<std::complex<double>> & x,int incx,std::complex<double> beta,DeviceMemory<std::complex<double>> * y,int incy)2795 Stream &Stream::ThenBlasHemv(blas::UpperLower uplo, uint64 n,
2796 std::complex<double> alpha,
2797 const DeviceMemory<std::complex<double>> &a,
2798 int lda,
2799 const DeviceMemory<std::complex<double>> &x,
2800 int incx, std::complex<double> beta,
2801 DeviceMemory<std::complex<double>> *y, int incy) {
2802 VLOG_CALL(PARAM(uplo), PARAM(n), PARAM(alpha), PARAM(a), PARAM(lda), PARAM(x),
2803 PARAM(incx), PARAM(beta), PARAM(y), PARAM(incy));
2804
2805 ThenBlasImpl<blas::UpperLower, uint64, std::complex<double>,
2806 const DeviceMemory<std::complex<double>> &, int,
2807 const DeviceMemory<std::complex<double>> &, int,
2808 std::complex<double>, DeviceMemory<std::complex<double>> *,
2809 int> impl;
2810 return impl(this, &blas::BlasSupport::DoBlasHemv, uplo, n, alpha, a, lda, x,
2811 incx, beta, y, incy);
2812 }
2813
ThenBlasHer(blas::UpperLower uplo,uint64 n,float alpha,const DeviceMemory<std::complex<float>> & x,int incx,DeviceMemory<std::complex<float>> * a,int lda)2814 Stream &Stream::ThenBlasHer(blas::UpperLower uplo, uint64 n, float alpha,
2815 const DeviceMemory<std::complex<float>> &x,
2816 int incx, DeviceMemory<std::complex<float>> *a,
2817 int lda) {
2818 VLOG_CALL(PARAM(uplo), PARAM(n), PARAM(alpha), PARAM(x), PARAM(incx),
2819 PARAM(a), PARAM(lda));
2820
2821 ThenBlasImpl<blas::UpperLower, uint64, float,
2822 const DeviceMemory<std::complex<float>> &, int,
2823 DeviceMemory<std::complex<float>> *, int> impl;
2824 return impl(this, &blas::BlasSupport::DoBlasHer, uplo, n, alpha, x, incx, a,
2825 lda);
2826 }
2827
ThenBlasHer(blas::UpperLower uplo,uint64 n,double alpha,const DeviceMemory<std::complex<double>> & x,int incx,DeviceMemory<std::complex<double>> * a,int lda)2828 Stream &Stream::ThenBlasHer(blas::UpperLower uplo, uint64 n, double alpha,
2829 const DeviceMemory<std::complex<double>> &x,
2830 int incx, DeviceMemory<std::complex<double>> *a,
2831 int lda) {
2832 VLOG_CALL(PARAM(uplo), PARAM(n), PARAM(alpha), PARAM(x), PARAM(incx),
2833 PARAM(a), PARAM(lda));
2834
2835 ThenBlasImpl<blas::UpperLower, uint64, double,
2836 const DeviceMemory<std::complex<double>> &, int,
2837 DeviceMemory<std::complex<double>> *, int> impl;
2838 return impl(this, &blas::BlasSupport::DoBlasHer, uplo, n, alpha, x, incx, a,
2839 lda);
2840 }
2841
ThenBlasHer2(blas::UpperLower uplo,uint64 n,std::complex<float> alpha,const DeviceMemory<std::complex<float>> & x,int incx,const DeviceMemory<std::complex<float>> & y,int incy,DeviceMemory<std::complex<float>> * a,int lda)2842 Stream &Stream::ThenBlasHer2(blas::UpperLower uplo, uint64 n,
2843 std::complex<float> alpha,
2844 const DeviceMemory<std::complex<float>> &x,
2845 int incx,
2846 const DeviceMemory<std::complex<float>> &y,
2847 int incy, DeviceMemory<std::complex<float>> *a,
2848 int lda) {
2849 VLOG_CALL(PARAM(uplo), PARAM(n), PARAM(alpha), PARAM(x), PARAM(incx),
2850 PARAM(y), PARAM(incy), PARAM(a), PARAM(lda));
2851
2852 ThenBlasImpl<blas::UpperLower, uint64, std::complex<float>,
2853 const DeviceMemory<std::complex<float>> &, int,
2854 const DeviceMemory<std::complex<float>> &, int,
2855 DeviceMemory<std::complex<float>> *, int> impl;
2856 return impl(this, &blas::BlasSupport::DoBlasHer2, uplo, n, alpha, x, incx, y,
2857 incy, a, lda);
2858 }
2859
ThenBlasHer2(blas::UpperLower uplo,uint64 n,std::complex<double> alpha,const DeviceMemory<std::complex<double>> & x,int incx,const DeviceMemory<std::complex<double>> & y,int incy,DeviceMemory<std::complex<double>> * a,int lda)2860 Stream &Stream::ThenBlasHer2(blas::UpperLower uplo, uint64 n,
2861 std::complex<double> alpha,
2862 const DeviceMemory<std::complex<double>> &x,
2863 int incx,
2864 const DeviceMemory<std::complex<double>> &y,
2865 int incy, DeviceMemory<std::complex<double>> *a,
2866 int lda) {
2867 VLOG_CALL(PARAM(uplo), PARAM(n), PARAM(alpha), PARAM(x), PARAM(incx),
2868 PARAM(y), PARAM(incy), PARAM(a), PARAM(lda));
2869
2870 ThenBlasImpl<blas::UpperLower, uint64, std::complex<double>,
2871 const DeviceMemory<std::complex<double>> &, int,
2872 const DeviceMemory<std::complex<double>> &, int,
2873 DeviceMemory<std::complex<double>> *, int> impl;
2874 return impl(this, &blas::BlasSupport::DoBlasHer2, uplo, n, alpha, x, incx, y,
2875 incy, a, lda);
2876 }
2877
ThenBlasHpmv(blas::UpperLower uplo,uint64 n,std::complex<float> alpha,const DeviceMemory<std::complex<float>> & ap,const DeviceMemory<std::complex<float>> & x,int incx,std::complex<float> beta,DeviceMemory<std::complex<float>> * y,int incy)2878 Stream &Stream::ThenBlasHpmv(blas::UpperLower uplo, uint64 n,
2879 std::complex<float> alpha,
2880 const DeviceMemory<std::complex<float>> &ap,
2881 const DeviceMemory<std::complex<float>> &x,
2882 int incx, std::complex<float> beta,
2883 DeviceMemory<std::complex<float>> *y, int incy) {
2884 VLOG_CALL(PARAM(uplo), PARAM(n), PARAM(alpha), PARAM(ap), PARAM(x),
2885 PARAM(incx), PARAM(beta), PARAM(y), PARAM(incy));
2886
2887 ThenBlasImpl<blas::UpperLower, uint64, std::complex<float>,
2888 const DeviceMemory<std::complex<float>> &,
2889 const DeviceMemory<std::complex<float>> &, int,
2890 std::complex<float>, DeviceMemory<std::complex<float>> *,
2891 int> impl;
2892 return impl(this, &blas::BlasSupport::DoBlasHpmv, uplo, n, alpha, ap, x, incx,
2893 beta, y, incy);
2894 }
2895
ThenBlasHpmv(blas::UpperLower uplo,uint64 n,std::complex<double> alpha,const DeviceMemory<std::complex<double>> & ap,const DeviceMemory<std::complex<double>> & x,int incx,std::complex<double> beta,DeviceMemory<std::complex<double>> * y,int incy)2896 Stream &Stream::ThenBlasHpmv(blas::UpperLower uplo, uint64 n,
2897 std::complex<double> alpha,
2898 const DeviceMemory<std::complex<double>> &ap,
2899 const DeviceMemory<std::complex<double>> &x,
2900 int incx, std::complex<double> beta,
2901 DeviceMemory<std::complex<double>> *y, int incy) {
2902 VLOG_CALL(PARAM(uplo), PARAM(n), PARAM(alpha), PARAM(ap), PARAM(x),
2903 PARAM(incx), PARAM(beta), PARAM(y), PARAM(incy));
2904
2905 ThenBlasImpl<blas::UpperLower, uint64, std::complex<double>,
2906 const DeviceMemory<std::complex<double>> &,
2907 const DeviceMemory<std::complex<double>> &, int,
2908 std::complex<double>, DeviceMemory<std::complex<double>> *,
2909 int> impl;
2910 return impl(this, &blas::BlasSupport::DoBlasHpmv, uplo, n, alpha, ap, x, incx,
2911 beta, y, incy);
2912 }
2913
ThenBlasHpr(blas::UpperLower uplo,uint64 n,float alpha,const DeviceMemory<std::complex<float>> & x,int incx,DeviceMemory<std::complex<float>> * ap)2914 Stream &Stream::ThenBlasHpr(blas::UpperLower uplo, uint64 n, float alpha,
2915 const DeviceMemory<std::complex<float>> &x,
2916 int incx, DeviceMemory<std::complex<float>> *ap) {
2917 VLOG_CALL(PARAM(uplo), PARAM(n), PARAM(alpha), PARAM(x), PARAM(incx),
2918 PARAM(ap));
2919
2920 ThenBlasImpl<blas::UpperLower, uint64, float,
2921 const DeviceMemory<std::complex<float>> &, int,
2922 DeviceMemory<std::complex<float>> *> impl;
2923 return impl(this, &blas::BlasSupport::DoBlasHpr, uplo, n, alpha, x, incx, ap);
2924 }
2925
ThenBlasHpr(blas::UpperLower uplo,uint64 n,double alpha,const DeviceMemory<std::complex<double>> & x,int incx,DeviceMemory<std::complex<double>> * ap)2926 Stream &Stream::ThenBlasHpr(blas::UpperLower uplo, uint64 n, double alpha,
2927 const DeviceMemory<std::complex<double>> &x,
2928 int incx, DeviceMemory<std::complex<double>> *ap) {
2929 VLOG_CALL(PARAM(uplo), PARAM(n), PARAM(alpha), PARAM(x), PARAM(incx),
2930 PARAM(ap));
2931
2932 ThenBlasImpl<blas::UpperLower, uint64, double,
2933 const DeviceMemory<std::complex<double>> &, int,
2934 DeviceMemory<std::complex<double>> *> impl;
2935 return impl(this, &blas::BlasSupport::DoBlasHpr, uplo, n, alpha, x, incx, ap);
2936 }
2937
ThenBlasHpr2(blas::UpperLower uplo,uint64 n,std::complex<float> alpha,const DeviceMemory<std::complex<float>> & x,int incx,const DeviceMemory<std::complex<float>> & y,int incy,DeviceMemory<std::complex<float>> * ap)2938 Stream &Stream::ThenBlasHpr2(blas::UpperLower uplo, uint64 n,
2939 std::complex<float> alpha,
2940 const DeviceMemory<std::complex<float>> &x,
2941 int incx,
2942 const DeviceMemory<std::complex<float>> &y,
2943 int incy, DeviceMemory<std::complex<float>> *ap) {
2944 VLOG_CALL(PARAM(uplo), PARAM(n), PARAM(alpha), PARAM(x), PARAM(incx),
2945 PARAM(y), PARAM(incy), PARAM(ap));
2946
2947 ThenBlasImpl<blas::UpperLower, uint64, std::complex<float>,
2948 const DeviceMemory<std::complex<float>> &, int,
2949 const DeviceMemory<std::complex<float>> &, int,
2950 DeviceMemory<std::complex<float>> *> impl;
2951 return impl(this, &blas::BlasSupport::DoBlasHpr2, uplo, n, alpha, x, incx, y,
2952 incy, ap);
2953 }
2954
ThenBlasHpr2(blas::UpperLower uplo,uint64 n,std::complex<double> alpha,const DeviceMemory<std::complex<double>> & x,int incx,const DeviceMemory<std::complex<double>> & y,int incy,DeviceMemory<std::complex<double>> * ap)2955 Stream &Stream::ThenBlasHpr2(blas::UpperLower uplo, uint64 n,
2956 std::complex<double> alpha,
2957 const DeviceMemory<std::complex<double>> &x,
2958 int incx,
2959 const DeviceMemory<std::complex<double>> &y,
2960 int incy, DeviceMemory<std::complex<double>> *ap) {
2961 VLOG_CALL(PARAM(uplo), PARAM(n), PARAM(alpha), PARAM(x), PARAM(incx),
2962 PARAM(y), PARAM(incy), PARAM(ap));
2963
2964 ThenBlasImpl<blas::UpperLower, uint64, std::complex<double>,
2965 const DeviceMemory<std::complex<double>> &, int,
2966 const DeviceMemory<std::complex<double>> &, int,
2967 DeviceMemory<std::complex<double>> *> impl;
2968 return impl(this, &blas::BlasSupport::DoBlasHpr2, uplo, n, alpha, x, incx, y,
2969 incy, ap);
2970 }
2971
ThenBlasSbmv(blas::UpperLower uplo,uint64 n,uint64 k,float alpha,const DeviceMemory<float> & a,int lda,const DeviceMemory<float> & x,int incx,float beta,DeviceMemory<float> * y,int incy)2972 Stream &Stream::ThenBlasSbmv(blas::UpperLower uplo, uint64 n, uint64 k,
2973 float alpha, const DeviceMemory<float> &a, int lda,
2974 const DeviceMemory<float> &x, int incx, float beta,
2975 DeviceMemory<float> *y, int incy) {
2976 VLOG_CALL(PARAM(uplo), PARAM(n), PARAM(k), PARAM(alpha), PARAM(a), PARAM(lda),
2977 PARAM(x), PARAM(incx), PARAM(beta), PARAM(y), PARAM(incy));
2978
2979 ThenBlasImpl<blas::UpperLower, uint64, uint64, float,
2980 const DeviceMemory<float> &, int, const DeviceMemory<float> &,
2981 int, float, DeviceMemory<float> *, int> impl;
2982 return impl(this, &blas::BlasSupport::DoBlasSbmv, uplo, n, k, alpha, a, lda,
2983 x, incx, beta, y, incy);
2984 }
2985
ThenBlasSbmv(blas::UpperLower uplo,uint64 n,uint64 k,double alpha,const DeviceMemory<double> & a,int lda,const DeviceMemory<double> & x,int incx,double beta,DeviceMemory<double> * y,int incy)2986 Stream &Stream::ThenBlasSbmv(blas::UpperLower uplo, uint64 n, uint64 k,
2987 double alpha, const DeviceMemory<double> &a,
2988 int lda, const DeviceMemory<double> &x, int incx,
2989 double beta, DeviceMemory<double> *y, int incy) {
2990 VLOG_CALL(PARAM(uplo), PARAM(n), PARAM(k), PARAM(alpha), PARAM(a), PARAM(lda),
2991 PARAM(x), PARAM(incx), PARAM(beta), PARAM(y), PARAM(incy));
2992
2993 ThenBlasImpl<blas::UpperLower, uint64, uint64, double,
2994 const DeviceMemory<double> &, int, const DeviceMemory<double> &,
2995 int, double, DeviceMemory<double> *, int> impl;
2996 return impl(this, &blas::BlasSupport::DoBlasSbmv, uplo, n, k, alpha, a, lda,
2997 x, incx, beta, y, incy);
2998 }
2999
ThenBlasSpmv(blas::UpperLower uplo,uint64 n,float alpha,const DeviceMemory<float> & ap,const DeviceMemory<float> & x,int incx,float beta,DeviceMemory<float> * y,int incy)3000 Stream &Stream::ThenBlasSpmv(blas::UpperLower uplo, uint64 n, float alpha,
3001 const DeviceMemory<float> &ap,
3002 const DeviceMemory<float> &x, int incx, float beta,
3003 DeviceMemory<float> *y, int incy) {
3004 VLOG_CALL(PARAM(uplo), PARAM(n), PARAM(alpha), PARAM(ap), PARAM(x),
3005 PARAM(incx), PARAM(beta), PARAM(y), PARAM(incy));
3006
3007 ThenBlasImpl<blas::UpperLower, uint64, float, const DeviceMemory<float> &,
3008 const DeviceMemory<float> &, int, float, DeviceMemory<float> *,
3009 int> impl;
3010 return impl(this, &blas::BlasSupport::DoBlasSpmv, uplo, n, alpha, ap, x, incx,
3011 beta, y, incy);
3012 }
3013
ThenBlasSpmv(blas::UpperLower uplo,uint64 n,double alpha,const DeviceMemory<double> & ap,const DeviceMemory<double> & x,int incx,double beta,DeviceMemory<double> * y,int incy)3014 Stream &Stream::ThenBlasSpmv(blas::UpperLower uplo, uint64 n, double alpha,
3015 const DeviceMemory<double> &ap,
3016 const DeviceMemory<double> &x, int incx,
3017 double beta, DeviceMemory<double> *y, int incy) {
3018 VLOG_CALL(PARAM(uplo), PARAM(n), PARAM(alpha), PARAM(ap), PARAM(x),
3019 PARAM(incx), PARAM(beta), PARAM(y), PARAM(incy));
3020
3021 ThenBlasImpl<blas::UpperLower, uint64, double, const DeviceMemory<double> &,
3022 const DeviceMemory<double> &, int, double,
3023 DeviceMemory<double> *, int> impl;
3024 return impl(this, &blas::BlasSupport::DoBlasSpmv, uplo, n, alpha, ap, x, incx,
3025 beta, y, incy);
3026 }
3027
ThenBlasSpr(blas::UpperLower uplo,uint64 n,float alpha,const DeviceMemory<float> & x,int incx,DeviceMemory<float> * ap)3028 Stream &Stream::ThenBlasSpr(blas::UpperLower uplo, uint64 n, float alpha,
3029 const DeviceMemory<float> &x, int incx,
3030 DeviceMemory<float> *ap) {
3031 VLOG_CALL(PARAM(uplo), PARAM(n), PARAM(alpha), PARAM(x), PARAM(incx),
3032 PARAM(ap));
3033
3034 ThenBlasImpl<blas::UpperLower, uint64, float, const DeviceMemory<float> &,
3035 int, DeviceMemory<float> *> impl;
3036 return impl(this, &blas::BlasSupport::DoBlasSpr, uplo, n, alpha, x, incx, ap);
3037 }
3038
ThenBlasSpr(blas::UpperLower uplo,uint64 n,double alpha,const DeviceMemory<double> & x,int incx,DeviceMemory<double> * ap)3039 Stream &Stream::ThenBlasSpr(blas::UpperLower uplo, uint64 n, double alpha,
3040 const DeviceMemory<double> &x, int incx,
3041 DeviceMemory<double> *ap) {
3042 VLOG_CALL(PARAM(uplo), PARAM(n), PARAM(alpha), PARAM(x), PARAM(incx),
3043 PARAM(ap));
3044
3045 ThenBlasImpl<blas::UpperLower, uint64, double, const DeviceMemory<double> &,
3046 int, DeviceMemory<double> *> impl;
3047 return impl(this, &blas::BlasSupport::DoBlasSpr, uplo, n, alpha, x, incx, ap);
3048 }
3049
ThenBlasSpr2(blas::UpperLower uplo,uint64 n,float alpha,const DeviceMemory<float> & x,int incx,const DeviceMemory<float> & y,int incy,DeviceMemory<float> * ap)3050 Stream &Stream::ThenBlasSpr2(blas::UpperLower uplo, uint64 n, float alpha,
3051 const DeviceMemory<float> &x, int incx,
3052 const DeviceMemory<float> &y, int incy,
3053 DeviceMemory<float> *ap) {
3054 VLOG_CALL(PARAM(uplo), PARAM(n), PARAM(alpha), PARAM(x), PARAM(incx),
3055 PARAM(y), PARAM(incy), PARAM(ap));
3056
3057 ThenBlasImpl<blas::UpperLower, uint64, float, const DeviceMemory<float> &,
3058 int, const DeviceMemory<float> &, int,
3059 DeviceMemory<float> *> impl;
3060 return impl(this, &blas::BlasSupport::DoBlasSpr2, uplo, n, alpha, x, incx, y,
3061 incy, ap);
3062 }
3063
ThenBlasSpr2(blas::UpperLower uplo,uint64 n,double alpha,const DeviceMemory<double> & x,int incx,const DeviceMemory<double> & y,int incy,DeviceMemory<double> * ap)3064 Stream &Stream::ThenBlasSpr2(blas::UpperLower uplo, uint64 n, double alpha,
3065 const DeviceMemory<double> &x, int incx,
3066 const DeviceMemory<double> &y, int incy,
3067 DeviceMemory<double> *ap) {
3068 VLOG_CALL(PARAM(uplo), PARAM(n), PARAM(alpha), PARAM(x), PARAM(incx),
3069 PARAM(y), PARAM(incy), PARAM(ap));
3070
3071 ThenBlasImpl<blas::UpperLower, uint64, double, const DeviceMemory<double> &,
3072 int, const DeviceMemory<double> &, int,
3073 DeviceMemory<double> *> impl;
3074 return impl(this, &blas::BlasSupport::DoBlasSpr2, uplo, n, alpha, x, incx, y,
3075 incy, ap);
3076 }
3077
ThenBlasSymv(blas::UpperLower uplo,uint64 n,float alpha,const DeviceMemory<float> & a,int lda,const DeviceMemory<float> & x,int incx,float beta,DeviceMemory<float> * y,int incy)3078 Stream &Stream::ThenBlasSymv(blas::UpperLower uplo, uint64 n, float alpha,
3079 const DeviceMemory<float> &a, int lda,
3080 const DeviceMemory<float> &x, int incx, float beta,
3081 DeviceMemory<float> *y, int incy) {
3082 VLOG_CALL(PARAM(uplo), PARAM(n), PARAM(alpha), PARAM(a), PARAM(lda), PARAM(x),
3083 PARAM(incx), PARAM(beta), PARAM(y), PARAM(incy));
3084
3085 ThenBlasImpl<blas::UpperLower, uint64, float, const DeviceMemory<float> &,
3086 int, const DeviceMemory<float> &, int, float,
3087 DeviceMemory<float> *, int> impl;
3088 return impl(this, &blas::BlasSupport::DoBlasSymv, uplo, n, alpha, a, lda, x,
3089 incx, beta, y, incy);
3090 }
3091
ThenBlasSymv(blas::UpperLower uplo,uint64 n,double alpha,const DeviceMemory<double> & a,int lda,const DeviceMemory<double> & x,int incx,double beta,DeviceMemory<double> * y,int incy)3092 Stream &Stream::ThenBlasSymv(blas::UpperLower uplo, uint64 n, double alpha,
3093 const DeviceMemory<double> &a, int lda,
3094 const DeviceMemory<double> &x, int incx,
3095 double beta, DeviceMemory<double> *y, int incy) {
3096 VLOG_CALL(PARAM(uplo), PARAM(n), PARAM(alpha), PARAM(a), PARAM(lda), PARAM(x),
3097 PARAM(incx), PARAM(beta), PARAM(y), PARAM(incy));
3098
3099 ThenBlasImpl<blas::UpperLower, uint64, double, const DeviceMemory<double> &,
3100 int, const DeviceMemory<double> &, int, double,
3101 DeviceMemory<double> *, int> impl;
3102 return impl(this, &blas::BlasSupport::DoBlasSymv, uplo, n, alpha, a, lda, x,
3103 incx, beta, y, incy);
3104 }
3105
ThenBlasSyr(blas::UpperLower uplo,uint64 n,float alpha,const DeviceMemory<float> & x,int incx,DeviceMemory<float> * a,int lda)3106 Stream &Stream::ThenBlasSyr(blas::UpperLower uplo, uint64 n, float alpha,
3107 const DeviceMemory<float> &x, int incx,
3108 DeviceMemory<float> *a, int lda) {
3109 VLOG_CALL(PARAM(uplo), PARAM(n), PARAM(alpha), PARAM(x), PARAM(incx),
3110 PARAM(a), PARAM(lda));
3111
3112 ThenBlasImpl<blas::UpperLower, uint64, float, const DeviceMemory<float> &,
3113 int, DeviceMemory<float> *, int> impl;
3114 return impl(this, &blas::BlasSupport::DoBlasSyr, uplo, n, alpha, x, incx, a,
3115 lda);
3116 }
3117
ThenBlasSyr(blas::UpperLower uplo,uint64 n,double alpha,const DeviceMemory<double> & x,int incx,DeviceMemory<double> * a,int lda)3118 Stream &Stream::ThenBlasSyr(blas::UpperLower uplo, uint64 n, double alpha,
3119 const DeviceMemory<double> &x, int incx,
3120 DeviceMemory<double> *a, int lda) {
3121 VLOG_CALL(PARAM(uplo), PARAM(n), PARAM(alpha), PARAM(x), PARAM(incx),
3122 PARAM(a), PARAM(lda));
3123
3124 ThenBlasImpl<blas::UpperLower, uint64, double, const DeviceMemory<double> &,
3125 int, DeviceMemory<double> *, int> impl;
3126 return impl(this, &blas::BlasSupport::DoBlasSyr, uplo, n, alpha, x, incx, a,
3127 lda);
3128 }
3129
ThenBlasSyr2(blas::UpperLower uplo,uint64 n,float alpha,const DeviceMemory<float> & x,int incx,const DeviceMemory<float> & y,int incy,DeviceMemory<float> * a,int lda)3130 Stream &Stream::ThenBlasSyr2(blas::UpperLower uplo, uint64 n, float alpha,
3131 const DeviceMemory<float> &x, int incx,
3132 const DeviceMemory<float> &y, int incy,
3133 DeviceMemory<float> *a, int lda) {
3134 VLOG_CALL(PARAM(uplo), PARAM(n), PARAM(alpha), PARAM(x), PARAM(incx),
3135 PARAM(y), PARAM(incy), PARAM(a), PARAM(lda));
3136
3137 ThenBlasImpl<blas::UpperLower, uint64, float, const DeviceMemory<float> &,
3138 int, const DeviceMemory<float> &, int, DeviceMemory<float> *,
3139 int> impl;
3140 return impl(this, &blas::BlasSupport::DoBlasSyr2, uplo, n, alpha, x, incx, y,
3141 incy, a, lda);
3142 }
3143
ThenBlasSyr2(blas::UpperLower uplo,uint64 n,double alpha,const DeviceMemory<double> & x,int incx,const DeviceMemory<double> & y,int incy,DeviceMemory<double> * a,int lda)3144 Stream &Stream::ThenBlasSyr2(blas::UpperLower uplo, uint64 n, double alpha,
3145 const DeviceMemory<double> &x, int incx,
3146 const DeviceMemory<double> &y, int incy,
3147 DeviceMemory<double> *a, int lda) {
3148 VLOG_CALL(PARAM(uplo), PARAM(n), PARAM(alpha), PARAM(x), PARAM(incx),
3149 PARAM(y), PARAM(incy), PARAM(a), PARAM(lda));
3150
3151 ThenBlasImpl<blas::UpperLower, uint64, double, const DeviceMemory<double> &,
3152 int, const DeviceMemory<double> &, int, DeviceMemory<double> *,
3153 int> impl;
3154 return impl(this, &blas::BlasSupport::DoBlasSyr2, uplo, n, alpha, x, incx, y,
3155 incy, a, lda);
3156 }
3157
ThenBlasTbmv(blas::UpperLower uplo,blas::Transpose trans,blas::Diagonal diag,uint64 n,uint64 k,const DeviceMemory<float> & a,int lda,DeviceMemory<float> * x,int incx)3158 Stream &Stream::ThenBlasTbmv(blas::UpperLower uplo, blas::Transpose trans,
3159 blas::Diagonal diag, uint64 n, uint64 k,
3160 const DeviceMemory<float> &a, int lda,
3161 DeviceMemory<float> *x, int incx) {
3162 VLOG_CALL(PARAM(uplo), PARAM(trans), PARAM(diag), PARAM(n), PARAM(k),
3163 PARAM(a), PARAM(lda), PARAM(x), PARAM(incx));
3164
3165 ThenBlasImpl<blas::UpperLower, blas::Transpose, blas::Diagonal, uint64,
3166 uint64, const DeviceMemory<float> &, int, DeviceMemory<float> *,
3167 int> impl;
3168 return impl(this, &blas::BlasSupport::DoBlasTbmv, uplo, trans, diag, n, k, a,
3169 lda, x, incx);
3170 }
3171
ThenBlasTbmv(blas::UpperLower uplo,blas::Transpose trans,blas::Diagonal diag,uint64 n,uint64 k,const DeviceMemory<double> & a,int lda,DeviceMemory<double> * x,int incx)3172 Stream &Stream::ThenBlasTbmv(blas::UpperLower uplo, blas::Transpose trans,
3173 blas::Diagonal diag, uint64 n, uint64 k,
3174 const DeviceMemory<double> &a, int lda,
3175 DeviceMemory<double> *x, int incx) {
3176 VLOG_CALL(PARAM(uplo), PARAM(trans), PARAM(diag), PARAM(n), PARAM(k),
3177 PARAM(a), PARAM(lda), PARAM(x), PARAM(incx));
3178
3179 ThenBlasImpl<blas::UpperLower, blas::Transpose, blas::Diagonal, uint64,
3180 uint64, const DeviceMemory<double> &, int,
3181 DeviceMemory<double> *, int> impl;
3182 return impl(this, &blas::BlasSupport::DoBlasTbmv, uplo, trans, diag, n, k, a,
3183 lda, x, incx);
3184 }
3185
ThenBlasTbmv(blas::UpperLower uplo,blas::Transpose trans,blas::Diagonal diag,uint64 n,uint64 k,const DeviceMemory<std::complex<float>> & a,int lda,DeviceMemory<std::complex<float>> * x,int incx)3186 Stream &Stream::ThenBlasTbmv(blas::UpperLower uplo, blas::Transpose trans,
3187 blas::Diagonal diag, uint64 n, uint64 k,
3188 const DeviceMemory<std::complex<float>> &a,
3189 int lda, DeviceMemory<std::complex<float>> *x,
3190 int incx) {
3191 VLOG_CALL(PARAM(uplo), PARAM(trans), PARAM(diag), PARAM(n), PARAM(k),
3192 PARAM(a), PARAM(lda), PARAM(x), PARAM(incx));
3193
3194 ThenBlasImpl<blas::UpperLower, blas::Transpose, blas::Diagonal, uint64,
3195 uint64, const DeviceMemory<std::complex<float>> &, int,
3196 DeviceMemory<std::complex<float>> *, int> impl;
3197 return impl(this, &blas::BlasSupport::DoBlasTbmv, uplo, trans, diag, n, k, a,
3198 lda, x, incx);
3199 }
3200
ThenBlasTbmv(blas::UpperLower uplo,blas::Transpose trans,blas::Diagonal diag,uint64 n,uint64 k,const DeviceMemory<std::complex<double>> & a,int lda,DeviceMemory<std::complex<double>> * x,int incx)3201 Stream &Stream::ThenBlasTbmv(blas::UpperLower uplo, blas::Transpose trans,
3202 blas::Diagonal diag, uint64 n, uint64 k,
3203 const DeviceMemory<std::complex<double>> &a,
3204 int lda, DeviceMemory<std::complex<double>> *x,
3205 int incx) {
3206 VLOG_CALL(PARAM(uplo), PARAM(trans), PARAM(diag), PARAM(n), PARAM(k),
3207 PARAM(a), PARAM(lda), PARAM(x), PARAM(incx));
3208
3209 ThenBlasImpl<blas::UpperLower, blas::Transpose, blas::Diagonal, uint64,
3210 uint64, const DeviceMemory<std::complex<double>> &, int,
3211 DeviceMemory<std::complex<double>> *, int> impl;
3212 return impl(this, &blas::BlasSupport::DoBlasTbmv, uplo, trans, diag, n, k, a,
3213 lda, x, incx);
3214 }
3215
ThenBlasTbsv(blas::UpperLower uplo,blas::Transpose trans,blas::Diagonal diag,uint64 n,uint64 k,const DeviceMemory<float> & a,int lda,DeviceMemory<float> * x,int incx)3216 Stream &Stream::ThenBlasTbsv(blas::UpperLower uplo, blas::Transpose trans,
3217 blas::Diagonal diag, uint64 n, uint64 k,
3218 const DeviceMemory<float> &a, int lda,
3219 DeviceMemory<float> *x, int incx) {
3220 VLOG_CALL(PARAM(uplo), PARAM(trans), PARAM(diag), PARAM(n), PARAM(k),
3221 PARAM(a), PARAM(lda), PARAM(x), PARAM(incx));
3222
3223 ThenBlasImpl<blas::UpperLower, blas::Transpose, blas::Diagonal, uint64,
3224 uint64, const DeviceMemory<float> &, int, DeviceMemory<float> *,
3225 int> impl;
3226 return impl(this, &blas::BlasSupport::DoBlasTbsv, uplo, trans, diag, n, k, a,
3227 lda, x, incx);
3228 }
3229
ThenBlasTbsv(blas::UpperLower uplo,blas::Transpose trans,blas::Diagonal diag,uint64 n,uint64 k,const DeviceMemory<double> & a,int lda,DeviceMemory<double> * x,int incx)3230 Stream &Stream::ThenBlasTbsv(blas::UpperLower uplo, blas::Transpose trans,
3231 blas::Diagonal diag, uint64 n, uint64 k,
3232 const DeviceMemory<double> &a, int lda,
3233 DeviceMemory<double> *x, int incx) {
3234 VLOG_CALL(PARAM(uplo), PARAM(trans), PARAM(diag), PARAM(n), PARAM(k),
3235 PARAM(a), PARAM(lda), PARAM(x), PARAM(incx));
3236
3237 ThenBlasImpl<blas::UpperLower, blas::Transpose, blas::Diagonal, uint64,
3238 uint64, const DeviceMemory<double> &, int,
3239 DeviceMemory<double> *, int> impl;
3240 return impl(this, &blas::BlasSupport::DoBlasTbsv, uplo, trans, diag, n, k, a,
3241 lda, x, incx);
3242 }
3243
ThenBlasTbsv(blas::UpperLower uplo,blas::Transpose trans,blas::Diagonal diag,uint64 n,uint64 k,const DeviceMemory<std::complex<float>> & a,int lda,DeviceMemory<std::complex<float>> * x,int incx)3244 Stream &Stream::ThenBlasTbsv(blas::UpperLower uplo, blas::Transpose trans,
3245 blas::Diagonal diag, uint64 n, uint64 k,
3246 const DeviceMemory<std::complex<float>> &a,
3247 int lda, DeviceMemory<std::complex<float>> *x,
3248 int incx) {
3249 VLOG_CALL(PARAM(uplo), PARAM(trans), PARAM(diag), PARAM(n), PARAM(k),
3250 PARAM(a), PARAM(lda), PARAM(x), PARAM(incx));
3251
3252 ThenBlasImpl<blas::UpperLower, blas::Transpose, blas::Diagonal, uint64,
3253 uint64, const DeviceMemory<std::complex<float>> &, int,
3254 DeviceMemory<std::complex<float>> *, int> impl;
3255 return impl(this, &blas::BlasSupport::DoBlasTbsv, uplo, trans, diag, n, k, a,
3256 lda, x, incx);
3257 }
3258
ThenBlasTbsv(blas::UpperLower uplo,blas::Transpose trans,blas::Diagonal diag,uint64 n,uint64 k,const DeviceMemory<std::complex<double>> & a,int lda,DeviceMemory<std::complex<double>> * x,int incx)3259 Stream &Stream::ThenBlasTbsv(blas::UpperLower uplo, blas::Transpose trans,
3260 blas::Diagonal diag, uint64 n, uint64 k,
3261 const DeviceMemory<std::complex<double>> &a,
3262 int lda, DeviceMemory<std::complex<double>> *x,
3263 int incx) {
3264 VLOG_CALL(PARAM(uplo), PARAM(trans), PARAM(diag), PARAM(n), PARAM(k),
3265 PARAM(a), PARAM(lda), PARAM(x), PARAM(incx));
3266
3267 ThenBlasImpl<blas::UpperLower, blas::Transpose, blas::Diagonal, uint64,
3268 uint64, const DeviceMemory<std::complex<double>> &, int,
3269 DeviceMemory<std::complex<double>> *, int> impl;
3270 return impl(this, &blas::BlasSupport::DoBlasTbsv, uplo, trans, diag, n, k, a,
3271 lda, x, incx);
3272 }
3273
ThenBlasTpmv(blas::UpperLower uplo,blas::Transpose trans,blas::Diagonal diag,uint64 n,const DeviceMemory<float> & ap,DeviceMemory<float> * x,int incx)3274 Stream &Stream::ThenBlasTpmv(blas::UpperLower uplo, blas::Transpose trans,
3275 blas::Diagonal diag, uint64 n,
3276 const DeviceMemory<float> &ap,
3277 DeviceMemory<float> *x, int incx) {
3278 VLOG_CALL(PARAM(uplo), PARAM(trans), PARAM(diag), PARAM(n), PARAM(ap),
3279 PARAM(x), PARAM(incx));
3280
3281 ThenBlasImpl<blas::UpperLower, blas::Transpose, blas::Diagonal, uint64,
3282 const DeviceMemory<float> &, DeviceMemory<float> *, int> impl;
3283 return impl(this, &blas::BlasSupport::DoBlasTpmv, uplo, trans, diag, n, ap, x,
3284 incx);
3285 }
3286
ThenBlasTpmv(blas::UpperLower uplo,blas::Transpose trans,blas::Diagonal diag,uint64 n,const DeviceMemory<double> & ap,DeviceMemory<double> * x,int incx)3287 Stream &Stream::ThenBlasTpmv(blas::UpperLower uplo, blas::Transpose trans,
3288 blas::Diagonal diag, uint64 n,
3289 const DeviceMemory<double> &ap,
3290 DeviceMemory<double> *x, int incx) {
3291 VLOG_CALL(PARAM(uplo), PARAM(trans), PARAM(diag), PARAM(n), PARAM(ap),
3292 PARAM(x), PARAM(incx));
3293
3294 ThenBlasImpl<blas::UpperLower, blas::Transpose, blas::Diagonal, uint64,
3295 const DeviceMemory<double> &, DeviceMemory<double> *, int> impl;
3296 return impl(this, &blas::BlasSupport::DoBlasTpmv, uplo, trans, diag, n, ap, x,
3297 incx);
3298 }
3299
ThenBlasTpmv(blas::UpperLower uplo,blas::Transpose trans,blas::Diagonal diag,uint64 n,const DeviceMemory<std::complex<float>> & ap,DeviceMemory<std::complex<float>> * x,int incx)3300 Stream &Stream::ThenBlasTpmv(blas::UpperLower uplo, blas::Transpose trans,
3301 blas::Diagonal diag, uint64 n,
3302 const DeviceMemory<std::complex<float>> &ap,
3303 DeviceMemory<std::complex<float>> *x, int incx) {
3304 VLOG_CALL(PARAM(uplo), PARAM(trans), PARAM(diag), PARAM(n), PARAM(ap),
3305 PARAM(x), PARAM(incx));
3306
3307 ThenBlasImpl<blas::UpperLower, blas::Transpose, blas::Diagonal, uint64,
3308 const DeviceMemory<std::complex<float>> &,
3309 DeviceMemory<std::complex<float>> *, int> impl;
3310 return impl(this, &blas::BlasSupport::DoBlasTpmv, uplo, trans, diag, n, ap, x,
3311 incx);
3312 }
3313
ThenBlasTpmv(blas::UpperLower uplo,blas::Transpose trans,blas::Diagonal diag,uint64 n,const DeviceMemory<std::complex<double>> & ap,DeviceMemory<std::complex<double>> * x,int incx)3314 Stream &Stream::ThenBlasTpmv(blas::UpperLower uplo, blas::Transpose trans,
3315 blas::Diagonal diag, uint64 n,
3316 const DeviceMemory<std::complex<double>> &ap,
3317 DeviceMemory<std::complex<double>> *x, int incx) {
3318 VLOG_CALL(PARAM(uplo), PARAM(trans), PARAM(diag), PARAM(n), PARAM(ap),
3319 PARAM(x), PARAM(incx));
3320
3321 ThenBlasImpl<blas::UpperLower, blas::Transpose, blas::Diagonal, uint64,
3322 const DeviceMemory<std::complex<double>> &,
3323 DeviceMemory<std::complex<double>> *, int> impl;
3324 return impl(this, &blas::BlasSupport::DoBlasTpmv, uplo, trans, diag, n, ap, x,
3325 incx);
3326 }
3327
ThenBlasTpsv(blas::UpperLower uplo,blas::Transpose trans,blas::Diagonal diag,uint64 n,const DeviceMemory<float> & ap,DeviceMemory<float> * x,int incx)3328 Stream &Stream::ThenBlasTpsv(blas::UpperLower uplo, blas::Transpose trans,
3329 blas::Diagonal diag, uint64 n,
3330 const DeviceMemory<float> &ap,
3331 DeviceMemory<float> *x, int incx) {
3332 VLOG_CALL(PARAM(uplo), PARAM(trans), PARAM(diag), PARAM(n), PARAM(ap),
3333 PARAM(x), PARAM(incx));
3334
3335 ThenBlasImpl<blas::UpperLower, blas::Transpose, blas::Diagonal, uint64,
3336 const DeviceMemory<float> &, DeviceMemory<float> *, int> impl;
3337 return impl(this, &blas::BlasSupport::DoBlasTpsv, uplo, trans, diag, n, ap, x,
3338 incx);
3339 }
3340
ThenBlasTpsv(blas::UpperLower uplo,blas::Transpose trans,blas::Diagonal diag,uint64 n,const DeviceMemory<double> & ap,DeviceMemory<double> * x,int incx)3341 Stream &Stream::ThenBlasTpsv(blas::UpperLower uplo, blas::Transpose trans,
3342 blas::Diagonal diag, uint64 n,
3343 const DeviceMemory<double> &ap,
3344 DeviceMemory<double> *x, int incx) {
3345 VLOG_CALL(PARAM(uplo), PARAM(trans), PARAM(diag), PARAM(n), PARAM(ap),
3346 PARAM(x), PARAM(incx));
3347
3348 ThenBlasImpl<blas::UpperLower, blas::Transpose, blas::Diagonal, uint64,
3349 const DeviceMemory<double> &, DeviceMemory<double> *, int> impl;
3350 return impl(this, &blas::BlasSupport::DoBlasTpsv, uplo, trans, diag, n, ap, x,
3351 incx);
3352 }
3353
ThenBlasTpsv(blas::UpperLower uplo,blas::Transpose trans,blas::Diagonal diag,uint64 n,const DeviceMemory<std::complex<float>> & ap,DeviceMemory<std::complex<float>> * x,int incx)3354 Stream &Stream::ThenBlasTpsv(blas::UpperLower uplo, blas::Transpose trans,
3355 blas::Diagonal diag, uint64 n,
3356 const DeviceMemory<std::complex<float>> &ap,
3357 DeviceMemory<std::complex<float>> *x, int incx) {
3358 VLOG_CALL(PARAM(uplo), PARAM(trans), PARAM(diag), PARAM(n), PARAM(ap),
3359 PARAM(x), PARAM(incx));
3360
3361 ThenBlasImpl<blas::UpperLower, blas::Transpose, blas::Diagonal, uint64,
3362 const DeviceMemory<std::complex<float>> &,
3363 DeviceMemory<std::complex<float>> *, int> impl;
3364 return impl(this, &blas::BlasSupport::DoBlasTpsv, uplo, trans, diag, n, ap, x,
3365 incx);
3366 }
3367
ThenBlasTpsv(blas::UpperLower uplo,blas::Transpose trans,blas::Diagonal diag,uint64 n,const DeviceMemory<std::complex<double>> & ap,DeviceMemory<std::complex<double>> * x,int incx)3368 Stream &Stream::ThenBlasTpsv(blas::UpperLower uplo, blas::Transpose trans,
3369 blas::Diagonal diag, uint64 n,
3370 const DeviceMemory<std::complex<double>> &ap,
3371 DeviceMemory<std::complex<double>> *x, int incx) {
3372 VLOG_CALL(PARAM(uplo), PARAM(trans), PARAM(diag), PARAM(n), PARAM(ap),
3373 PARAM(x), PARAM(incx));
3374
3375 ThenBlasImpl<blas::UpperLower, blas::Transpose, blas::Diagonal, uint64,
3376 const DeviceMemory<std::complex<double>> &,
3377 DeviceMemory<std::complex<double>> *, int> impl;
3378 return impl(this, &blas::BlasSupport::DoBlasTpsv, uplo, trans, diag, n, ap, x,
3379 incx);
3380 }
3381
ThenBlasTrmv(blas::UpperLower uplo,blas::Transpose trans,blas::Diagonal diag,uint64 n,const DeviceMemory<float> & a,int lda,DeviceMemory<float> * x,int incx)3382 Stream &Stream::ThenBlasTrmv(blas::UpperLower uplo, blas::Transpose trans,
3383 blas::Diagonal diag, uint64 n,
3384 const DeviceMemory<float> &a, int lda,
3385 DeviceMemory<float> *x, int incx) {
3386 VLOG_CALL(PARAM(uplo), PARAM(trans), PARAM(diag), PARAM(n), PARAM(a),
3387 PARAM(lda), PARAM(x), PARAM(incx));
3388
3389 ThenBlasImpl<blas::UpperLower, blas::Transpose, blas::Diagonal, uint64,
3390 const DeviceMemory<float> &, int, DeviceMemory<float> *,
3391 int> impl;
3392 return impl(this, &blas::BlasSupport::DoBlasTrmv, uplo, trans, diag, n, a,
3393 lda, x, incx);
3394 }
3395
ThenBlasTrmv(blas::UpperLower uplo,blas::Transpose trans,blas::Diagonal diag,uint64 n,const DeviceMemory<double> & a,int lda,DeviceMemory<double> * x,int incx)3396 Stream &Stream::ThenBlasTrmv(blas::UpperLower uplo, blas::Transpose trans,
3397 blas::Diagonal diag, uint64 n,
3398 const DeviceMemory<double> &a, int lda,
3399 DeviceMemory<double> *x, int incx) {
3400 VLOG_CALL(PARAM(uplo), PARAM(trans), PARAM(diag), PARAM(n), PARAM(a),
3401 PARAM(lda), PARAM(x), PARAM(incx));
3402
3403 ThenBlasImpl<blas::UpperLower, blas::Transpose, blas::Diagonal, uint64,
3404 const DeviceMemory<double> &, int, DeviceMemory<double> *,
3405 int> impl;
3406 return impl(this, &blas::BlasSupport::DoBlasTrmv, uplo, trans, diag, n, a,
3407 lda, x, incx);
3408 }
3409
ThenBlasTrmv(blas::UpperLower uplo,blas::Transpose trans,blas::Diagonal diag,uint64 n,const DeviceMemory<std::complex<float>> & a,int lda,DeviceMemory<std::complex<float>> * x,int incx)3410 Stream &Stream::ThenBlasTrmv(blas::UpperLower uplo, blas::Transpose trans,
3411 blas::Diagonal diag, uint64 n,
3412 const DeviceMemory<std::complex<float>> &a,
3413 int lda, DeviceMemory<std::complex<float>> *x,
3414 int incx) {
3415 VLOG_CALL(PARAM(uplo), PARAM(trans), PARAM(diag), PARAM(n), PARAM(a),
3416 PARAM(lda), PARAM(x), PARAM(incx));
3417
3418 ThenBlasImpl<blas::UpperLower, blas::Transpose, blas::Diagonal, uint64,
3419 const DeviceMemory<std::complex<float>> &, int,
3420 DeviceMemory<std::complex<float>> *, int> impl;
3421 return impl(this, &blas::BlasSupport::DoBlasTrmv, uplo, trans, diag, n, a,
3422 lda, x, incx);
3423 }
3424
ThenBlasTrmv(blas::UpperLower uplo,blas::Transpose trans,blas::Diagonal diag,uint64 n,const DeviceMemory<std::complex<double>> & a,int lda,DeviceMemory<std::complex<double>> * x,int incx)3425 Stream &Stream::ThenBlasTrmv(blas::UpperLower uplo, blas::Transpose trans,
3426 blas::Diagonal diag, uint64 n,
3427 const DeviceMemory<std::complex<double>> &a,
3428 int lda, DeviceMemory<std::complex<double>> *x,
3429 int incx) {
3430 VLOG_CALL(PARAM(uplo), PARAM(trans), PARAM(diag), PARAM(n), PARAM(a),
3431 PARAM(lda), PARAM(x), PARAM(incx));
3432
3433 ThenBlasImpl<blas::UpperLower, blas::Transpose, blas::Diagonal, uint64,
3434 const DeviceMemory<std::complex<double>> &, int,
3435 DeviceMemory<std::complex<double>> *, int> impl;
3436 return impl(this, &blas::BlasSupport::DoBlasTrmv, uplo, trans, diag, n, a,
3437 lda, x, incx);
3438 }
3439
ThenBlasTrsv(blas::UpperLower uplo,blas::Transpose trans,blas::Diagonal diag,uint64 n,const DeviceMemory<float> & a,int lda,DeviceMemory<float> * x,int incx)3440 Stream &Stream::ThenBlasTrsv(blas::UpperLower uplo, blas::Transpose trans,
3441 blas::Diagonal diag, uint64 n,
3442 const DeviceMemory<float> &a, int lda,
3443 DeviceMemory<float> *x, int incx) {
3444 VLOG_CALL(PARAM(uplo), PARAM(trans), PARAM(diag), PARAM(n), PARAM(a),
3445 PARAM(lda), PARAM(x), PARAM(incx));
3446
3447 ThenBlasImpl<blas::UpperLower, blas::Transpose, blas::Diagonal, uint64,
3448 const DeviceMemory<float> &, int, DeviceMemory<float> *,
3449 int> impl;
3450 return impl(this, &blas::BlasSupport::DoBlasTrsv, uplo, trans, diag, n, a,
3451 lda, x, incx);
3452 }
3453
ThenBlasTrsv(blas::UpperLower uplo,blas::Transpose trans,blas::Diagonal diag,uint64 n,const DeviceMemory<double> & a,int lda,DeviceMemory<double> * x,int incx)3454 Stream &Stream::ThenBlasTrsv(blas::UpperLower uplo, blas::Transpose trans,
3455 blas::Diagonal diag, uint64 n,
3456 const DeviceMemory<double> &a, int lda,
3457 DeviceMemory<double> *x, int incx) {
3458 VLOG_CALL(PARAM(uplo), PARAM(trans), PARAM(diag), PARAM(n), PARAM(a),
3459 PARAM(lda), PARAM(x), PARAM(incx));
3460
3461 ThenBlasImpl<blas::UpperLower, blas::Transpose, blas::Diagonal, uint64,
3462 const DeviceMemory<double> &, int, DeviceMemory<double> *,
3463 int> impl;
3464 return impl(this, &blas::BlasSupport::DoBlasTrsv, uplo, trans, diag, n, a,
3465 lda, x, incx);
3466 }
3467
ThenBlasTrsv(blas::UpperLower uplo,blas::Transpose trans,blas::Diagonal diag,uint64 n,const DeviceMemory<std::complex<float>> & a,int lda,DeviceMemory<std::complex<float>> * x,int incx)3468 Stream &Stream::ThenBlasTrsv(blas::UpperLower uplo, blas::Transpose trans,
3469 blas::Diagonal diag, uint64 n,
3470 const DeviceMemory<std::complex<float>> &a,
3471 int lda, DeviceMemory<std::complex<float>> *x,
3472 int incx) {
3473 VLOG_CALL(PARAM(uplo), PARAM(trans), PARAM(diag), PARAM(n), PARAM(a),
3474 PARAM(lda), PARAM(x), PARAM(incx));
3475
3476 ThenBlasImpl<blas::UpperLower, blas::Transpose, blas::Diagonal, uint64,
3477 const DeviceMemory<std::complex<float>> &, int,
3478 DeviceMemory<std::complex<float>> *, int> impl;
3479 return impl(this, &blas::BlasSupport::DoBlasTrsv, uplo, trans, diag, n, a,
3480 lda, x, incx);
3481 }
3482
ThenBlasTrsv(blas::UpperLower uplo,blas::Transpose trans,blas::Diagonal diag,uint64 n,const DeviceMemory<std::complex<double>> & a,int lda,DeviceMemory<std::complex<double>> * x,int incx)3483 Stream &Stream::ThenBlasTrsv(blas::UpperLower uplo, blas::Transpose trans,
3484 blas::Diagonal diag, uint64 n,
3485 const DeviceMemory<std::complex<double>> &a,
3486 int lda, DeviceMemory<std::complex<double>> *x,
3487 int incx) {
3488 VLOG_CALL(PARAM(uplo), PARAM(trans), PARAM(diag), PARAM(n), PARAM(a),
3489 PARAM(lda), PARAM(x), PARAM(incx));
3490
3491 ThenBlasImpl<blas::UpperLower, blas::Transpose, blas::Diagonal, uint64,
3492 const DeviceMemory<std::complex<double>> &, int,
3493 DeviceMemory<std::complex<double>> *, int> impl;
3494 return impl(this, &blas::BlasSupport::DoBlasTrsv, uplo, trans, diag, n, a,
3495 lda, x, incx);
3496 }
3497
ThenBlasGemm(blas::Transpose transa,blas::Transpose transb,uint64 m,uint64 n,uint64 k,float alpha,const DeviceMemory<Eigen::half> & a,int lda,const DeviceMemory<Eigen::half> & b,int ldb,float beta,DeviceMemory<Eigen::half> * c,int ldc)3498 Stream &Stream::ThenBlasGemm(blas::Transpose transa, blas::Transpose transb,
3499 uint64 m, uint64 n, uint64 k, float alpha,
3500 const DeviceMemory<Eigen::half> &a, int lda,
3501 const DeviceMemory<Eigen::half> &b, int ldb,
3502 float beta,
3503 DeviceMemory<Eigen::half> *c, int ldc) {
3504 VLOG_CALL(PARAM(transa), PARAM(transb), PARAM(m), PARAM(n), PARAM(k),
3505 PARAM(alpha), PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb),
3506 PARAM(beta), PARAM(c), PARAM(ldc));
3507
3508 ThenBlasImpl<blas::Transpose, blas::Transpose, uint64, uint64, uint64, float,
3509 const DeviceMemory<Eigen::half> &, int,
3510 const DeviceMemory<Eigen::half> &, int,
3511 float, DeviceMemory<Eigen::half> *, int> impl;
3512 return impl(this, &blas::BlasSupport::DoBlasGemm, transa, transb, m, n, k,
3513 alpha, a, lda, b, ldb, beta, c, ldc);
3514 }
3515
ThenBlasGemm(blas::Transpose transa,blas::Transpose transb,uint64 m,uint64 n,uint64 k,float alpha,const DeviceMemory<float> & a,int lda,const DeviceMemory<float> & b,int ldb,float beta,DeviceMemory<float> * c,int ldc)3516 Stream &Stream::ThenBlasGemm(blas::Transpose transa, blas::Transpose transb,
3517 uint64 m, uint64 n, uint64 k, float alpha,
3518 const DeviceMemory<float> &a, int lda,
3519 const DeviceMemory<float> &b, int ldb, float beta,
3520 DeviceMemory<float> *c, int ldc) {
3521 VLOG_CALL(PARAM(transa), PARAM(transb), PARAM(m), PARAM(n), PARAM(k),
3522 PARAM(alpha), PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb),
3523 PARAM(beta), PARAM(c), PARAM(ldc));
3524
3525 ThenBlasImpl<blas::Transpose, blas::Transpose, uint64, uint64, uint64, float,
3526 const DeviceMemory<float> &, int, const DeviceMemory<float> &,
3527 int, float, DeviceMemory<float> *, int> impl;
3528 return impl(this, &blas::BlasSupport::DoBlasGemm, transa, transb, m, n, k,
3529 alpha, a, lda, b, ldb, beta, c, ldc);
3530 }
3531
ThenBlasGemm(blas::Transpose transa,blas::Transpose transb,uint64 m,uint64 n,uint64 k,double alpha,const DeviceMemory<double> & a,int lda,const DeviceMemory<double> & b,int ldb,double beta,DeviceMemory<double> * c,int ldc)3532 Stream &Stream::ThenBlasGemm(blas::Transpose transa, blas::Transpose transb,
3533 uint64 m, uint64 n, uint64 k, double alpha,
3534 const DeviceMemory<double> &a, int lda,
3535 const DeviceMemory<double> &b, int ldb,
3536 double beta, DeviceMemory<double> *c, int ldc) {
3537 VLOG_CALL(PARAM(transa), PARAM(transb), PARAM(m), PARAM(n), PARAM(k),
3538 PARAM(alpha), PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb),
3539 PARAM(beta), PARAM(c), PARAM(ldc));
3540
3541 ThenBlasImpl<blas::Transpose, blas::Transpose, uint64, uint64, uint64, double,
3542 const DeviceMemory<double> &, int, const DeviceMemory<double> &,
3543 int, double, DeviceMemory<double> *, int> impl;
3544 return impl(this, &blas::BlasSupport::DoBlasGemm, transa, transb, m, n, k,
3545 alpha, a, lda, b, ldb, beta, c, ldc);
3546 }
3547
ThenBlasGemm(blas::Transpose transa,blas::Transpose transb,uint64 m,uint64 n,uint64 k,std::complex<float> alpha,const DeviceMemory<std::complex<float>> & a,int lda,const DeviceMemory<std::complex<float>> & b,int ldb,std::complex<float> beta,DeviceMemory<std::complex<float>> * c,int ldc)3548 Stream &Stream::ThenBlasGemm(blas::Transpose transa, blas::Transpose transb,
3549 uint64 m, uint64 n, uint64 k,
3550 std::complex<float> alpha,
3551 const DeviceMemory<std::complex<float>> &a,
3552 int lda,
3553 const DeviceMemory<std::complex<float>> &b,
3554 int ldb, std::complex<float> beta,
3555 DeviceMemory<std::complex<float>> *c, int ldc) {
3556 VLOG_CALL(PARAM(transa), PARAM(transb), PARAM(m), PARAM(n), PARAM(k),
3557 PARAM(alpha), PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb),
3558 PARAM(beta), PARAM(c), PARAM(ldc));
3559
3560 ThenBlasImpl<blas::Transpose, blas::Transpose, uint64, uint64, uint64,
3561 std::complex<float>, const DeviceMemory<std::complex<float>> &,
3562 int, const DeviceMemory<std::complex<float>> &, int,
3563 std::complex<float>, DeviceMemory<std::complex<float>> *,
3564 int> impl;
3565 return impl(this, &blas::BlasSupport::DoBlasGemm, transa, transb, m, n, k,
3566 alpha, a, lda, b, ldb, beta, c, ldc);
3567 }
3568
ThenBlasGemm(blas::Transpose transa,blas::Transpose transb,uint64 m,uint64 n,uint64 k,std::complex<double> alpha,const DeviceMemory<std::complex<double>> & a,int lda,const DeviceMemory<std::complex<double>> & b,int ldb,std::complex<double> beta,DeviceMemory<std::complex<double>> * c,int ldc)3569 Stream &Stream::ThenBlasGemm(blas::Transpose transa, blas::Transpose transb,
3570 uint64 m, uint64 n, uint64 k,
3571 std::complex<double> alpha,
3572 const DeviceMemory<std::complex<double>> &a,
3573 int lda,
3574 const DeviceMemory<std::complex<double>> &b,
3575 int ldb, std::complex<double> beta,
3576 DeviceMemory<std::complex<double>> *c, int ldc) {
3577 VLOG_CALL(PARAM(transa), PARAM(transb), PARAM(m), PARAM(n), PARAM(k),
3578 PARAM(alpha), PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb),
3579 PARAM(beta), PARAM(c), PARAM(ldc));
3580
3581 ThenBlasImpl<blas::Transpose, blas::Transpose, uint64, uint64, uint64,
3582 std::complex<double>, const DeviceMemory<std::complex<double>> &,
3583 int, const DeviceMemory<std::complex<double>> &, int,
3584 std::complex<double>, DeviceMemory<std::complex<double>> *,
3585 int> impl;
3586 return impl(this, &blas::BlasSupport::DoBlasGemm, transa, transb, m, n, k,
3587 alpha, a, lda, b, ldb, beta, c, ldc);
3588 }
3589
3590 namespace {
3591 // Like ThenBlasImpl, except this expects the last argument of blas_func to be a
3592 // blas::ProfileResult*. This functor doesn't put the stream into an error
3593 // state if the op fails and the profile result is non-null. Instead, the
3594 // error-ness is returned in the profile result itself.
3595 template <typename... Args>
3596 struct ThenBlasWithProfileImpl {
operator ()stream_executor::__anon3e261ebe0211::ThenBlasWithProfileImpl3597 Stream &operator()(Stream *stream,
3598 bool (blas::BlasSupport::*blas_func)(
3599 Stream *, Args..., blas::ProfileResult *),
3600 Args... args, blas::ProfileResult *profile_result) {
3601 ThenBlasImpl<Args..., blas::ProfileResult *> Runner;
3602 bool record_error = profile_result == nullptr;
3603 return Runner.Run(stream, blas_func, record_error, args..., profile_result);
3604 }
3605 };
3606 } // anonymous namespace
3607
ThenBlasGemvWithProfiling(blas::Transpose trans,uint64 m,uint64 n,float alpha,const DeviceMemory<float> & a,int lda,const DeviceMemory<float> & x,int incx,float beta,DeviceMemory<float> * y,int incy,blas::ProfileResult * output_profile_result)3608 Stream &Stream::ThenBlasGemvWithProfiling(
3609 blas::Transpose trans, uint64 m, uint64 n, float alpha,
3610 const DeviceMemory<float> &a, int lda, const DeviceMemory<float> &x,
3611 int incx, float beta, DeviceMemory<float> *y, int incy,
3612 blas::ProfileResult *output_profile_result) {
3613 VLOG_CALL(PARAM(trans), PARAM(m), PARAM(n), PARAM(alpha), PARAM(a),
3614 PARAM(lda), PARAM(x), PARAM(incx), PARAM(beta), PARAM(y),
3615 PARAM(incy));
3616
3617 ThenBlasWithProfileImpl<
3618 blas::Transpose, uint64, uint64, float, const DeviceMemory<float> &, int,
3619 const DeviceMemory<float> &, int, float, DeviceMemory<float> *, int>
3620 impl;
3621 return impl(this, &blas::BlasSupport::DoBlasGemvWithProfiling, trans, m, n,
3622 alpha, a, lda, x, incx, beta, y, incy, output_profile_result);
3623 }
3624
ThenBlasGemvWithProfiling(blas::Transpose trans,uint64 m,uint64 n,double alpha,const DeviceMemory<double> & a,int lda,const DeviceMemory<double> & x,int incx,double beta,DeviceMemory<double> * y,int incy,blas::ProfileResult * output_profile_result)3625 Stream &Stream::ThenBlasGemvWithProfiling(
3626 blas::Transpose trans, uint64 m, uint64 n, double alpha,
3627 const DeviceMemory<double> &a, int lda, const DeviceMemory<double> &x,
3628 int incx, double beta, DeviceMemory<double> *y, int incy,
3629 blas::ProfileResult *output_profile_result) {
3630 VLOG_CALL(PARAM(trans), PARAM(m), PARAM(n), PARAM(alpha), PARAM(a),
3631 PARAM(lda), PARAM(x), PARAM(incx), PARAM(beta), PARAM(y),
3632 PARAM(incy));
3633
3634 ThenBlasWithProfileImpl<blas::Transpose, uint64, uint64, double,
3635 const DeviceMemory<double> &, int,
3636 const DeviceMemory<double> &, int, double,
3637 DeviceMemory<double> *, int>
3638 impl;
3639 return impl(this, &blas::BlasSupport::DoBlasGemvWithProfiling, trans, m, n,
3640 alpha, a, lda, x, incx, beta, y, incy, output_profile_result);
3641 }
3642
ThenBlasGemvWithProfiling(blas::Transpose trans,uint64 m,uint64 n,std::complex<float> alpha,const DeviceMemory<std::complex<float>> & a,int lda,const DeviceMemory<std::complex<float>> & x,int incx,std::complex<float> beta,DeviceMemory<std::complex<float>> * y,int incy,blas::ProfileResult * output_profile_result)3643 Stream &Stream::ThenBlasGemvWithProfiling(
3644 blas::Transpose trans, uint64 m, uint64 n, std::complex<float> alpha,
3645 const DeviceMemory<std::complex<float>> &a, int lda,
3646 const DeviceMemory<std::complex<float>> &x, int incx,
3647 std::complex<float> beta, DeviceMemory<std::complex<float>> *y, int incy,
3648 blas::ProfileResult *output_profile_result) {
3649 VLOG_CALL(PARAM(trans), PARAM(m), PARAM(n), PARAM(alpha), PARAM(a),
3650 PARAM(lda), PARAM(x), PARAM(incx), PARAM(beta), PARAM(y),
3651 PARAM(incy));
3652
3653 ThenBlasWithProfileImpl<blas::Transpose, uint64, uint64, std::complex<float>,
3654 const DeviceMemory<std::complex<float>> &, int,
3655 const DeviceMemory<std::complex<float>> &, int,
3656 std::complex<float>,
3657 DeviceMemory<std::complex<float>> *, int>
3658 impl;
3659 return impl(this, &blas::BlasSupport::DoBlasGemvWithProfiling, trans, m, n,
3660 alpha, a, lda, x, incx, beta, y, incy, output_profile_result);
3661 }
3662
ThenBlasGemvWithProfiling(blas::Transpose trans,uint64 m,uint64 n,std::complex<double> alpha,const DeviceMemory<std::complex<double>> & a,int lda,const DeviceMemory<std::complex<double>> & x,int incx,std::complex<double> beta,DeviceMemory<std::complex<double>> * y,int incy,blas::ProfileResult * output_profile_result)3663 Stream &Stream::ThenBlasGemvWithProfiling(
3664 blas::Transpose trans, uint64 m, uint64 n, std::complex<double> alpha,
3665 const DeviceMemory<std::complex<double>> &a, int lda,
3666 const DeviceMemory<std::complex<double>> &x, int incx,
3667 std::complex<double> beta, DeviceMemory<std::complex<double>> *y, int incy,
3668 blas::ProfileResult *output_profile_result) {
3669 VLOG_CALL(PARAM(trans), PARAM(m), PARAM(n), PARAM(alpha), PARAM(a),
3670 PARAM(lda), PARAM(x), PARAM(incx), PARAM(beta), PARAM(y),
3671 PARAM(incy));
3672
3673 ThenBlasWithProfileImpl<blas::Transpose, uint64, uint64, std::complex<double>,
3674 const DeviceMemory<std::complex<double>> &, int,
3675 const DeviceMemory<std::complex<double>> &, int,
3676 std::complex<double>,
3677 DeviceMemory<std::complex<double>> *, int>
3678 impl;
3679 return impl(this, &blas::BlasSupport::DoBlasGemvWithProfiling, trans, m, n,
3680 alpha, a, lda, x, incx, beta, y, incy, output_profile_result);
3681 }
3682
ThenBlasGemmWithProfiling(blas::Transpose transa,blas::Transpose transb,uint64 m,uint64 n,uint64 k,float alpha,const DeviceMemory<Eigen::half> & a,int lda,const DeviceMemory<Eigen::half> & b,int ldb,float beta,DeviceMemory<Eigen::half> * c,int ldc,blas::ProfileResult * output_profile_result)3683 Stream &Stream::ThenBlasGemmWithProfiling(
3684 blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n,
3685 uint64 k, float alpha, const DeviceMemory<Eigen::half> &a, int lda,
3686 const DeviceMemory<Eigen::half> &b, int ldb, float beta,
3687 DeviceMemory<Eigen::half> *c, int ldc,
3688 blas::ProfileResult *output_profile_result) {
3689 VLOG_CALL(PARAM(transa), PARAM(transb), PARAM(m), PARAM(n), PARAM(k),
3690 PARAM(alpha), PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb),
3691 PARAM(beta), PARAM(c), PARAM(ldc));
3692
3693 ThenBlasWithProfileImpl<blas::Transpose, blas::Transpose, uint64, uint64,
3694 uint64, float, const DeviceMemory<Eigen::half> &, int,
3695 const DeviceMemory<Eigen::half> &, int, float,
3696 DeviceMemory<Eigen::half> *, int>
3697 impl;
3698 return impl(this, &blas::BlasSupport::DoBlasGemmWithProfiling, transa, transb,
3699 m, n, k, alpha, a, lda, b, ldb, beta, c, ldc,
3700 output_profile_result);
3701 }
3702
ThenBlasGemmWithProfiling(blas::Transpose transa,blas::Transpose transb,uint64 m,uint64 n,uint64 k,float alpha,const DeviceMemory<float> & a,int lda,const DeviceMemory<float> & b,int ldb,float beta,DeviceMemory<float> * c,int ldc,blas::ProfileResult * output_profile_result)3703 Stream &Stream::ThenBlasGemmWithProfiling(
3704 blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n,
3705 uint64 k, float alpha, const DeviceMemory<float> &a, int lda,
3706 const DeviceMemory<float> &b, int ldb, float beta, DeviceMemory<float> *c,
3707 int ldc, blas::ProfileResult *output_profile_result) {
3708 VLOG_CALL(PARAM(transa), PARAM(transb), PARAM(m), PARAM(n), PARAM(k),
3709 PARAM(alpha), PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb),
3710 PARAM(beta), PARAM(c), PARAM(ldc));
3711
3712 ThenBlasWithProfileImpl<blas::Transpose, blas::Transpose, uint64, uint64,
3713 uint64, float, const DeviceMemory<float> &, int,
3714 const DeviceMemory<float> &, int, float,
3715 DeviceMemory<float> *, int>
3716 impl;
3717 return impl(this, &blas::BlasSupport::DoBlasGemmWithProfiling, transa, transb,
3718 m, n, k, alpha, a, lda, b, ldb, beta, c, ldc,
3719 output_profile_result);
3720 }
3721
ThenBlasGemmWithProfiling(blas::Transpose transa,blas::Transpose transb,uint64 m,uint64 n,uint64 k,double alpha,const DeviceMemory<double> & a,int lda,const DeviceMemory<double> & b,int ldb,double beta,DeviceMemory<double> * c,int ldc,blas::ProfileResult * output_profile_result)3722 Stream &Stream::ThenBlasGemmWithProfiling(
3723 blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n,
3724 uint64 k, double alpha, const DeviceMemory<double> &a, int lda,
3725 const DeviceMemory<double> &b, int ldb, double beta,
3726 DeviceMemory<double> *c, int ldc,
3727 blas::ProfileResult *output_profile_result) {
3728 VLOG_CALL(PARAM(transa), PARAM(transb), PARAM(m), PARAM(n), PARAM(k),
3729 PARAM(alpha), PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb),
3730 PARAM(beta), PARAM(c), PARAM(ldc));
3731
3732 ThenBlasWithProfileImpl<blas::Transpose, blas::Transpose, uint64, uint64,
3733 uint64, double, const DeviceMemory<double> &, int,
3734 const DeviceMemory<double> &, int, double,
3735 DeviceMemory<double> *, int>
3736 impl;
3737 return impl(this, &blas::BlasSupport::DoBlasGemmWithProfiling, transa, transb,
3738 m, n, k, alpha, a, lda, b, ldb, beta, c, ldc,
3739 output_profile_result);
3740 }
3741
ThenBlasGemmWithProfiling(blas::Transpose transa,blas::Transpose transb,uint64 m,uint64 n,uint64 k,std::complex<float> alpha,const DeviceMemory<std::complex<float>> & a,int lda,const DeviceMemory<std::complex<float>> & b,int ldb,std::complex<float> beta,DeviceMemory<std::complex<float>> * c,int ldc,blas::ProfileResult * output_profile_result)3742 Stream &Stream::ThenBlasGemmWithProfiling(
3743 blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n,
3744 uint64 k, std::complex<float> alpha,
3745 const DeviceMemory<std::complex<float>> &a, int lda,
3746 const DeviceMemory<std::complex<float>> &b, int ldb,
3747 std::complex<float> beta, DeviceMemory<std::complex<float>> *c, int ldc,
3748 blas::ProfileResult *output_profile_result) {
3749 VLOG_CALL(PARAM(transa), PARAM(transb), PARAM(m), PARAM(n), PARAM(k),
3750 PARAM(alpha), PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb),
3751 PARAM(beta), PARAM(c), PARAM(ldc));
3752
3753 ThenBlasWithProfileImpl<
3754 blas::Transpose, blas::Transpose, uint64, uint64, uint64,
3755 std::complex<float>, const DeviceMemory<std::complex<float>> &, int,
3756 const DeviceMemory<std::complex<float>> &, int, std::complex<float>,
3757 DeviceMemory<std::complex<float>> *, int>
3758 impl;
3759 return impl(this, &blas::BlasSupport::DoBlasGemmWithProfiling, transa, transb,
3760 m, n, k, alpha, a, lda, b, ldb, beta, c, ldc,
3761 output_profile_result);
3762 }
3763
ThenBlasGemmWithProfiling(blas::Transpose transa,blas::Transpose transb,uint64 m,uint64 n,uint64 k,std::complex<double> alpha,const DeviceMemory<std::complex<double>> & a,int lda,const DeviceMemory<std::complex<double>> & b,int ldb,std::complex<double> beta,DeviceMemory<std::complex<double>> * c,int ldc,blas::ProfileResult * output_profile_result)3764 Stream &Stream::ThenBlasGemmWithProfiling(
3765 blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n,
3766 uint64 k, std::complex<double> alpha,
3767 const DeviceMemory<std::complex<double>> &a, int lda,
3768 const DeviceMemory<std::complex<double>> &b, int ldb,
3769 std::complex<double> beta, DeviceMemory<std::complex<double>> *c, int ldc,
3770 blas::ProfileResult *output_profile_result) {
3771 VLOG_CALL(PARAM(transa), PARAM(transb), PARAM(m), PARAM(n), PARAM(k),
3772 PARAM(alpha), PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb),
3773 PARAM(beta), PARAM(c), PARAM(ldc));
3774
3775 ThenBlasWithProfileImpl<
3776 blas::Transpose, blas::Transpose, uint64, uint64, uint64,
3777 std::complex<double>, const DeviceMemory<std::complex<double>> &, int,
3778 const DeviceMemory<std::complex<double>> &, int, std::complex<double>,
3779 DeviceMemory<std::complex<double>> *, int>
3780 impl;
3781 return impl(this, &blas::BlasSupport::DoBlasGemmWithProfiling, transa, transb,
3782 m, n, k, alpha, a, lda, b, ldb, beta, c, ldc,
3783 output_profile_result);
3784 }
3785
ThenBlasGemmWithAlgorithm(blas::Transpose transa,blas::Transpose transb,uint64 m,uint64 n,uint64 k,const HostOrDeviceScalar<Eigen::half> & alpha,const DeviceMemory<Eigen::half> & a,int lda,const DeviceMemory<Eigen::half> & b,int ldb,const HostOrDeviceScalar<Eigen::half> & beta,DeviceMemory<Eigen::half> * c,int ldc,blas::ComputationType computation_type,blas::AlgorithmType algorithm,blas::ProfileResult * output_profile_result)3786 Stream &Stream::ThenBlasGemmWithAlgorithm(
3787 blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n,
3788 uint64 k, const HostOrDeviceScalar<Eigen::half> &alpha,
3789 const DeviceMemory<Eigen::half> &a, int lda,
3790 const DeviceMemory<Eigen::half> &b, int ldb,
3791 const HostOrDeviceScalar<Eigen::half> &beta, DeviceMemory<Eigen::half> *c,
3792 int ldc, blas::ComputationType computation_type,
3793 blas::AlgorithmType algorithm, blas::ProfileResult *output_profile_result) {
3794 VLOG_CALL(PARAM(transa), PARAM(transb), PARAM(m), PARAM(n), PARAM(k),
3795 PARAM(alpha), PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb),
3796 PARAM(beta), PARAM(c), PARAM(ldc), PARAM(computation_type),
3797 PARAM(algorithm));
3798
3799 ThenBlasWithProfileImpl<
3800 blas::Transpose, blas::Transpose, uint64, uint64, uint64,
3801 const HostOrDeviceScalar<Eigen::half> &,
3802 const DeviceMemory<Eigen::half> &, int, const DeviceMemory<Eigen::half> &,
3803 int, const HostOrDeviceScalar<Eigen::half> &, DeviceMemory<Eigen::half> *,
3804 int, blas::ComputationType, blas::AlgorithmType>
3805 impl;
3806 return impl(this, &blas::BlasSupport::DoBlasGemmWithAlgorithm, transa, transb,
3807 m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, computation_type,
3808 algorithm, output_profile_result);
3809 }
3810
ThenBlasGemmWithAlgorithm(blas::Transpose transa,blas::Transpose transb,uint64 m,uint64 n,uint64 k,const HostOrDeviceScalar<int> & alpha,const DeviceMemory<int8> & a,int lda,const DeviceMemory<int8> & b,int ldb,const HostOrDeviceScalar<int> & beta,DeviceMemory<int> * c,int ldc,blas::ComputationType computation_type,blas::AlgorithmType algorithm,blas::ProfileResult * output_profile_result)3811 Stream &Stream::ThenBlasGemmWithAlgorithm(
3812 blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n,
3813 uint64 k, const HostOrDeviceScalar<int> &alpha, const DeviceMemory<int8> &a,
3814 int lda, const DeviceMemory<int8> &b, int ldb,
3815 const HostOrDeviceScalar<int> &beta, DeviceMemory<int> *c, int ldc,
3816 blas::ComputationType computation_type, blas::AlgorithmType algorithm,
3817 blas::ProfileResult *output_profile_result) {
3818 VLOG_CALL(PARAM(transa), PARAM(transb), PARAM(m), PARAM(n), PARAM(k),
3819 PARAM(alpha), PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb),
3820 PARAM(beta), PARAM(c), PARAM(ldc), PARAM(computation_type),
3821 PARAM(algorithm));
3822
3823 ThenBlasWithProfileImpl<
3824 blas::Transpose, blas::Transpose, uint64, uint64, uint64,
3825 const HostOrDeviceScalar<int> &, const DeviceMemory<int8> &, int,
3826 const DeviceMemory<int8> &, int, const HostOrDeviceScalar<int> &,
3827 DeviceMemory<int> *, int, blas::ComputationType, blas::AlgorithmType>
3828 impl;
3829 return impl(this, &blas::BlasSupport::DoBlasGemmWithAlgorithm, transa, transb,
3830 m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, computation_type,
3831 algorithm, output_profile_result);
3832 }
3833
ThenBlasGemmWithAlgorithm(blas::Transpose transa,blas::Transpose transb,uint64 m,uint64 n,uint64 k,const HostOrDeviceScalar<float> & alpha,const DeviceMemory<float> & a,int lda,const DeviceMemory<float> & b,int ldb,const HostOrDeviceScalar<float> & beta,DeviceMemory<float> * c,int ldc,blas::ComputationType computation_type,blas::AlgorithmType algorithm,blas::ProfileResult * output_profile_result)3834 Stream &Stream::ThenBlasGemmWithAlgorithm(
3835 blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n,
3836 uint64 k, const HostOrDeviceScalar<float> &alpha,
3837 const DeviceMemory<float> &a, int lda, const DeviceMemory<float> &b,
3838 int ldb, const HostOrDeviceScalar<float> &beta, DeviceMemory<float> *c,
3839 int ldc, blas::ComputationType computation_type,
3840 blas::AlgorithmType algorithm, blas::ProfileResult *output_profile_result) {
3841 VLOG_CALL(PARAM(transa), PARAM(transb), PARAM(m), PARAM(n), PARAM(k),
3842 PARAM(alpha), PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb),
3843 PARAM(beta), PARAM(c), PARAM(ldc), PARAM(computation_type),
3844 PARAM(algorithm));
3845
3846 ThenBlasWithProfileImpl<
3847 blas::Transpose, blas::Transpose, uint64, uint64, uint64,
3848 const HostOrDeviceScalar<float> &, const DeviceMemory<float> &, int,
3849 const DeviceMemory<float> &, int, const HostOrDeviceScalar<float> &,
3850 DeviceMemory<float> *, int, blas::ComputationType, blas::AlgorithmType>
3851 impl;
3852 return impl(this, &blas::BlasSupport::DoBlasGemmWithAlgorithm, transa, transb,
3853 m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, computation_type,
3854 algorithm, output_profile_result);
3855 }
3856
ThenBlasGemmWithAlgorithm(blas::Transpose transa,blas::Transpose transb,uint64 m,uint64 n,uint64 k,const HostOrDeviceScalar<double> & alpha,const DeviceMemory<double> & a,int lda,const DeviceMemory<double> & b,int ldb,const HostOrDeviceScalar<double> & beta,DeviceMemory<double> * c,int ldc,blas::ComputationType computation_type,blas::AlgorithmType algorithm,blas::ProfileResult * output_profile_result)3857 Stream &Stream::ThenBlasGemmWithAlgorithm(
3858 blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n,
3859 uint64 k, const HostOrDeviceScalar<double> &alpha,
3860 const DeviceMemory<double> &a, int lda, const DeviceMemory<double> &b,
3861 int ldb, const HostOrDeviceScalar<double> &beta, DeviceMemory<double> *c,
3862 int ldc, blas::ComputationType computation_type,
3863 blas::AlgorithmType algorithm, blas::ProfileResult *output_profile_result) {
3864 VLOG_CALL(PARAM(transa), PARAM(transb), PARAM(m), PARAM(n), PARAM(k),
3865 PARAM(alpha), PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb),
3866 PARAM(beta), PARAM(c), PARAM(ldc), PARAM(computation_type),
3867 PARAM(algorithm));
3868
3869 ThenBlasWithProfileImpl<
3870 blas::Transpose, blas::Transpose, uint64, uint64, uint64,
3871 const HostOrDeviceScalar<double> &, const DeviceMemory<double> &, int,
3872 const DeviceMemory<double> &, int, const HostOrDeviceScalar<double> &,
3873 DeviceMemory<double> *, int, blas::ComputationType, blas::AlgorithmType>
3874 impl;
3875 return impl(this, &blas::BlasSupport::DoBlasGemmWithAlgorithm, transa, transb,
3876 m, n, k, HostOrDeviceScalar<double>(alpha), a, lda, b, ldb,
3877 HostOrDeviceScalar<double>(beta), c, ldc, computation_type,
3878 algorithm, output_profile_result);
3879 }
3880
ThenBlasGemmWithAlgorithm(blas::Transpose transa,blas::Transpose transb,uint64 m,uint64 n,uint64 k,const HostOrDeviceScalar<std::complex<float>> & alpha,const DeviceMemory<std::complex<float>> & a,int lda,const DeviceMemory<std::complex<float>> & b,int ldb,const HostOrDeviceScalar<std::complex<float>> & beta,DeviceMemory<std::complex<float>> * c,int ldc,blas::ComputationType computation_type,blas::AlgorithmType algorithm,blas::ProfileResult * output_profile_result)3881 Stream &Stream::ThenBlasGemmWithAlgorithm(
3882 blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n,
3883 uint64 k, const HostOrDeviceScalar<std::complex<float>> &alpha,
3884 const DeviceMemory<std::complex<float>> &a, int lda,
3885 const DeviceMemory<std::complex<float>> &b, int ldb,
3886 const HostOrDeviceScalar<std::complex<float>> &beta,
3887 DeviceMemory<std::complex<float>> *c, int ldc,
3888 blas::ComputationType computation_type, blas::AlgorithmType algorithm,
3889 blas::ProfileResult *output_profile_result) {
3890 VLOG_CALL(PARAM(transa), PARAM(transb), PARAM(m), PARAM(n), PARAM(k),
3891 PARAM(alpha), PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb),
3892 PARAM(beta), PARAM(c), PARAM(ldc), PARAM(computation_type),
3893 PARAM(algorithm));
3894
3895 ThenBlasWithProfileImpl<blas::Transpose, blas::Transpose, uint64, uint64,
3896 uint64,
3897 const HostOrDeviceScalar<std::complex<float>> &,
3898 const DeviceMemory<std::complex<float>> &, int,
3899 const DeviceMemory<std::complex<float>> &, int,
3900 const HostOrDeviceScalar<std::complex<float>> &,
3901 DeviceMemory<std::complex<float>> *, int,
3902 blas::ComputationType, blas::AlgorithmType>
3903 impl;
3904 return impl(this, &blas::BlasSupport::DoBlasGemmWithAlgorithm, transa, transb,
3905 m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, computation_type,
3906 algorithm, output_profile_result);
3907 }
3908
ThenBlasGemmWithAlgorithm(blas::Transpose transa,blas::Transpose transb,uint64 m,uint64 n,uint64 k,const HostOrDeviceScalar<std::complex<double>> & alpha,const DeviceMemory<std::complex<double>> & a,int lda,const DeviceMemory<std::complex<double>> & b,int ldb,const HostOrDeviceScalar<std::complex<double>> & beta,DeviceMemory<std::complex<double>> * c,int ldc,blas::ComputationType computation_type,blas::AlgorithmType algorithm,blas::ProfileResult * output_profile_result)3909 Stream &Stream::ThenBlasGemmWithAlgorithm(
3910 blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n,
3911 uint64 k, const HostOrDeviceScalar<std::complex<double>> &alpha,
3912 const DeviceMemory<std::complex<double>> &a, int lda,
3913 const DeviceMemory<std::complex<double>> &b, int ldb,
3914 const HostOrDeviceScalar<std::complex<double>> &beta,
3915 DeviceMemory<std::complex<double>> *c, int ldc,
3916 blas::ComputationType computation_type, blas::AlgorithmType algorithm,
3917 blas::ProfileResult *output_profile_result) {
3918 VLOG_CALL(PARAM(transa), PARAM(transb), PARAM(m), PARAM(n), PARAM(k),
3919 PARAM(alpha), PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb),
3920 PARAM(beta), PARAM(c), PARAM(ldc), PARAM(computation_type),
3921 PARAM(algorithm));
3922
3923 ThenBlasWithProfileImpl<blas::Transpose, blas::Transpose, uint64, uint64,
3924 uint64,
3925 const HostOrDeviceScalar<std::complex<double>> &,
3926 const DeviceMemory<std::complex<double>> &, int,
3927 const DeviceMemory<std::complex<double>> &, int,
3928 const HostOrDeviceScalar<std::complex<double>> &,
3929 DeviceMemory<std::complex<double>> *, int,
3930 blas::ComputationType, blas::AlgorithmType>
3931 impl;
3932 return impl(this, &blas::BlasSupport::DoBlasGemmWithAlgorithm, transa, transb,
3933 m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, computation_type,
3934 algorithm, output_profile_result);
3935 }
3936
ThenBlasHemm(blas::Side side,blas::UpperLower uplo,uint64 m,uint64 n,std::complex<float> alpha,const DeviceMemory<std::complex<float>> & a,int lda,const DeviceMemory<std::complex<float>> & b,int ldb,std::complex<float> beta,DeviceMemory<std::complex<float>> * c,int ldc)3937 Stream &Stream::ThenBlasHemm(blas::Side side, blas::UpperLower uplo, uint64 m,
3938 uint64 n, std::complex<float> alpha,
3939 const DeviceMemory<std::complex<float>> &a,
3940 int lda,
3941 const DeviceMemory<std::complex<float>> &b,
3942 int ldb, std::complex<float> beta,
3943 DeviceMemory<std::complex<float>> *c, int ldc) {
3944 VLOG_CALL(PARAM(side), PARAM(uplo), PARAM(m), PARAM(n), PARAM(alpha),
3945 PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb), PARAM(beta), PARAM(c),
3946 PARAM(ldc));
3947
3948 ThenBlasImpl<blas::Side, blas::UpperLower, uint64, uint64,
3949 std::complex<float>, const DeviceMemory<std::complex<float>> &,
3950 int, const DeviceMemory<std::complex<float>> &, int,
3951 std::complex<float>, DeviceMemory<std::complex<float>> *,
3952 int> impl;
3953 return impl(this, &blas::BlasSupport::DoBlasHemm, side, uplo, m, n, alpha, a,
3954 lda, b, ldb, beta, c, ldc);
3955 }
3956
ThenBlasHemm(blas::Side side,blas::UpperLower uplo,uint64 m,uint64 n,std::complex<double> alpha,const DeviceMemory<std::complex<double>> & a,int lda,const DeviceMemory<std::complex<double>> & b,int ldb,std::complex<double> beta,DeviceMemory<std::complex<double>> * c,int ldc)3957 Stream &Stream::ThenBlasHemm(blas::Side side, blas::UpperLower uplo, uint64 m,
3958 uint64 n, std::complex<double> alpha,
3959 const DeviceMemory<std::complex<double>> &a,
3960 int lda,
3961 const DeviceMemory<std::complex<double>> &b,
3962 int ldb, std::complex<double> beta,
3963 DeviceMemory<std::complex<double>> *c, int ldc) {
3964 VLOG_CALL(PARAM(side), PARAM(uplo), PARAM(m), PARAM(n), PARAM(alpha),
3965 PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb), PARAM(beta), PARAM(c),
3966 PARAM(ldc));
3967
3968 ThenBlasImpl<blas::Side, blas::UpperLower, uint64, uint64,
3969 std::complex<double>, const DeviceMemory<std::complex<double>> &,
3970 int, const DeviceMemory<std::complex<double>> &, int,
3971 std::complex<double>, DeviceMemory<std::complex<double>> *,
3972 int> impl;
3973 return impl(this, &blas::BlasSupport::DoBlasHemm, side, uplo, m, n, alpha, a,
3974 lda, b, ldb, beta, c, ldc);
3975 }
3976
ThenBlasHerk(blas::UpperLower uplo,blas::Transpose trans,uint64 n,uint64 k,float alpha,const DeviceMemory<std::complex<float>> & a,int lda,float beta,DeviceMemory<std::complex<float>> * c,int ldc)3977 Stream &Stream::ThenBlasHerk(blas::UpperLower uplo, blas::Transpose trans,
3978 uint64 n, uint64 k, float alpha,
3979 const DeviceMemory<std::complex<float>> &a,
3980 int lda, float beta,
3981 DeviceMemory<std::complex<float>> *c, int ldc) {
3982 VLOG_CALL(PARAM(uplo), PARAM(trans), PARAM(n), PARAM(k), PARAM(alpha),
3983 PARAM(a), PARAM(lda), PARAM(beta), PARAM(c), PARAM(ldc));
3984
3985 ThenBlasImpl<blas::UpperLower, blas::Transpose, uint64, uint64, float,
3986 const DeviceMemory<std::complex<float>> &, int, float,
3987 DeviceMemory<std::complex<float>> *, int> impl;
3988 return impl(this, &blas::BlasSupport::DoBlasHerk, uplo, trans, n, k, alpha, a,
3989 lda, beta, c, ldc);
3990 }
3991
ThenBlasHerk(blas::UpperLower uplo,blas::Transpose trans,uint64 n,uint64 k,double alpha,const DeviceMemory<std::complex<double>> & a,int lda,double beta,DeviceMemory<std::complex<double>> * c,int ldc)3992 Stream &Stream::ThenBlasHerk(blas::UpperLower uplo, blas::Transpose trans,
3993 uint64 n, uint64 k, double alpha,
3994 const DeviceMemory<std::complex<double>> &a,
3995 int lda, double beta,
3996 DeviceMemory<std::complex<double>> *c, int ldc) {
3997 VLOG_CALL(PARAM(uplo), PARAM(trans), PARAM(n), PARAM(k), PARAM(alpha),
3998 PARAM(a), PARAM(lda), PARAM(beta), PARAM(c), PARAM(ldc));
3999
4000 ThenBlasImpl<blas::UpperLower, blas::Transpose, uint64, uint64, double,
4001 const DeviceMemory<std::complex<double>> &, int, double,
4002 DeviceMemory<std::complex<double>> *, int> impl;
4003 return impl(this, &blas::BlasSupport::DoBlasHerk, uplo, trans, n, k, alpha, a,
4004 lda, beta, c, ldc);
4005 }
4006
ThenBlasHer2k(blas::UpperLower uplo,blas::Transpose trans,uint64 n,uint64 k,std::complex<float> alpha,const DeviceMemory<std::complex<float>> & a,int lda,const DeviceMemory<std::complex<float>> & b,int ldb,float beta,DeviceMemory<std::complex<float>> * c,int ldc)4007 Stream &Stream::ThenBlasHer2k(blas::UpperLower uplo, blas::Transpose trans,
4008 uint64 n, uint64 k, std::complex<float> alpha,
4009 const DeviceMemory<std::complex<float>> &a,
4010 int lda,
4011 const DeviceMemory<std::complex<float>> &b,
4012 int ldb, float beta,
4013 DeviceMemory<std::complex<float>> *c, int ldc) {
4014 VLOG_CALL(PARAM(uplo), PARAM(trans), PARAM(n), PARAM(k), PARAM(alpha),
4015 PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb), PARAM(beta), PARAM(c),
4016 PARAM(ldc));
4017
4018 ThenBlasImpl<blas::UpperLower, blas::Transpose, uint64, uint64,
4019 std::complex<float>, const DeviceMemory<std::complex<float>> &,
4020 int, const DeviceMemory<std::complex<float>> &, int, float,
4021 DeviceMemory<std::complex<float>> *, int> impl;
4022 return impl(this, &blas::BlasSupport::DoBlasHer2k, uplo, trans, n, k, alpha,
4023 a, lda, b, ldb, beta, c, ldc);
4024 }
4025
ThenBlasHer2k(blas::UpperLower uplo,blas::Transpose trans,uint64 n,uint64 k,std::complex<double> alpha,const DeviceMemory<std::complex<double>> & a,int lda,const DeviceMemory<std::complex<double>> & b,int ldb,double beta,DeviceMemory<std::complex<double>> * c,int ldc)4026 Stream &Stream::ThenBlasHer2k(blas::UpperLower uplo, blas::Transpose trans,
4027 uint64 n, uint64 k, std::complex<double> alpha,
4028 const DeviceMemory<std::complex<double>> &a,
4029 int lda,
4030 const DeviceMemory<std::complex<double>> &b,
4031 int ldb, double beta,
4032 DeviceMemory<std::complex<double>> *c, int ldc) {
4033 VLOG_CALL(PARAM(uplo), PARAM(trans), PARAM(n), PARAM(k), PARAM(alpha),
4034 PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb), PARAM(beta), PARAM(c),
4035 PARAM(ldc));
4036
4037 ThenBlasImpl<blas::UpperLower, blas::Transpose, uint64, uint64,
4038 std::complex<double>, const DeviceMemory<std::complex<double>> &,
4039 int, const DeviceMemory<std::complex<double>> &, int, double,
4040 DeviceMemory<std::complex<double>> *, int> impl;
4041 return impl(this, &blas::BlasSupport::DoBlasHer2k, uplo, trans, n, k, alpha,
4042 a, lda, b, ldb, beta, c, ldc);
4043 }
4044
ThenBlasSymm(blas::Side side,blas::UpperLower uplo,uint64 m,uint64 n,float alpha,const DeviceMemory<float> & a,int lda,const DeviceMemory<float> & b,int ldb,float beta,DeviceMemory<float> * c,int ldc)4045 Stream &Stream::ThenBlasSymm(blas::Side side, blas::UpperLower uplo, uint64 m,
4046 uint64 n, float alpha,
4047 const DeviceMemory<float> &a, int lda,
4048 const DeviceMemory<float> &b, int ldb, float beta,
4049 DeviceMemory<float> *c, int ldc) {
4050 VLOG_CALL(PARAM(side), PARAM(uplo), PARAM(m), PARAM(n), PARAM(alpha),
4051 PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb), PARAM(beta), PARAM(c),
4052 PARAM(ldc));
4053
4054 ThenBlasImpl<blas::Side, blas::UpperLower, uint64, uint64, float,
4055 const DeviceMemory<float> &, int, const DeviceMemory<float> &,
4056 int, float, DeviceMemory<float> *, int> impl;
4057 return impl(this, &blas::BlasSupport::DoBlasSymm, side, uplo, m, n, alpha, a,
4058 lda, b, ldb, beta, c, ldc);
4059 }
4060
ThenBlasSymm(blas::Side side,blas::UpperLower uplo,uint64 m,uint64 n,double alpha,const DeviceMemory<double> & a,int lda,const DeviceMemory<double> & b,int ldb,double beta,DeviceMemory<double> * c,int ldc)4061 Stream &Stream::ThenBlasSymm(blas::Side side, blas::UpperLower uplo, uint64 m,
4062 uint64 n, double alpha,
4063 const DeviceMemory<double> &a, int lda,
4064 const DeviceMemory<double> &b, int ldb,
4065 double beta, DeviceMemory<double> *c, int ldc) {
4066 VLOG_CALL(PARAM(side), PARAM(uplo), PARAM(m), PARAM(n), PARAM(alpha),
4067 PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb), PARAM(beta), PARAM(c),
4068 PARAM(ldc));
4069
4070 ThenBlasImpl<blas::Side, blas::UpperLower, uint64, uint64, double,
4071 const DeviceMemory<double> &, int, const DeviceMemory<double> &,
4072 int, double, DeviceMemory<double> *, int> impl;
4073 return impl(this, &blas::BlasSupport::DoBlasSymm, side, uplo, m, n, alpha, a,
4074 lda, b, ldb, beta, c, ldc);
4075 }
4076
ThenBlasSymm(blas::Side side,blas::UpperLower uplo,uint64 m,uint64 n,std::complex<float> alpha,const DeviceMemory<std::complex<float>> & a,int lda,const DeviceMemory<std::complex<float>> & b,int ldb,std::complex<float> beta,DeviceMemory<std::complex<float>> * c,int ldc)4077 Stream &Stream::ThenBlasSymm(blas::Side side, blas::UpperLower uplo, uint64 m,
4078 uint64 n, std::complex<float> alpha,
4079 const DeviceMemory<std::complex<float>> &a,
4080 int lda,
4081 const DeviceMemory<std::complex<float>> &b,
4082 int ldb, std::complex<float> beta,
4083 DeviceMemory<std::complex<float>> *c, int ldc) {
4084 VLOG_CALL(PARAM(side), PARAM(uplo), PARAM(m), PARAM(n), PARAM(alpha),
4085 PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb), PARAM(beta), PARAM(c),
4086 PARAM(ldc));
4087
4088 ThenBlasImpl<blas::Side, blas::UpperLower, uint64, uint64,
4089 std::complex<float>, const DeviceMemory<std::complex<float>> &,
4090 int, const DeviceMemory<std::complex<float>> &, int,
4091 std::complex<float>, DeviceMemory<std::complex<float>> *,
4092 int> impl;
4093 return impl(this, &blas::BlasSupport::DoBlasSymm, side, uplo, m, n, alpha, a,
4094 lda, b, ldb, beta, c, ldc);
4095 }
4096
ThenBlasSymm(blas::Side side,blas::UpperLower uplo,uint64 m,uint64 n,std::complex<double> alpha,const DeviceMemory<std::complex<double>> & a,int lda,const DeviceMemory<std::complex<double>> & b,int ldb,std::complex<double> beta,DeviceMemory<std::complex<double>> * c,int ldc)4097 Stream &Stream::ThenBlasSymm(blas::Side side, blas::UpperLower uplo, uint64 m,
4098 uint64 n, std::complex<double> alpha,
4099 const DeviceMemory<std::complex<double>> &a,
4100 int lda,
4101 const DeviceMemory<std::complex<double>> &b,
4102 int ldb, std::complex<double> beta,
4103 DeviceMemory<std::complex<double>> *c, int ldc) {
4104 VLOG_CALL(PARAM(side), PARAM(uplo), PARAM(m), PARAM(n), PARAM(alpha),
4105 PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb), PARAM(beta), PARAM(c),
4106 PARAM(ldc));
4107
4108 ThenBlasImpl<blas::Side, blas::UpperLower, uint64, uint64,
4109 std::complex<double>, const DeviceMemory<std::complex<double>> &,
4110 int, const DeviceMemory<std::complex<double>> &, int,
4111 std::complex<double>, DeviceMemory<std::complex<double>> *,
4112 int> impl;
4113 return impl(this, &blas::BlasSupport::DoBlasSymm, side, uplo, m, n, alpha, a,
4114 lda, b, ldb, beta, c, ldc);
4115 }
4116
ThenBlasSyrk(blas::UpperLower uplo,blas::Transpose trans,uint64 n,uint64 k,float alpha,const DeviceMemory<float> & a,int lda,float beta,DeviceMemory<float> * c,int ldc)4117 Stream &Stream::ThenBlasSyrk(blas::UpperLower uplo, blas::Transpose trans,
4118 uint64 n, uint64 k, float alpha,
4119 const DeviceMemory<float> &a, int lda, float beta,
4120 DeviceMemory<float> *c, int ldc) {
4121 VLOG_CALL(PARAM(uplo), PARAM(trans), PARAM(n), PARAM(k), PARAM(alpha),
4122 PARAM(a), PARAM(lda), PARAM(beta), PARAM(c), PARAM(ldc));
4123
4124 ThenBlasImpl<blas::UpperLower, blas::Transpose, uint64, uint64, float,
4125 const DeviceMemory<float> &, int, float, DeviceMemory<float> *,
4126 int> impl;
4127 return impl(this, &blas::BlasSupport::DoBlasSyrk, uplo, trans, n, k, alpha, a,
4128 lda, beta, c, ldc);
4129 }
4130
ThenBlasSyrk(blas::UpperLower uplo,blas::Transpose trans,uint64 n,uint64 k,double alpha,const DeviceMemory<double> & a,int lda,double beta,DeviceMemory<double> * c,int ldc)4131 Stream &Stream::ThenBlasSyrk(blas::UpperLower uplo, blas::Transpose trans,
4132 uint64 n, uint64 k, double alpha,
4133 const DeviceMemory<double> &a, int lda,
4134 double beta, DeviceMemory<double> *c, int ldc) {
4135 VLOG_CALL(PARAM(uplo), PARAM(trans), PARAM(n), PARAM(k), PARAM(alpha),
4136 PARAM(a), PARAM(lda), PARAM(beta), PARAM(c), PARAM(ldc));
4137
4138 ThenBlasImpl<blas::UpperLower, blas::Transpose, uint64, uint64, double,
4139 const DeviceMemory<double> &, int, double,
4140 DeviceMemory<double> *, int> impl;
4141 return impl(this, &blas::BlasSupport::DoBlasSyrk, uplo, trans, n, k, alpha, a,
4142 lda, beta, c, ldc);
4143 }
4144
ThenBlasSyrk(blas::UpperLower uplo,blas::Transpose trans,uint64 n,uint64 k,std::complex<float> alpha,const DeviceMemory<std::complex<float>> & a,int lda,std::complex<float> beta,DeviceMemory<std::complex<float>> * c,int ldc)4145 Stream &Stream::ThenBlasSyrk(blas::UpperLower uplo, blas::Transpose trans,
4146 uint64 n, uint64 k, std::complex<float> alpha,
4147 const DeviceMemory<std::complex<float>> &a,
4148 int lda, std::complex<float> beta,
4149 DeviceMemory<std::complex<float>> *c, int ldc) {
4150 VLOG_CALL(PARAM(uplo), PARAM(trans), PARAM(n), PARAM(k), PARAM(alpha),
4151 PARAM(a), PARAM(lda), PARAM(beta), PARAM(c), PARAM(ldc));
4152
4153 ThenBlasImpl<blas::UpperLower, blas::Transpose, uint64, uint64,
4154 std::complex<float>, const DeviceMemory<std::complex<float>> &,
4155 int, std::complex<float>, DeviceMemory<std::complex<float>> *,
4156 int> impl;
4157 return impl(this, &blas::BlasSupport::DoBlasSyrk, uplo, trans, n, k, alpha, a,
4158 lda, beta, c, ldc);
4159 }
4160
ThenBlasSyrk(blas::UpperLower uplo,blas::Transpose trans,uint64 n,uint64 k,std::complex<double> alpha,const DeviceMemory<std::complex<double>> & a,int lda,std::complex<double> beta,DeviceMemory<std::complex<double>> * c,int ldc)4161 Stream &Stream::ThenBlasSyrk(blas::UpperLower uplo, blas::Transpose trans,
4162 uint64 n, uint64 k, std::complex<double> alpha,
4163 const DeviceMemory<std::complex<double>> &a,
4164 int lda, std::complex<double> beta,
4165 DeviceMemory<std::complex<double>> *c, int ldc) {
4166 VLOG_CALL(PARAM(uplo), PARAM(trans), PARAM(n), PARAM(k), PARAM(alpha),
4167 PARAM(a), PARAM(lda), PARAM(beta), PARAM(c), PARAM(ldc));
4168
4169 ThenBlasImpl<blas::UpperLower, blas::Transpose, uint64, uint64,
4170 std::complex<double>, const DeviceMemory<std::complex<double>> &,
4171 int, std::complex<double>, DeviceMemory<std::complex<double>> *,
4172 int> impl;
4173 return impl(this, &blas::BlasSupport::DoBlasSyrk, uplo, trans, n, k, alpha, a,
4174 lda, beta, c, ldc);
4175 }
4176
ThenBlasSyr2k(blas::UpperLower uplo,blas::Transpose trans,uint64 n,uint64 k,float alpha,const DeviceMemory<float> & a,int lda,const DeviceMemory<float> & b,int ldb,float beta,DeviceMemory<float> * c,int ldc)4177 Stream &Stream::ThenBlasSyr2k(blas::UpperLower uplo, blas::Transpose trans,
4178 uint64 n, uint64 k, float alpha,
4179 const DeviceMemory<float> &a, int lda,
4180 const DeviceMemory<float> &b, int ldb, float beta,
4181 DeviceMemory<float> *c, int ldc) {
4182 VLOG_CALL(PARAM(uplo), PARAM(trans), PARAM(n), PARAM(k), PARAM(alpha),
4183 PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb), PARAM(beta), PARAM(c),
4184 PARAM(ldc));
4185
4186 ThenBlasImpl<blas::UpperLower, blas::Transpose, uint64, uint64, float,
4187 const DeviceMemory<float> &, int, const DeviceMemory<float> &,
4188 int, float, DeviceMemory<float> *, int> impl;
4189 return impl(this, &blas::BlasSupport::DoBlasSyr2k, uplo, trans, n, k, alpha,
4190 a, lda, b, ldb, beta, c, ldc);
4191 }
4192
ThenBlasSyr2k(blas::UpperLower uplo,blas::Transpose trans,uint64 n,uint64 k,double alpha,const DeviceMemory<double> & a,int lda,const DeviceMemory<double> & b,int ldb,double beta,DeviceMemory<double> * c,int ldc)4193 Stream &Stream::ThenBlasSyr2k(blas::UpperLower uplo, blas::Transpose trans,
4194 uint64 n, uint64 k, double alpha,
4195 const DeviceMemory<double> &a, int lda,
4196 const DeviceMemory<double> &b, int ldb,
4197 double beta, DeviceMemory<double> *c, int ldc) {
4198 VLOG_CALL(PARAM(uplo), PARAM(trans), PARAM(n), PARAM(k), PARAM(alpha),
4199 PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb), PARAM(beta), PARAM(c),
4200 PARAM(ldc));
4201
4202 ThenBlasImpl<blas::UpperLower, blas::Transpose, uint64, uint64, double,
4203 const DeviceMemory<double> &, int, const DeviceMemory<double> &,
4204 int, double, DeviceMemory<double> *, int> impl;
4205 return impl(this, &blas::BlasSupport::DoBlasSyr2k, uplo, trans, n, k, alpha,
4206 a, lda, b, ldb, beta, c, ldc);
4207 }
4208
ThenBlasSyr2k(blas::UpperLower uplo,blas::Transpose trans,uint64 n,uint64 k,std::complex<float> alpha,const DeviceMemory<std::complex<float>> & a,int lda,const DeviceMemory<std::complex<float>> & b,int ldb,std::complex<float> beta,DeviceMemory<std::complex<float>> * c,int ldc)4209 Stream &Stream::ThenBlasSyr2k(blas::UpperLower uplo, blas::Transpose trans,
4210 uint64 n, uint64 k, std::complex<float> alpha,
4211 const DeviceMemory<std::complex<float>> &a,
4212 int lda,
4213 const DeviceMemory<std::complex<float>> &b,
4214 int ldb, std::complex<float> beta,
4215 DeviceMemory<std::complex<float>> *c, int ldc) {
4216 VLOG_CALL(PARAM(uplo), PARAM(trans), PARAM(n), PARAM(k), PARAM(alpha),
4217 PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb), PARAM(beta), PARAM(c),
4218 PARAM(ldc));
4219
4220 ThenBlasImpl<blas::UpperLower, blas::Transpose, uint64, uint64,
4221 std::complex<float>, const DeviceMemory<std::complex<float>> &,
4222 int, const DeviceMemory<std::complex<float>> &, int,
4223 std::complex<float>, DeviceMemory<std::complex<float>> *,
4224 int> impl;
4225 return impl(this, &blas::BlasSupport::DoBlasSyr2k, uplo, trans, n, k, alpha,
4226 a, lda, b, ldb, beta, c, ldc);
4227 }
4228
ThenBlasSyr2k(blas::UpperLower uplo,blas::Transpose trans,uint64 n,uint64 k,std::complex<double> alpha,const DeviceMemory<std::complex<double>> & a,int lda,const DeviceMemory<std::complex<double>> & b,int ldb,std::complex<double> beta,DeviceMemory<std::complex<double>> * c,int ldc)4229 Stream &Stream::ThenBlasSyr2k(blas::UpperLower uplo, blas::Transpose trans,
4230 uint64 n, uint64 k, std::complex<double> alpha,
4231 const DeviceMemory<std::complex<double>> &a,
4232 int lda,
4233 const DeviceMemory<std::complex<double>> &b,
4234 int ldb, std::complex<double> beta,
4235 DeviceMemory<std::complex<double>> *c, int ldc) {
4236 VLOG_CALL(PARAM(uplo), PARAM(trans), PARAM(n), PARAM(k), PARAM(alpha),
4237 PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb), PARAM(beta), PARAM(c),
4238 PARAM(ldc));
4239
4240 ThenBlasImpl<blas::UpperLower, blas::Transpose, uint64, uint64,
4241 std::complex<double>, const DeviceMemory<std::complex<double>> &,
4242 int, const DeviceMemory<std::complex<double>> &, int,
4243 std::complex<double>, DeviceMemory<std::complex<double>> *,
4244 int> impl;
4245 return impl(this, &blas::BlasSupport::DoBlasSyr2k, uplo, trans, n, k, alpha,
4246 a, lda, b, ldb, beta, c, ldc);
4247 }
4248
ThenBlasTrmm(blas::Side side,blas::UpperLower uplo,blas::Transpose transa,blas::Diagonal diag,uint64 m,uint64 n,float alpha,const DeviceMemory<float> & a,int lda,DeviceMemory<float> * b,int ldb)4249 Stream &Stream::ThenBlasTrmm(blas::Side side, blas::UpperLower uplo,
4250 blas::Transpose transa, blas::Diagonal diag,
4251 uint64 m, uint64 n, float alpha,
4252 const DeviceMemory<float> &a, int lda,
4253 DeviceMemory<float> *b, int ldb) {
4254 VLOG_CALL(PARAM(side), PARAM(uplo), PARAM(transa), PARAM(diag), PARAM(m),
4255 PARAM(n), PARAM(alpha), PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb));
4256
4257 ThenBlasImpl<blas::Side, blas::UpperLower, blas::Transpose, blas::Diagonal,
4258 uint64, uint64, float, const DeviceMemory<float> &, int,
4259 DeviceMemory<float> *, int> impl;
4260 return impl(this, &blas::BlasSupport::DoBlasTrmm, side, uplo, transa, diag, m,
4261 n, alpha, a, lda, b, ldb);
4262 }
4263
ThenBlasTrmm(blas::Side side,blas::UpperLower uplo,blas::Transpose transa,blas::Diagonal diag,uint64 m,uint64 n,double alpha,const DeviceMemory<double> & a,int lda,DeviceMemory<double> * b,int ldb)4264 Stream &Stream::ThenBlasTrmm(blas::Side side, blas::UpperLower uplo,
4265 blas::Transpose transa, blas::Diagonal diag,
4266 uint64 m, uint64 n, double alpha,
4267 const DeviceMemory<double> &a, int lda,
4268 DeviceMemory<double> *b, int ldb) {
4269 VLOG_CALL(PARAM(side), PARAM(uplo), PARAM(transa), PARAM(diag), PARAM(m),
4270 PARAM(n), PARAM(alpha), PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb));
4271
4272 ThenBlasImpl<blas::Side, blas::UpperLower, blas::Transpose, blas::Diagonal,
4273 uint64, uint64, double, const DeviceMemory<double> &, int,
4274 DeviceMemory<double> *, int> impl;
4275 return impl(this, &blas::BlasSupport::DoBlasTrmm, side, uplo, transa, diag, m,
4276 n, alpha, a, lda, b, ldb);
4277 }
4278
ThenBlasTrmm(blas::Side side,blas::UpperLower uplo,blas::Transpose transa,blas::Diagonal diag,uint64 m,uint64 n,std::complex<float> alpha,const DeviceMemory<std::complex<float>> & a,int lda,DeviceMemory<std::complex<float>> * b,int ldb)4279 Stream &Stream::ThenBlasTrmm(blas::Side side, blas::UpperLower uplo,
4280 blas::Transpose transa, blas::Diagonal diag,
4281 uint64 m, uint64 n, std::complex<float> alpha,
4282 const DeviceMemory<std::complex<float>> &a,
4283 int lda, DeviceMemory<std::complex<float>> *b,
4284 int ldb) {
4285 VLOG_CALL(PARAM(side), PARAM(uplo), PARAM(transa), PARAM(diag), PARAM(m),
4286 PARAM(n), PARAM(alpha), PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb));
4287
4288 ThenBlasImpl<blas::Side, blas::UpperLower, blas::Transpose, blas::Diagonal,
4289 uint64, uint64, std::complex<float>,
4290 const DeviceMemory<std::complex<float>> &, int,
4291 DeviceMemory<std::complex<float>> *, int> impl;
4292 return impl(this, &blas::BlasSupport::DoBlasTrmm, side, uplo, transa, diag, m,
4293 n, alpha, a, lda, b, ldb);
4294 }
4295
ThenBlasTrmm(blas::Side side,blas::UpperLower uplo,blas::Transpose transa,blas::Diagonal diag,uint64 m,uint64 n,std::complex<double> alpha,const DeviceMemory<std::complex<double>> & a,int lda,DeviceMemory<std::complex<double>> * b,int ldb)4296 Stream &Stream::ThenBlasTrmm(blas::Side side, blas::UpperLower uplo,
4297 blas::Transpose transa, blas::Diagonal diag,
4298 uint64 m, uint64 n, std::complex<double> alpha,
4299 const DeviceMemory<std::complex<double>> &a,
4300 int lda, DeviceMemory<std::complex<double>> *b,
4301 int ldb) {
4302 VLOG_CALL(PARAM(side), PARAM(uplo), PARAM(transa), PARAM(diag), PARAM(m),
4303 PARAM(n), PARAM(alpha), PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb));
4304
4305 ThenBlasImpl<blas::Side, blas::UpperLower, blas::Transpose, blas::Diagonal,
4306 uint64, uint64, std::complex<double>,
4307 const DeviceMemory<std::complex<double>> &, int,
4308 DeviceMemory<std::complex<double>> *, int> impl;
4309 return impl(this, &blas::BlasSupport::DoBlasTrmm, side, uplo, transa, diag, m,
4310 n, alpha, a, lda, b, ldb);
4311 }
4312
ThenBlasTrsm(blas::Side side,blas::UpperLower uplo,blas::Transpose transa,blas::Diagonal diag,uint64 m,uint64 n,float alpha,const DeviceMemory<float> & a,int lda,DeviceMemory<float> * b,int ldb)4313 Stream &Stream::ThenBlasTrsm(blas::Side side, blas::UpperLower uplo,
4314 blas::Transpose transa, blas::Diagonal diag,
4315 uint64 m, uint64 n, float alpha,
4316 const DeviceMemory<float> &a, int lda,
4317 DeviceMemory<float> *b, int ldb) {
4318 VLOG_CALL(PARAM(side), PARAM(uplo), PARAM(transa), PARAM(diag), PARAM(m),
4319 PARAM(n), PARAM(alpha), PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb));
4320
4321 ThenBlasImpl<blas::Side, blas::UpperLower, blas::Transpose, blas::Diagonal,
4322 uint64, uint64, float, const DeviceMemory<float> &, int,
4323 DeviceMemory<float> *, int> impl;
4324 return impl(this, &blas::BlasSupport::DoBlasTrsm, side, uplo, transa, diag, m,
4325 n, alpha, a, lda, b, ldb);
4326 }
4327
ThenBlasTrsm(blas::Side side,blas::UpperLower uplo,blas::Transpose transa,blas::Diagonal diag,uint64 m,uint64 n,double alpha,const DeviceMemory<double> & a,int lda,DeviceMemory<double> * b,int ldb)4328 Stream &Stream::ThenBlasTrsm(blas::Side side, blas::UpperLower uplo,
4329 blas::Transpose transa, blas::Diagonal diag,
4330 uint64 m, uint64 n, double alpha,
4331 const DeviceMemory<double> &a, int lda,
4332 DeviceMemory<double> *b, int ldb) {
4333 VLOG_CALL(PARAM(side), PARAM(uplo), PARAM(transa), PARAM(diag), PARAM(m),
4334 PARAM(n), PARAM(alpha), PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb));
4335
4336 ThenBlasImpl<blas::Side, blas::UpperLower, blas::Transpose, blas::Diagonal,
4337 uint64, uint64, double, const DeviceMemory<double> &, int,
4338 DeviceMemory<double> *, int> impl;
4339 return impl(this, &blas::BlasSupport::DoBlasTrsm, side, uplo, transa, diag, m,
4340 n, alpha, a, lda, b, ldb);
4341 }
4342
ThenBlasTrsm(blas::Side side,blas::UpperLower uplo,blas::Transpose transa,blas::Diagonal diag,uint64 m,uint64 n,std::complex<float> alpha,const DeviceMemory<std::complex<float>> & a,int lda,DeviceMemory<std::complex<float>> * b,int ldb)4343 Stream &Stream::ThenBlasTrsm(blas::Side side, blas::UpperLower uplo,
4344 blas::Transpose transa, blas::Diagonal diag,
4345 uint64 m, uint64 n, std::complex<float> alpha,
4346 const DeviceMemory<std::complex<float>> &a,
4347 int lda, DeviceMemory<std::complex<float>> *b,
4348 int ldb) {
4349 VLOG_CALL(PARAM(side), PARAM(uplo), PARAM(transa), PARAM(diag), PARAM(m),
4350 PARAM(n), PARAM(alpha), PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb));
4351
4352 ThenBlasImpl<blas::Side, blas::UpperLower, blas::Transpose, blas::Diagonal,
4353 uint64, uint64, std::complex<float>,
4354 const DeviceMemory<std::complex<float>> &, int,
4355 DeviceMemory<std::complex<float>> *, int> impl;
4356 return impl(this, &blas::BlasSupport::DoBlasTrsm, side, uplo, transa, diag, m,
4357 n, alpha, a, lda, b, ldb);
4358 }
4359
ThenBlasTrsm(blas::Side side,blas::UpperLower uplo,blas::Transpose transa,blas::Diagonal diag,uint64 m,uint64 n,std::complex<double> alpha,const DeviceMemory<std::complex<double>> & a,int lda,DeviceMemory<std::complex<double>> * b,int ldb)4360 Stream &Stream::ThenBlasTrsm(blas::Side side, blas::UpperLower uplo,
4361 blas::Transpose transa, blas::Diagonal diag,
4362 uint64 m, uint64 n, std::complex<double> alpha,
4363 const DeviceMemory<std::complex<double>> &a,
4364 int lda, DeviceMemory<std::complex<double>> *b,
4365 int ldb) {
4366 VLOG_CALL(PARAM(side), PARAM(uplo), PARAM(transa), PARAM(diag), PARAM(m),
4367 PARAM(n), PARAM(alpha), PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb));
4368
4369 ThenBlasImpl<blas::Side, blas::UpperLower, blas::Transpose, blas::Diagonal,
4370 uint64, uint64, std::complex<double>,
4371 const DeviceMemory<std::complex<double>> &, int,
4372 DeviceMemory<std::complex<double>> *, int> impl;
4373 return impl(this, &blas::BlasSupport::DoBlasTrsm, side, uplo, transa, diag, m,
4374 n, alpha, a, lda, b, ldb);
4375 }
4376
ThenBlasGemmBatched(blas::Transpose transa,blas::Transpose transb,uint64 m,uint64 n,uint64 k,float alpha,const port::ArraySlice<DeviceMemory<Eigen::half> * > & a,int lda,const port::ArraySlice<DeviceMemory<Eigen::half> * > & b,int ldb,float beta,const port::ArraySlice<DeviceMemory<Eigen::half> * > & c,int ldc,int batch_count)4377 Stream &Stream::ThenBlasGemmBatched(
4378 blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n,
4379 uint64 k, float alpha,
4380 const port::ArraySlice<DeviceMemory<Eigen::half> *> &a, int lda,
4381 const port::ArraySlice<DeviceMemory<Eigen::half> *> &b, int ldb, float beta,
4382 const port::ArraySlice<DeviceMemory<Eigen::half> *> &c, int ldc,
4383 int batch_count) {
4384 return ThenBlasGemmBatchedWithScratch(transa, transb, m, n, k, alpha, a, lda,
4385 b, ldb, beta, c, ldc, batch_count,
4386 /*scratch_allocator=*/nullptr);
4387 }
4388
ThenBlasGemmBatchedWithScratch(blas::Transpose transa,blas::Transpose transb,uint64 m,uint64 n,uint64 k,float alpha,const port::ArraySlice<DeviceMemory<Eigen::half> * > & a,int lda,const port::ArraySlice<DeviceMemory<Eigen::half> * > & b,int ldb,float beta,const port::ArraySlice<DeviceMemory<Eigen::half> * > & c,int ldc,int batch_count,ScratchAllocator * scratch_allocator)4389 Stream &Stream::ThenBlasGemmBatchedWithScratch(
4390 blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n,
4391 uint64 k, float alpha,
4392 const port::ArraySlice<DeviceMemory<Eigen::half> *> &a, int lda,
4393 const port::ArraySlice<DeviceMemory<Eigen::half> *> &b, int ldb, float beta,
4394 const port::ArraySlice<DeviceMemory<Eigen::half> *> &c, int ldc,
4395 int batch_count, ScratchAllocator *scratch_allocator) {
4396 VLOG_CALL(PARAM(transa), PARAM(transb), PARAM(m), PARAM(n), PARAM(k),
4397 PARAM(alpha), PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb),
4398 PARAM(beta), PARAM(c), PARAM(ldc), PARAM(batch_count));
4399
4400 ThenBlasImpl<blas::Transpose, blas::Transpose, uint64, uint64, uint64, float,
4401 const port::ArraySlice<DeviceMemory<Eigen::half> *> &, int,
4402 const port::ArraySlice<DeviceMemory<Eigen::half> *> &, int,
4403 float, const port::ArraySlice<DeviceMemory<Eigen::half> *> &,
4404 int, int, ScratchAllocator *>
4405 impl;
4406 return impl(this, &blas::BlasSupport::DoBlasGemmBatched, transa, transb, m, n,
4407 k, alpha, a, lda, b, ldb, beta, c, ldc, batch_count,
4408 scratch_allocator);
4409 }
4410
ThenBlasGemmBatched(blas::Transpose transa,blas::Transpose transb,uint64 m,uint64 n,uint64 k,float alpha,const port::ArraySlice<DeviceMemory<float> * > & a,int lda,const port::ArraySlice<DeviceMemory<float> * > & b,int ldb,float beta,const port::ArraySlice<DeviceMemory<float> * > & c,int ldc,int batch_count)4411 Stream &Stream::ThenBlasGemmBatched(
4412 blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n,
4413 uint64 k, float alpha, const port::ArraySlice<DeviceMemory<float> *> &a,
4414 int lda, const port::ArraySlice<DeviceMemory<float> *> &b, int ldb,
4415 float beta, const port::ArraySlice<DeviceMemory<float> *> &c, int ldc,
4416 int batch_count) {
4417 return ThenBlasGemmBatchedWithScratch(transa, transb, m, n, k, alpha, a, lda,
4418 b, ldb, beta, c, ldc, batch_count,
4419 /*scratch_allocator=*/nullptr);
4420 }
4421
ThenBlasGemmBatchedWithScratch(blas::Transpose transa,blas::Transpose transb,uint64 m,uint64 n,uint64 k,float alpha,const port::ArraySlice<DeviceMemory<float> * > & a,int lda,const port::ArraySlice<DeviceMemory<float> * > & b,int ldb,float beta,const port::ArraySlice<DeviceMemory<float> * > & c,int ldc,int batch_count,ScratchAllocator * scratch_allocator)4422 Stream &Stream::ThenBlasGemmBatchedWithScratch(
4423 blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n,
4424 uint64 k, float alpha, const port::ArraySlice<DeviceMemory<float> *> &a,
4425 int lda, const port::ArraySlice<DeviceMemory<float> *> &b, int ldb,
4426 float beta, const port::ArraySlice<DeviceMemory<float> *> &c, int ldc,
4427 int batch_count, ScratchAllocator *scratch_allocator) {
4428 VLOG_CALL(PARAM(transa), PARAM(transb), PARAM(m), PARAM(n), PARAM(k),
4429 PARAM(alpha), PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb),
4430 PARAM(beta), PARAM(c), PARAM(ldc), PARAM(batch_count));
4431
4432 ThenBlasImpl<blas::Transpose, blas::Transpose, uint64, uint64, uint64, float,
4433 const port::ArraySlice<DeviceMemory<float> *> &, int,
4434 const port::ArraySlice<DeviceMemory<float> *> &, int, float,
4435 const port::ArraySlice<DeviceMemory<float> *> &, int, int,
4436 ScratchAllocator *>
4437 impl;
4438 return impl(this, &blas::BlasSupport::DoBlasGemmBatched, transa, transb, m, n,
4439 k, alpha, a, lda, b, ldb, beta, c, ldc, batch_count,
4440 scratch_allocator);
4441 }
4442
ThenBlasGemmBatched(blas::Transpose transa,blas::Transpose transb,uint64 m,uint64 n,uint64 k,double alpha,const port::ArraySlice<DeviceMemory<double> * > & a,int lda,const port::ArraySlice<DeviceMemory<double> * > & b,int ldb,double beta,const port::ArraySlice<DeviceMemory<double> * > & c,int ldc,int batch_count)4443 Stream &Stream::ThenBlasGemmBatched(
4444 blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n,
4445 uint64 k, double alpha, const port::ArraySlice<DeviceMemory<double> *> &a,
4446 int lda, const port::ArraySlice<DeviceMemory<double> *> &b, int ldb,
4447 double beta, const port::ArraySlice<DeviceMemory<double> *> &c, int ldc,
4448 int batch_count) {
4449 return ThenBlasGemmBatchedWithScratch(transa, transb, m, n, k, alpha, a, lda,
4450 b, ldb, beta, c, ldc, batch_count,
4451 /*scratch_allocator=*/nullptr);
4452 }
4453
ThenBlasGemmBatchedWithScratch(blas::Transpose transa,blas::Transpose transb,uint64 m,uint64 n,uint64 k,double alpha,const port::ArraySlice<DeviceMemory<double> * > & a,int lda,const port::ArraySlice<DeviceMemory<double> * > & b,int ldb,double beta,const port::ArraySlice<DeviceMemory<double> * > & c,int ldc,int batch_count,ScratchAllocator * scratch_allocator)4454 Stream &Stream::ThenBlasGemmBatchedWithScratch(
4455 blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n,
4456 uint64 k, double alpha, const port::ArraySlice<DeviceMemory<double> *> &a,
4457 int lda, const port::ArraySlice<DeviceMemory<double> *> &b, int ldb,
4458 double beta, const port::ArraySlice<DeviceMemory<double> *> &c, int ldc,
4459 int batch_count, ScratchAllocator *scratch_allocator) {
4460 VLOG_CALL(PARAM(transa), PARAM(transb), PARAM(m), PARAM(n), PARAM(k),
4461 PARAM(alpha), PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb),
4462 PARAM(beta), PARAM(c), PARAM(ldc), PARAM(batch_count));
4463
4464 ThenBlasImpl<blas::Transpose, blas::Transpose, uint64, uint64, uint64, double,
4465 const port::ArraySlice<DeviceMemory<double> *> &, int,
4466 const port::ArraySlice<DeviceMemory<double> *> &, int, double,
4467 const port::ArraySlice<DeviceMemory<double> *> &, int, int,
4468 ScratchAllocator *>
4469 impl;
4470 return impl(this, &blas::BlasSupport::DoBlasGemmBatched, transa, transb, m, n,
4471 k, alpha, a, lda, b, ldb, beta, c, ldc, batch_count,
4472 scratch_allocator);
4473 }
4474
ThenBlasGemmBatched(blas::Transpose transa,blas::Transpose transb,uint64 m,uint64 n,uint64 k,std::complex<float> alpha,const port::ArraySlice<DeviceMemory<std::complex<float>> * > & a,int lda,const port::ArraySlice<DeviceMemory<std::complex<float>> * > & b,int ldb,std::complex<float> beta,const port::ArraySlice<DeviceMemory<std::complex<float>> * > & c,int ldc,int batch_count)4475 Stream &Stream::ThenBlasGemmBatched(
4476 blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n,
4477 uint64 k, std::complex<float> alpha,
4478 const port::ArraySlice<DeviceMemory<std::complex<float>> *> &a, int lda,
4479 const port::ArraySlice<DeviceMemory<std::complex<float>> *> &b, int ldb,
4480 std::complex<float> beta,
4481 const port::ArraySlice<DeviceMemory<std::complex<float>> *> &c, int ldc,
4482 int batch_count) {
4483 return ThenBlasGemmBatchedWithScratch(transa, transb, m, n, k, alpha, a, lda,
4484 b, ldb, beta, c, ldc, batch_count,
4485 /*scratch_allocator=*/nullptr);
4486 }
4487
ThenBlasGemmBatchedWithScratch(blas::Transpose transa,blas::Transpose transb,uint64 m,uint64 n,uint64 k,std::complex<float> alpha,const port::ArraySlice<DeviceMemory<std::complex<float>> * > & a,int lda,const port::ArraySlice<DeviceMemory<std::complex<float>> * > & b,int ldb,std::complex<float> beta,const port::ArraySlice<DeviceMemory<std::complex<float>> * > & c,int ldc,int batch_count,ScratchAllocator * scratch_allocator)4488 Stream &Stream::ThenBlasGemmBatchedWithScratch(
4489 blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n,
4490 uint64 k, std::complex<float> alpha,
4491 const port::ArraySlice<DeviceMemory<std::complex<float>> *> &a, int lda,
4492 const port::ArraySlice<DeviceMemory<std::complex<float>> *> &b, int ldb,
4493 std::complex<float> beta,
4494 const port::ArraySlice<DeviceMemory<std::complex<float>> *> &c, int ldc,
4495 int batch_count, ScratchAllocator *scratch_allocator) {
4496 VLOG_CALL(PARAM(transa), PARAM(transb), PARAM(m), PARAM(n), PARAM(k),
4497 PARAM(alpha), PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb),
4498 PARAM(beta), PARAM(c), PARAM(ldc), PARAM(batch_count));
4499
4500 ThenBlasImpl<blas::Transpose, blas::Transpose, uint64, uint64, uint64,
4501 std::complex<float>,
4502 const port::ArraySlice<DeviceMemory<std::complex<float>> *> &,
4503 int,
4504 const port::ArraySlice<DeviceMemory<std::complex<float>> *> &,
4505 int, std::complex<float>,
4506 const port::ArraySlice<DeviceMemory<std::complex<float>> *> &,
4507 int, int, ScratchAllocator *>
4508 impl;
4509 return impl(this, &blas::BlasSupport::DoBlasGemmBatched, transa, transb, m, n,
4510 k, alpha, a, lda, b, ldb, beta, c, ldc, batch_count,
4511 scratch_allocator);
4512 }
4513
ThenBlasGemmBatched(blas::Transpose transa,blas::Transpose transb,uint64 m,uint64 n,uint64 k,std::complex<double> alpha,const port::ArraySlice<DeviceMemory<std::complex<double>> * > & a,int lda,const port::ArraySlice<DeviceMemory<std::complex<double>> * > & b,int ldb,std::complex<double> beta,const port::ArraySlice<DeviceMemory<std::complex<double>> * > & c,int ldc,int batch_count)4514 Stream &Stream::ThenBlasGemmBatched(
4515 blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n,
4516 uint64 k, std::complex<double> alpha,
4517 const port::ArraySlice<DeviceMemory<std::complex<double>> *> &a, int lda,
4518 const port::ArraySlice<DeviceMemory<std::complex<double>> *> &b, int ldb,
4519 std::complex<double> beta,
4520 const port::ArraySlice<DeviceMemory<std::complex<double>> *> &c, int ldc,
4521 int batch_count) {
4522 return ThenBlasGemmBatchedWithScratch(transa, transb, m, n, k, alpha, a, lda,
4523 b, ldb, beta, c, ldc, batch_count,
4524 /*scratch_allocator=*/nullptr);
4525 }
4526
ThenBlasGemmBatchedWithScratch(blas::Transpose transa,blas::Transpose transb,uint64 m,uint64 n,uint64 k,std::complex<double> alpha,const port::ArraySlice<DeviceMemory<std::complex<double>> * > & a,int lda,const port::ArraySlice<DeviceMemory<std::complex<double>> * > & b,int ldb,std::complex<double> beta,const port::ArraySlice<DeviceMemory<std::complex<double>> * > & c,int ldc,int batch_count,ScratchAllocator * scratch_allocator)4527 Stream &Stream::ThenBlasGemmBatchedWithScratch(
4528 blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n,
4529 uint64 k, std::complex<double> alpha,
4530 const port::ArraySlice<DeviceMemory<std::complex<double>> *> &a, int lda,
4531 const port::ArraySlice<DeviceMemory<std::complex<double>> *> &b, int ldb,
4532 std::complex<double> beta,
4533 const port::ArraySlice<DeviceMemory<std::complex<double>> *> &c, int ldc,
4534 int batch_count, ScratchAllocator *scratch_allocator) {
4535 VLOG_CALL(PARAM(transa), PARAM(transb), PARAM(m), PARAM(n), PARAM(k),
4536 PARAM(alpha), PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb),
4537 PARAM(beta), PARAM(c), PARAM(ldc), PARAM(batch_count));
4538
4539 ThenBlasImpl<blas::Transpose, blas::Transpose, uint64, uint64, uint64,
4540 std::complex<double>,
4541 const port::ArraySlice<DeviceMemory<std::complex<double>> *> &,
4542 int,
4543 const port::ArraySlice<DeviceMemory<std::complex<double>> *> &,
4544 int, std::complex<double>,
4545 const port::ArraySlice<DeviceMemory<std::complex<double>> *> &,
4546 int, int, ScratchAllocator *>
4547 impl;
4548 return impl(this, &blas::BlasSupport::DoBlasGemmBatched, transa, transb, m, n,
4549 k, alpha, a, lda, b, ldb, beta, c, ldc, batch_count,
4550 scratch_allocator);
4551 }
4552
ThenBlasGemmStridedBatched(blas::Transpose transa,blas::Transpose transb,uint64 m,uint64 n,uint64 k,float alpha,const DeviceMemory<Eigen::half> & a,int lda,int64 stride_a,const DeviceMemory<Eigen::half> & b,int ldb,int64 stride_b,float beta,DeviceMemory<Eigen::half> * c,int ldc,int64 stride_c,int batch_count)4553 Stream &Stream::ThenBlasGemmStridedBatched(
4554 blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n,
4555 uint64 k, float alpha, const DeviceMemory<Eigen::half> &a, int lda,
4556 int64 stride_a, const DeviceMemory<Eigen::half> &b, int ldb, int64 stride_b,
4557 float beta, DeviceMemory<Eigen::half> *c, int ldc, int64 stride_c,
4558 int batch_count) {
4559 VLOG_CALL(PARAM(transa), PARAM(transb), PARAM(m), PARAM(n), PARAM(k),
4560 PARAM(alpha), PARAM(a), PARAM(lda), PARAM(stride_a), PARAM(b),
4561 PARAM(ldb), PARAM(stride_b), PARAM(beta), PARAM(c), PARAM(ldc),
4562 PARAM(stride_c), PARAM(batch_count));
4563
4564 ThenBlasImpl<blas::Transpose, blas::Transpose, uint64, uint64, uint64, float,
4565 const DeviceMemory<Eigen::half> &, int, int64,
4566 const DeviceMemory<Eigen::half> &, int, int64, float,
4567 DeviceMemory<Eigen::half> *, int, int64, int>
4568 impl;
4569 return impl(this, &blas::BlasSupport::DoBlasGemmStridedBatched, transa,
4570 transb, m, n, k, alpha, a, lda, stride_a, b, ldb, stride_b, beta,
4571 c, ldc, stride_c, batch_count);
4572 }
4573
ThenBlasGemmStridedBatched(blas::Transpose transa,blas::Transpose transb,uint64 m,uint64 n,uint64 k,float alpha,const DeviceMemory<float> & a,int lda,int64 stride_a,const DeviceMemory<float> & b,int ldb,int64 stride_b,float beta,DeviceMemory<float> * c,int ldc,int64 stride_c,int batch_count)4574 Stream &Stream::ThenBlasGemmStridedBatched(
4575 blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n,
4576 uint64 k, float alpha, const DeviceMemory<float> &a, int lda,
4577 int64 stride_a, const DeviceMemory<float> &b, int ldb, int64 stride_b,
4578 float beta, DeviceMemory<float> *c, int ldc, int64 stride_c,
4579 int batch_count) {
4580 VLOG_CALL(PARAM(transa), PARAM(transb), PARAM(m), PARAM(n), PARAM(k),
4581 PARAM(alpha), PARAM(a), PARAM(lda), PARAM(stride_a), PARAM(b),
4582 PARAM(ldb), PARAM(stride_b), PARAM(beta), PARAM(c), PARAM(ldc),
4583 PARAM(stride_c), PARAM(batch_count));
4584
4585 ThenBlasImpl<blas::Transpose, blas::Transpose, uint64, uint64, uint64, float,
4586 const DeviceMemory<float> &, int, int64,
4587 const DeviceMemory<float> &, int, int64, float,
4588 DeviceMemory<float> *, int, int64, int>
4589 impl;
4590 return impl(this, &blas::BlasSupport::DoBlasGemmStridedBatched, transa,
4591 transb, m, n, k, alpha, a, lda, stride_a, b, ldb, stride_b, beta,
4592 c, ldc, stride_c, batch_count);
4593 }
4594
ThenBlasGemmStridedBatched(blas::Transpose transa,blas::Transpose transb,uint64 m,uint64 n,uint64 k,double alpha,const DeviceMemory<double> & a,int lda,int64 stride_a,const DeviceMemory<double> & b,int ldb,int64 stride_b,double beta,DeviceMemory<double> * c,int ldc,int64 stride_c,int batch_count)4595 Stream &Stream::ThenBlasGemmStridedBatched(
4596 blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n,
4597 uint64 k, double alpha, const DeviceMemory<double> &a, int lda,
4598 int64 stride_a, const DeviceMemory<double> &b, int ldb, int64 stride_b,
4599 double beta, DeviceMemory<double> *c, int ldc, int64 stride_c,
4600 int batch_count) {
4601 VLOG_CALL(PARAM(transa), PARAM(transb), PARAM(m), PARAM(n), PARAM(k),
4602 PARAM(alpha), PARAM(a), PARAM(lda), PARAM(stride_a), PARAM(b),
4603 PARAM(ldb), PARAM(stride_b), PARAM(beta), PARAM(c), PARAM(ldc),
4604 PARAM(stride_c), PARAM(batch_count));
4605
4606 ThenBlasImpl<blas::Transpose, blas::Transpose, uint64, uint64, uint64, double,
4607 const DeviceMemory<double> &, int, int64,
4608 const DeviceMemory<double> &, int, int64, double,
4609 DeviceMemory<double> *, int, int64, int>
4610 impl;
4611 return impl(this, &blas::BlasSupport::DoBlasGemmStridedBatched, transa,
4612 transb, m, n, k, alpha, a, lda, stride_a, b, ldb, stride_b, beta,
4613 c, ldc, stride_c, batch_count);
4614 }
4615
ThenBlasGemmStridedBatched(blas::Transpose transa,blas::Transpose transb,uint64 m,uint64 n,uint64 k,std::complex<float> alpha,const DeviceMemory<std::complex<float>> & a,int lda,int64 stride_a,const DeviceMemory<std::complex<float>> & b,int ldb,int64 stride_b,std::complex<float> beta,DeviceMemory<std::complex<float>> * c,int ldc,int64 stride_c,int batch_count)4616 Stream &Stream::ThenBlasGemmStridedBatched(
4617 blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n,
4618 uint64 k, std::complex<float> alpha,
4619 const DeviceMemory<std::complex<float>> &a, int lda, int64 stride_a,
4620 const DeviceMemory<std::complex<float>> &b, int ldb, int64 stride_b,
4621 std::complex<float> beta, DeviceMemory<std::complex<float>> *c, int ldc,
4622 int64 stride_c, int batch_count) {
4623 VLOG_CALL(PARAM(transa), PARAM(transb), PARAM(m), PARAM(n), PARAM(k),
4624 PARAM(alpha), PARAM(a), PARAM(lda), PARAM(stride_a), PARAM(b),
4625 PARAM(ldb), PARAM(stride_b), PARAM(beta), PARAM(c), PARAM(ldc),
4626 PARAM(stride_c), PARAM(batch_count));
4627
4628 ThenBlasImpl<blas::Transpose, blas::Transpose, uint64, uint64, uint64,
4629 std::complex<float>, const DeviceMemory<std::complex<float>> &,
4630 int, int64, const DeviceMemory<std::complex<float>> &, int,
4631 int64, std::complex<float>, DeviceMemory<std::complex<float>> *,
4632 int, int64, int>
4633 impl;
4634 return impl(this, &blas::BlasSupport::DoBlasGemmStridedBatched, transa,
4635 transb, m, n, k, alpha, a, lda, stride_a, b, ldb, stride_b, beta,
4636 c, ldc, stride_c, batch_count);
4637 }
4638
ThenBlasGemmStridedBatched(blas::Transpose transa,blas::Transpose transb,uint64 m,uint64 n,uint64 k,std::complex<double> alpha,const DeviceMemory<std::complex<double>> & a,int lda,int64 stride_a,const DeviceMemory<std::complex<double>> & b,int ldb,int64 stride_b,std::complex<double> beta,DeviceMemory<std::complex<double>> * c,int ldc,int64 stride_c,int batch_count)4639 Stream &Stream::ThenBlasGemmStridedBatched(
4640 blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n,
4641 uint64 k, std::complex<double> alpha,
4642 const DeviceMemory<std::complex<double>> &a, int lda, int64 stride_a,
4643 const DeviceMemory<std::complex<double>> &b, int ldb, int64 stride_b,
4644 std::complex<double> beta, DeviceMemory<std::complex<double>> *c, int ldc,
4645 int64 stride_c, int batch_count) {
4646 VLOG_CALL(PARAM(transa), PARAM(transb), PARAM(m), PARAM(n), PARAM(k),
4647 PARAM(alpha), PARAM(a), PARAM(lda), PARAM(stride_a), PARAM(b),
4648 PARAM(ldb), PARAM(stride_b), PARAM(beta), PARAM(c), PARAM(ldc),
4649 PARAM(stride_c), PARAM(batch_count));
4650
4651 ThenBlasImpl<blas::Transpose, blas::Transpose, uint64, uint64, uint64,
4652 std::complex<double>, const DeviceMemory<std::complex<double>> &,
4653 int, int64, const DeviceMemory<std::complex<double>> &, int,
4654 int64, std::complex<double>,
4655 DeviceMemory<std::complex<double>> *, int, int64, int>
4656 impl;
4657 return impl(this, &blas::BlasSupport::DoBlasGemmStridedBatched, transa,
4658 transb, m, n, k, alpha, a, lda, stride_a, b, ldb, stride_b, beta,
4659 c, ldc, stride_c, batch_count);
4660 }
4661
ThenSetRngSeed(const uint8 * seed,uint64 seed_bytes)4662 Stream &Stream::ThenSetRngSeed(const uint8 *seed, uint64 seed_bytes) {
4663 VLOG_CALL(PARAM(seed), PARAM(seed_bytes));
4664
4665 if (ok()) {
4666 if (rng::RngSupport *rng = parent_->AsRng()) {
4667 CheckError(rng->SetSeed(this, seed, seed_bytes));
4668 } else {
4669 SetError();
4670 LOG(INFO) << DebugStreamPointers() << " unable to initialize RNG";
4671 }
4672 } else {
4673 LOG(INFO) << DebugStreamPointers()
4674 << " did not set RNG seed: " << static_cast<const void *>(seed)
4675 << "; bytes: " << seed_bytes;
4676 }
4677 return *this;
4678 }
4679
ThenPopulateRandUniform(DeviceMemory<float> * values)4680 Stream &Stream::ThenPopulateRandUniform(DeviceMemory<float> *values) {
4681 VLOG_CALL(PARAM(values));
4682
4683 if (ok()) {
4684 if (rng::RngSupport *rng = parent_->AsRng()) {
4685 CheckError(rng->DoPopulateRandUniform(this, values));
4686 } else {
4687 SetError();
4688 LOG(INFO) << DebugStreamPointers()
4689 << " attempting to perform RNG operation using StreamExecutor"
4690 " without RNG support.";
4691 }
4692 }
4693 return *this;
4694 }
4695
ThenPopulateRandGaussian(float mean,float sd,DeviceMemory<float> * values)4696 Stream &Stream::ThenPopulateRandGaussian(float mean, float sd,
4697 DeviceMemory<float> *values) {
4698 VLOG_CALL(PARAM(mean), PARAM(sd), PARAM(values));
4699
4700 if (ok()) {
4701 if (rng::RngSupport *rng = parent_->AsRng()) {
4702 CheckError(rng->DoPopulateRandGaussian(this, mean, sd, values));
4703 } else {
4704 SetError();
4705 LOG(INFO) << DebugStreamPointers()
4706 << " attempting to perform RNG operation using StreamExecutor"
4707 " without RNG support.";
4708 }
4709 }
4710 return *this;
4711 }
4712
ThenPopulateRandGaussian(double mean,double sd,DeviceMemory<double> * values)4713 Stream &Stream::ThenPopulateRandGaussian(double mean, double sd,
4714 DeviceMemory<double> *values) {
4715 VLOG_CALL(PARAM(mean), PARAM(sd), PARAM(values));
4716
4717 if (ok()) {
4718 if (rng::RngSupport *rng = parent_->AsRng()) {
4719 CheckError(rng->DoPopulateRandGaussian(this, mean, sd, values));
4720 } else {
4721 SetError();
4722 LOG(INFO) << DebugStreamPointers()
4723 << " attempting to perform RNG operation using StreamExecutor"
4724 " without RNG support.";
4725 }
4726 }
4727 return *this;
4728 }
4729
ThenPopulateRandUniform(DeviceMemory<double> * values)4730 Stream &Stream::ThenPopulateRandUniform(DeviceMemory<double> *values) {
4731 VLOG_CALL(PARAM(values));
4732
4733 if (ok()) {
4734 if (rng::RngSupport *rng = parent_->AsRng()) {
4735 CheckError(rng->DoPopulateRandUniform(this, values));
4736 } else {
4737 SetError();
4738 LOG(INFO) << DebugStreamPointers()
4739 << " attempting to perform RNG operation using StreamExecutor"
4740 " without RNG support.";
4741 }
4742 }
4743 return *this;
4744 }
4745
ThenPopulateRandUniform(DeviceMemory<std::complex<float>> * values)4746 Stream &Stream::ThenPopulateRandUniform(
4747 DeviceMemory<std::complex<float>> *values) {
4748 VLOG_CALL(PARAM(values));
4749
4750 if (ok()) {
4751 if (rng::RngSupport *rng = parent_->AsRng()) {
4752 CheckError(rng->DoPopulateRandUniform(this, values));
4753 } else {
4754 SetError();
4755 LOG(INFO) << DebugStreamPointers()
4756 << " attempting to perform RNG operation using StreamExecutor"
4757 " without RNG support.";
4758 }
4759 }
4760 return *this;
4761 }
4762
ThenPopulateRandUniform(DeviceMemory<std::complex<double>> * values)4763 Stream &Stream::ThenPopulateRandUniform(
4764 DeviceMemory<std::complex<double>> *values) {
4765 VLOG_CALL(PARAM(values));
4766
4767 if (ok()) {
4768 if (rng::RngSupport *rng = parent_->AsRng()) {
4769 CheckError(rng->DoPopulateRandUniform(this, values));
4770 } else {
4771 SetError();
4772 LOG(INFO) << DebugStreamPointers()
4773 << " attempting to perform RNG operation using StreamExecutor"
4774 " without RNG support.";
4775 }
4776 }
4777 return *this;
4778 }
4779
ThenMemcpy(void * host_dst,const DeviceMemoryBase & gpu_src,uint64 size)4780 Stream &Stream::ThenMemcpy(void *host_dst, const DeviceMemoryBase &gpu_src,
4781 uint64 size) {
4782 VLOG_CALL(PARAM(host_dst), PARAM(gpu_src), PARAM(size));
4783
4784 if (ok()) {
4785 CheckError(parent_->Memcpy(this, host_dst, gpu_src, size));
4786 } else {
4787 LOG(INFO) << DebugStreamPointers()
4788 << " did not memcpy device-to-host; source: " << gpu_src.opaque();
4789 }
4790 return *this;
4791 }
4792
ThenMemcpy(DeviceMemoryBase * gpu_dst,const void * host_src,uint64 size)4793 Stream &Stream::ThenMemcpy(DeviceMemoryBase *gpu_dst, const void *host_src,
4794 uint64 size) {
4795 VLOG_CALL(PARAM(gpu_dst), PARAM(host_src), PARAM(size));
4796
4797 if (ok()) {
4798 CheckError(parent_->Memcpy(this, gpu_dst, host_src, size));
4799 } else {
4800 LOG(INFO) << DebugStreamPointers()
4801 << " did not memcpy host-to-device; source: " << host_src;
4802 }
4803 return *this;
4804 }
4805
ThenMemcpy(DeviceMemoryBase * gpu_dst,const DeviceMemoryBase & gpu_src,uint64 size)4806 Stream &Stream::ThenMemcpy(DeviceMemoryBase *gpu_dst,
4807 const DeviceMemoryBase &gpu_src, uint64 size) {
4808 VLOG_CALL(PARAM(gpu_dst), PARAM(gpu_src), PARAM(size));
4809
4810 if (ok()) {
4811 CheckError(parent_->MemcpyDeviceToDevice(this, gpu_dst, gpu_src, size));
4812 } else {
4813 LOG(INFO) << DebugStreamPointers()
4814 << " did not memcpy gpu-to-gpu; source: " << &gpu_src;
4815 }
4816 return *this;
4817 }
4818
ThenMemZero(DeviceMemoryBase * location,uint64 size)4819 Stream &Stream::ThenMemZero(DeviceMemoryBase *location, uint64 size) {
4820 VLOG_CALL(PARAM(location), PARAM(size));
4821
4822 if (ok()) {
4823 CheckError(parent_->MemZero(this, location, size));
4824 } else {
4825 LOG(INFO) << DebugStreamPointers()
4826 << " did not memzero GPU location; source: " << location;
4827 }
4828 return *this;
4829 }
4830
ThenMemset32(DeviceMemoryBase * location,uint32 pattern,uint64 size)4831 Stream &Stream::ThenMemset32(DeviceMemoryBase *location, uint32 pattern,
4832 uint64 size) {
4833 VLOG_CALL(PARAM(location), PARAM(pattern), PARAM(size));
4834
4835 if (ok()) {
4836 CheckError(parent_->Memset32(this, location, pattern, size));
4837 } else {
4838 LOG(INFO) << DebugStreamPointers()
4839 << " did not memset GPU location; source: " << location
4840 << "; size: " << size << "; pattern: " << std::hex << pattern;
4841 }
4842 return *this;
4843 }
4844
ThenRnnForward(const dnn::RnnDescriptor & rnn_desc,const dnn::RnnSequenceTensorDescriptor & input_desc,const DeviceMemory<Eigen::half> & input_data,const dnn::RnnStateTensorDescriptor & input_h_desc,const DeviceMemory<Eigen::half> & input_h_data,const dnn::RnnStateTensorDescriptor & input_c_desc,const DeviceMemory<Eigen::half> & input_c_data,const DeviceMemory<Eigen::half> & params,const dnn::RnnSequenceTensorDescriptor & output_desc,DeviceMemory<Eigen::half> * output_data,const dnn::RnnStateTensorDescriptor & output_h_desc,DeviceMemory<Eigen::half> * output_h_data,const dnn::RnnStateTensorDescriptor & output_c_desc,DeviceMemory<Eigen::half> * output_c_data,bool is_training,ScratchAllocator * reserve_space_allocator,ScratchAllocator * workspace_allocator,dnn::ProfileResult * output_profile_result)4845 Stream &Stream::ThenRnnForward(
4846 const dnn::RnnDescriptor &rnn_desc,
4847 const dnn::RnnSequenceTensorDescriptor &input_desc,
4848 const DeviceMemory<Eigen::half> &input_data,
4849 const dnn::RnnStateTensorDescriptor &input_h_desc,
4850 const DeviceMemory<Eigen::half> &input_h_data,
4851 const dnn::RnnStateTensorDescriptor &input_c_desc,
4852 const DeviceMemory<Eigen::half> &input_c_data,
4853 const DeviceMemory<Eigen::half> ¶ms,
4854 const dnn::RnnSequenceTensorDescriptor &output_desc,
4855 DeviceMemory<Eigen::half> *output_data,
4856 const dnn::RnnStateTensorDescriptor &output_h_desc,
4857 DeviceMemory<Eigen::half> *output_h_data,
4858 const dnn::RnnStateTensorDescriptor &output_c_desc,
4859 DeviceMemory<Eigen::half> *output_c_data, bool is_training,
4860 ScratchAllocator *reserve_space_allocator,
4861 ScratchAllocator *workspace_allocator,
4862 dnn::ProfileResult *output_profile_result) {
4863 // TODO(zhengxq): add VLOG PARAM calls.
4864 if (ok()) {
4865 if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
4866 auto status = dnn->DoRnnForward(
4867 this, rnn_desc, input_desc, input_data, input_h_desc, input_h_data,
4868 input_c_desc, input_c_data, params, output_desc, output_data,
4869 output_h_desc, output_h_data, output_c_desc, output_c_data,
4870 is_training, reserve_space_allocator, workspace_allocator,
4871 output_profile_result);
4872 if (!status && !output_profile_result) {
4873 SetError();
4874 }
4875 } else {
4876 SetErrorAndLogNoDnnSupport();
4877 }
4878 }
4879 return *this;
4880 }
4881
ThenRnnForward(const dnn::RnnDescriptor & rnn_desc,const dnn::RnnSequenceTensorDescriptor & input_desc,const DeviceMemory<float> & input_data,const dnn::RnnStateTensorDescriptor & input_h_desc,const DeviceMemory<float> & input_h_data,const dnn::RnnStateTensorDescriptor & input_c_desc,const DeviceMemory<float> & input_c_data,const DeviceMemory<float> & params,const dnn::RnnSequenceTensorDescriptor & output_desc,DeviceMemory<float> * output_data,const dnn::RnnStateTensorDescriptor & output_h_desc,DeviceMemory<float> * output_h_data,const dnn::RnnStateTensorDescriptor & output_c_desc,DeviceMemory<float> * output_c_data,bool is_training,ScratchAllocator * reserve_space_allocator,ScratchAllocator * workspace_allocator,dnn::ProfileResult * output_profile_result)4882 Stream &Stream::ThenRnnForward(
4883 const dnn::RnnDescriptor &rnn_desc,
4884 const dnn::RnnSequenceTensorDescriptor &input_desc,
4885 const DeviceMemory<float> &input_data,
4886 const dnn::RnnStateTensorDescriptor &input_h_desc,
4887 const DeviceMemory<float> &input_h_data,
4888 const dnn::RnnStateTensorDescriptor &input_c_desc,
4889 const DeviceMemory<float> &input_c_data, const DeviceMemory<float> ¶ms,
4890 const dnn::RnnSequenceTensorDescriptor &output_desc,
4891 DeviceMemory<float> *output_data,
4892 const dnn::RnnStateTensorDescriptor &output_h_desc,
4893 DeviceMemory<float> *output_h_data,
4894 const dnn::RnnStateTensorDescriptor &output_c_desc,
4895 DeviceMemory<float> *output_c_data, bool is_training,
4896 ScratchAllocator *reserve_space_allocator,
4897 ScratchAllocator *workspace_allocator,
4898 dnn::ProfileResult *output_profile_result) {
4899 // TODO(zhengxq): add VLOG PARAM calls.
4900 if (ok()) {
4901 if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
4902 auto status = dnn->DoRnnForward(
4903 this, rnn_desc, input_desc, input_data, input_h_desc, input_h_data,
4904 input_c_desc, input_c_data, params, output_desc, output_data,
4905 output_h_desc, output_h_data, output_c_desc, output_c_data,
4906 is_training, reserve_space_allocator, workspace_allocator,
4907 output_profile_result);
4908 if (!status && !output_profile_result) {
4909 SetError();
4910 }
4911 } else {
4912 SetErrorAndLogNoDnnSupport();
4913 }
4914 }
4915 return *this;
4916 }
4917
ThenRnnForward(const dnn::RnnDescriptor & rnn_desc,const dnn::RnnSequenceTensorDescriptor & input_desc,const DeviceMemory<double> & input_data,const dnn::RnnStateTensorDescriptor & input_h_desc,const DeviceMemory<double> & input_h_data,const dnn::RnnStateTensorDescriptor & input_c_desc,const DeviceMemory<double> & input_c_data,const DeviceMemory<double> & params,const dnn::RnnSequenceTensorDescriptor & output_desc,DeviceMemory<double> * output_data,const dnn::RnnStateTensorDescriptor & output_h_desc,DeviceMemory<double> * output_h_data,const dnn::RnnStateTensorDescriptor & output_c_desc,DeviceMemory<double> * output_c_data,bool is_training,ScratchAllocator * reserve_space_allocator,ScratchAllocator * workspace_allocator,dnn::ProfileResult * output_profile_result)4918 Stream &Stream::ThenRnnForward(
4919 const dnn::RnnDescriptor &rnn_desc,
4920 const dnn::RnnSequenceTensorDescriptor &input_desc,
4921 const DeviceMemory<double> &input_data,
4922 const dnn::RnnStateTensorDescriptor &input_h_desc,
4923 const DeviceMemory<double> &input_h_data,
4924 const dnn::RnnStateTensorDescriptor &input_c_desc,
4925 const DeviceMemory<double> &input_c_data,
4926 const DeviceMemory<double> ¶ms,
4927 const dnn::RnnSequenceTensorDescriptor &output_desc,
4928 DeviceMemory<double> *output_data,
4929 const dnn::RnnStateTensorDescriptor &output_h_desc,
4930 DeviceMemory<double> *output_h_data,
4931 const dnn::RnnStateTensorDescriptor &output_c_desc,
4932 DeviceMemory<double> *output_c_data, bool is_training,
4933 ScratchAllocator *reserve_space_allocator,
4934 ScratchAllocator *workspace_allocator,
4935 dnn::ProfileResult *output_profile_result) {
4936 // TODO(zhengxq): add VLOG PARAM calls.
4937 if (ok()) {
4938 if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
4939 auto status = dnn->DoRnnForward(
4940 this, rnn_desc, input_desc, input_data, input_h_desc, input_h_data,
4941 input_c_desc, input_c_data, params, output_desc, output_data,
4942 output_h_desc, output_h_data, output_c_desc, output_c_data,
4943 is_training, reserve_space_allocator, workspace_allocator,
4944 output_profile_result);
4945 if (!status && !output_profile_result) {
4946 SetError();
4947 }
4948 } else {
4949 SetErrorAndLogNoDnnSupport();
4950 }
4951 }
4952 return *this;
4953 }
4954
ThenRnnBackward(const dnn::RnnDescriptor & rnn_desc,const dnn::RnnSequenceTensorDescriptor & input_desc,const DeviceMemory<Eigen::half> & input_data,const dnn::RnnStateTensorDescriptor & input_h_desc,const DeviceMemory<Eigen::half> & input_h_data,const dnn::RnnStateTensorDescriptor & input_c_desc,const DeviceMemory<Eigen::half> & input_c_data,const DeviceMemory<Eigen::half> & params,const dnn::RnnSequenceTensorDescriptor & output_desc,const DeviceMemory<Eigen::half> & output_data,const dnn::RnnStateTensorDescriptor & output_h_desc,const DeviceMemory<Eigen::half> & output_h_data,const dnn::RnnStateTensorDescriptor & output_c_desc,const DeviceMemory<Eigen::half> & output_c_data,const DeviceMemory<Eigen::half> & output_backprop_data,const DeviceMemory<Eigen::half> & output_h_backprop_data,const DeviceMemory<Eigen::half> & output_c_backprop_data,DeviceMemory<Eigen::half> * input_backprop_data,DeviceMemory<Eigen::half> * input_h_backprop_data,DeviceMemory<Eigen::half> * input_c_backprop_data,DeviceMemory<Eigen::half> * params_backprop_data,DeviceMemory<uint8> * reserve_space_data,ScratchAllocator * workspace_allocator,dnn::ProfileResult * output_profile_result)4955 Stream &Stream::ThenRnnBackward(
4956 const dnn::RnnDescriptor &rnn_desc,
4957 const dnn::RnnSequenceTensorDescriptor &input_desc,
4958 const DeviceMemory<Eigen::half> &input_data,
4959 const dnn::RnnStateTensorDescriptor &input_h_desc,
4960 const DeviceMemory<Eigen::half> &input_h_data,
4961 const dnn::RnnStateTensorDescriptor &input_c_desc,
4962 const DeviceMemory<Eigen::half> &input_c_data,
4963 const DeviceMemory<Eigen::half> ¶ms,
4964 const dnn::RnnSequenceTensorDescriptor &output_desc,
4965 const DeviceMemory<Eigen::half> &output_data,
4966 const dnn::RnnStateTensorDescriptor &output_h_desc,
4967 const DeviceMemory<Eigen::half> &output_h_data,
4968 const dnn::RnnStateTensorDescriptor &output_c_desc,
4969 const DeviceMemory<Eigen::half> &output_c_data,
4970 const DeviceMemory<Eigen::half> &output_backprop_data,
4971 const DeviceMemory<Eigen::half> &output_h_backprop_data,
4972 const DeviceMemory<Eigen::half> &output_c_backprop_data,
4973 DeviceMemory<Eigen::half> *input_backprop_data,
4974 DeviceMemory<Eigen::half> *input_h_backprop_data,
4975 DeviceMemory<Eigen::half> *input_c_backprop_data,
4976 DeviceMemory<Eigen::half> *params_backprop_data,
4977 DeviceMemory<uint8> *reserve_space_data,
4978 ScratchAllocator *workspace_allocator,
4979 dnn::ProfileResult *output_profile_result) {
4980 // TODO(zhengxq): add VLOG PARAM calls.
4981 if (ok()) {
4982 if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
4983 auto status = dnn->DoRnnBackward(
4984 this, rnn_desc, input_desc, input_data, input_h_desc, input_h_data,
4985 input_c_desc, input_c_data, params, output_desc, output_data,
4986 output_h_desc, output_h_data, output_c_desc, output_c_data,
4987 output_backprop_data, output_h_backprop_data, output_c_backprop_data,
4988 input_backprop_data, input_h_backprop_data, input_c_backprop_data,
4989 params_backprop_data, reserve_space_data, workspace_allocator,
4990 output_profile_result);
4991 if (!status && !output_profile_result) {
4992 SetError();
4993 }
4994 } else {
4995 SetError();
4996 LOG(WARNING) << "Attempting to call ThenRnnBackward without DNN support";
4997 }
4998 }
4999 return *this;
5000 }
5001
ThenRnnBackward(const dnn::RnnDescriptor & rnn_desc,const dnn::RnnSequenceTensorDescriptor & input_desc,const DeviceMemory<float> & input_data,const dnn::RnnStateTensorDescriptor & input_h_desc,const DeviceMemory<float> & input_h_data,const dnn::RnnStateTensorDescriptor & input_c_desc,const DeviceMemory<float> & input_c_data,const DeviceMemory<float> & params,const dnn::RnnSequenceTensorDescriptor & output_desc,const DeviceMemory<float> & output_data,const dnn::RnnStateTensorDescriptor & output_h_desc,const DeviceMemory<float> & output_h_data,const dnn::RnnStateTensorDescriptor & output_c_desc,const DeviceMemory<float> & output_c_data,const DeviceMemory<float> & output_backprop_data,const DeviceMemory<float> & output_h_backprop_data,const DeviceMemory<float> & output_c_backprop_data,DeviceMemory<float> * input_backprop_data,DeviceMemory<float> * input_h_backprop_data,DeviceMemory<float> * input_c_backprop_data,DeviceMemory<float> * params_backprop_data,DeviceMemory<uint8> * reserve_space_data,ScratchAllocator * workspace_allocator,dnn::ProfileResult * output_profile_result)5002 Stream &Stream::ThenRnnBackward(
5003 const dnn::RnnDescriptor &rnn_desc,
5004 const dnn::RnnSequenceTensorDescriptor &input_desc,
5005 const DeviceMemory<float> &input_data,
5006 const dnn::RnnStateTensorDescriptor &input_h_desc,
5007 const DeviceMemory<float> &input_h_data,
5008 const dnn::RnnStateTensorDescriptor &input_c_desc,
5009 const DeviceMemory<float> &input_c_data, const DeviceMemory<float> ¶ms,
5010 const dnn::RnnSequenceTensorDescriptor &output_desc,
5011 const DeviceMemory<float> &output_data,
5012 const dnn::RnnStateTensorDescriptor &output_h_desc,
5013 const DeviceMemory<float> &output_h_data,
5014 const dnn::RnnStateTensorDescriptor &output_c_desc,
5015 const DeviceMemory<float> &output_c_data,
5016 const DeviceMemory<float> &output_backprop_data,
5017 const DeviceMemory<float> &output_h_backprop_data,
5018 const DeviceMemory<float> &output_c_backprop_data,
5019 DeviceMemory<float> *input_backprop_data,
5020 DeviceMemory<float> *input_h_backprop_data,
5021 DeviceMemory<float> *input_c_backprop_data,
5022 DeviceMemory<float> *params_backprop_data,
5023 DeviceMemory<uint8> *reserve_space_data,
5024 ScratchAllocator *workspace_allocator,
5025 dnn::ProfileResult *output_profile_result) {
5026 // TODO(zhengxq): add VLOG PARAM calls.
5027 if (ok()) {
5028 if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
5029 auto status = dnn->DoRnnBackward(
5030 this, rnn_desc, input_desc, input_data, input_h_desc, input_h_data,
5031 input_c_desc, input_c_data, params, output_desc, output_data,
5032 output_h_desc, output_h_data, output_c_desc, output_c_data,
5033 output_backprop_data, output_h_backprop_data, output_c_backprop_data,
5034 input_backprop_data, input_h_backprop_data, input_c_backprop_data,
5035 params_backprop_data, reserve_space_data, workspace_allocator,
5036 output_profile_result);
5037 if (!status && !output_profile_result) {
5038 SetError();
5039 }
5040 } else {
5041 SetError();
5042 LOG(WARNING) << "Attempting to call ThenRnnBackward without DNN support";
5043 }
5044 }
5045 return *this;
5046 }
5047
ThenRnnBackward(const dnn::RnnDescriptor & rnn_desc,const dnn::RnnSequenceTensorDescriptor & input_desc,const DeviceMemory<double> & input_data,const dnn::RnnStateTensorDescriptor & input_h_desc,const DeviceMemory<double> & input_h_data,const dnn::RnnStateTensorDescriptor & input_c_desc,const DeviceMemory<double> & input_c_data,const DeviceMemory<double> & params,const dnn::RnnSequenceTensorDescriptor & output_desc,const DeviceMemory<double> & output_data,const dnn::RnnStateTensorDescriptor & output_h_desc,const DeviceMemory<double> & output_h_data,const dnn::RnnStateTensorDescriptor & output_c_desc,const DeviceMemory<double> & output_c_data,const DeviceMemory<double> & output_backprop_data,const DeviceMemory<double> & output_h_backprop_data,const DeviceMemory<double> & output_c_backprop_data,DeviceMemory<double> * input_backprop_data,DeviceMemory<double> * input_h_backprop_data,DeviceMemory<double> * input_c_backprop_data,DeviceMemory<double> * params_backprop_data,DeviceMemory<uint8> * reserve_space_data,ScratchAllocator * workspace_allocator,dnn::ProfileResult * output_profile_result)5048 Stream &Stream::ThenRnnBackward(
5049 const dnn::RnnDescriptor &rnn_desc,
5050 const dnn::RnnSequenceTensorDescriptor &input_desc,
5051 const DeviceMemory<double> &input_data,
5052 const dnn::RnnStateTensorDescriptor &input_h_desc,
5053 const DeviceMemory<double> &input_h_data,
5054 const dnn::RnnStateTensorDescriptor &input_c_desc,
5055 const DeviceMemory<double> &input_c_data,
5056 const DeviceMemory<double> ¶ms,
5057 const dnn::RnnSequenceTensorDescriptor &output_desc,
5058 const DeviceMemory<double> &output_data,
5059 const dnn::RnnStateTensorDescriptor &output_h_desc,
5060 const DeviceMemory<double> &output_h_data,
5061 const dnn::RnnStateTensorDescriptor &output_c_desc,
5062 const DeviceMemory<double> &output_c_data,
5063 const DeviceMemory<double> &output_backprop_data,
5064 const DeviceMemory<double> &output_h_backprop_data,
5065 const DeviceMemory<double> &output_c_backprop_data,
5066 DeviceMemory<double> *input_backprop_data,
5067 DeviceMemory<double> *input_h_backprop_data,
5068 DeviceMemory<double> *input_c_backprop_data,
5069 DeviceMemory<double> *params_backprop_data,
5070 DeviceMemory<uint8> *reserve_space_data,
5071 ScratchAllocator *workspace_allocator,
5072 dnn::ProfileResult *output_profile_result) {
5073 // TODO(zhengxq): add VLOG PARAM calls.
5074 if (ok()) {
5075 if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
5076 auto status = dnn->DoRnnBackward(
5077 this, rnn_desc, input_desc, input_data, input_h_desc, input_h_data,
5078 input_c_desc, input_c_data, params, output_desc, output_data,
5079 output_h_desc, output_h_data, output_c_desc, output_c_data,
5080 output_backprop_data, output_h_backprop_data, output_c_backprop_data,
5081 input_backprop_data, input_h_backprop_data, input_c_backprop_data,
5082 params_backprop_data, reserve_space_data, workspace_allocator,
5083 output_profile_result);
5084 if (!status && !output_profile_result) {
5085 SetError();
5086 }
5087 } else {
5088 SetError();
5089 LOG(WARNING) << "Attempting to call ThenRnnBackward without DNN support";
5090 }
5091 }
5092 return *this;
5093 }
5094
ThenTransformTensor(const dnn::BatchDescriptor & input_desc,dnn::DataType input_type,const DeviceMemoryBase & input_data,const dnn::BatchDescriptor & output_desc,dnn::DataType output_type,float scale,DeviceMemoryBase * output_data)5095 Stream &Stream::ThenTransformTensor(const dnn::BatchDescriptor &input_desc,
5096 dnn::DataType input_type,
5097 const DeviceMemoryBase &input_data,
5098 const dnn::BatchDescriptor &output_desc,
5099 dnn::DataType output_type, float scale,
5100 DeviceMemoryBase *output_data) {
5101 VLOG_CALL(PARAM(input_desc), PARAM(input_type), PARAM(input_data),
5102 PARAM(output_desc), PARAM(output_type), PARAM(scale),
5103 PARAM(output_data));
5104 if (ok()) {
5105 if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
5106 CheckError(dnn->DoTransformTensor(this, input_desc, input_type,
5107 input_data, output_desc, output_type,
5108 scale, output_data));
5109 } else {
5110 SetErrorAndLogNoDnnSupport();
5111 }
5112 }
5113 return *this;
5114 }
5115
ThenDoHostCallback(std::function<void ()> callback)5116 Stream &Stream::ThenDoHostCallback(std::function<void()> callback) {
5117 VLOG_CALL(PARAM(callback));
5118
5119 if (!ok()) {
5120 LOG(INFO) << DebugStreamPointers()
5121 << " was in error state before adding host callback";
5122 }
5123 CheckError(parent_->HostCallback(this, std::move(callback)));
5124 return *this;
5125 }
5126
ThenDoHostCallbackWithStatus(std::function<port::Status ()> callback)5127 Stream &Stream::ThenDoHostCallbackWithStatus(
5128 std::function<port::Status()> callback) {
5129 VLOG_CALL(PARAM(callback));
5130
5131 if (!ok()) {
5132 LOG(INFO) << DebugStreamPointers()
5133 << " was in error state before adding host callback";
5134 }
5135 CheckError(parent_->HostCallback(this, std::move(callback)));
5136 return *this;
5137 }
5138
ThenFft(fft::Plan * plan,const DeviceMemory<std::complex<float>> & input,DeviceMemory<std::complex<float>> * output)5139 Stream &Stream::ThenFft(fft::Plan *plan,
5140 const DeviceMemory<std::complex<float>> &input,
5141 DeviceMemory<std::complex<float>> *output) {
5142 VLOG_CALL(PARAM(plan), PARAM(input), PARAM(output));
5143
5144 if (ok()) {
5145 if (fft::FftSupport *fft = parent_->AsFft()) {
5146 CheckError(fft->DoFft(this, plan, input, output));
5147 } else {
5148 SetError();
5149 LOG(INFO) << DebugStreamPointers()
5150 << " attempting to perform FFT operation using StreamExecutor"
5151 " without FFT support";
5152 }
5153 }
5154 return *this;
5155 }
5156
ThenFft(fft::Plan * plan,const DeviceMemory<std::complex<double>> & input,DeviceMemory<std::complex<double>> * output)5157 Stream &Stream::ThenFft(fft::Plan *plan,
5158 const DeviceMemory<std::complex<double>> &input,
5159 DeviceMemory<std::complex<double>> *output) {
5160 VLOG_CALL(PARAM(plan), PARAM(input), PARAM(output));
5161
5162 if (ok()) {
5163 if (fft::FftSupport *fft = parent_->AsFft()) {
5164 CheckError(fft->DoFft(this, plan, input, output));
5165 } else {
5166 SetError();
5167 LOG(INFO) << DebugStreamPointers()
5168 << " attempting to perform FFT operation using StreamExecutor"
5169 " without FFT support";
5170 }
5171 }
5172 return *this;
5173 }
5174
ThenFft(fft::Plan * plan,const DeviceMemory<float> & input,DeviceMemory<std::complex<float>> * output)5175 Stream &Stream::ThenFft(fft::Plan *plan, const DeviceMemory<float> &input,
5176 DeviceMemory<std::complex<float>> *output) {
5177 VLOG_CALL(PARAM(plan), PARAM(input), PARAM(output));
5178
5179 if (ok()) {
5180 if (fft::FftSupport *fft = parent_->AsFft()) {
5181 CheckError(fft->DoFft(this, plan, input, output));
5182 } else {
5183 SetError();
5184 LOG(INFO) << DebugStreamPointers()
5185 << " attempting to perform FFT operation using StreamExecutor"
5186 " without FFT support";
5187 }
5188 }
5189 return *this;
5190 }
5191
ThenFft(fft::Plan * plan,const DeviceMemory<double> & input,DeviceMemory<std::complex<double>> * output)5192 Stream &Stream::ThenFft(fft::Plan *plan, const DeviceMemory<double> &input,
5193 DeviceMemory<std::complex<double>> *output) {
5194 VLOG_CALL(PARAM(plan), PARAM(input), PARAM(output));
5195
5196 if (ok()) {
5197 if (fft::FftSupport *fft = parent_->AsFft()) {
5198 CheckError(fft->DoFft(this, plan, input, output));
5199 } else {
5200 SetError();
5201 LOG(INFO) << DebugStreamPointers()
5202 << " attempting to perform FFT operation using StreamExecutor"
5203 " without FFT support";
5204 }
5205 }
5206 return *this;
5207 }
5208
ThenFft(fft::Plan * plan,const DeviceMemory<std::complex<float>> & input,DeviceMemory<float> * output)5209 Stream &Stream::ThenFft(fft::Plan *plan,
5210 const DeviceMemory<std::complex<float>> &input,
5211 DeviceMemory<float> *output) {
5212 VLOG_CALL(PARAM(plan), PARAM(input), PARAM(output));
5213
5214 if (ok()) {
5215 if (fft::FftSupport *fft = parent_->AsFft()) {
5216 CheckError(fft->DoFft(this, plan, input, output));
5217 } else {
5218 SetError();
5219 LOG(INFO) << DebugStreamPointers()
5220 << " attempting to perform FFT operation using StreamExecutor"
5221 " without FFT support";
5222 }
5223 }
5224 return *this;
5225 }
5226
ThenFft(fft::Plan * plan,const DeviceMemory<std::complex<double>> & input,DeviceMemory<double> * output)5227 Stream &Stream::ThenFft(fft::Plan *plan,
5228 const DeviceMemory<std::complex<double>> &input,
5229 DeviceMemory<double> *output) {
5230 VLOG_CALL(PARAM(plan), PARAM(input), PARAM(output));
5231
5232 if (ok()) {
5233 if (fft::FftSupport *fft = parent_->AsFft()) {
5234 CheckError(fft->DoFft(this, plan, input, output));
5235 } else {
5236 SetError();
5237 LOG(INFO) << DebugStreamPointers()
5238 << " attempting to perform FFT operation using StreamExecutor"
5239 " without FFT support";
5240 }
5241 }
5242 return *this;
5243 }
5244
5245 // It looks confusing, but all this is doing is inserting a callback at the
5246 // present point in the stream to then enqueue a task on the host executor.
ThenEnqueueOnBackgroundThread(std::function<void (StreamExecutor *)> task)5247 Stream &Stream::ThenEnqueueOnBackgroundThread(
5248 std::function<void(StreamExecutor *)> task) {
5249 VLOG_CALL(PARAM(task));
5250
5251 StreamExecutor *stream_executor = this->parent_;
5252 std::function<void()> bound_task = std::bind(task, stream_executor);
5253
5254 return ThenDoHostCallback([stream_executor, bound_task]() {
5255 stream_executor->EnqueueOnBackgroundThread(bound_task);
5256 });
5257 }
5258
BlockHostUntilDone()5259 port::Status Stream::BlockHostUntilDone() {
5260 VLOG_CALL();
5261
5262 if (!ok()) {
5263 port::Status status = port::Status(
5264 port::error::INTERNAL,
5265 "stream did not block host until done; was already in an error state");
5266 LOG(INFO) << DebugStreamPointers() << " " << status;
5267 return status;
5268 }
5269
5270 temporary_memory_manager_.DeallocateFinalizedTemporaries();
5271
5272 port::Status error = parent_->BlockHostUntilDone(this);
5273 CheckError(error.ok());
5274 return error;
5275 }
5276
DebugStreamPointers() const5277 string Stream::DebugStreamPointers() const {
5278 // Relies on the ToVlogString(const void*) overload above.
5279 return absl::StrCat("[stream=", ToVlogString(this),
5280 ",impl=", ToVlogString(implementation_.get()), "]");
5281 }
5282
CheckStatus(port::Status status)5283 void Stream::CheckStatus(port::Status status) {
5284 if (status.ok()) {
5285 return;
5286 }
5287 LOG(ERROR) << status;
5288 mutex_lock lock(mu_);
5289 ok_ = false;
5290 }
5291
5292 } // namespace stream_executor
5293