1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // The Stream is used in conjunction with the StreamExecutor "parent" to
17 // perform actions with a linear stream of dependencies. Dependencies can also
18 // be created between Streams to do task management (i.e. limit which tasks
19 // can be performed concurrently and specify what task dependencies exist).
20 
21 #ifndef TENSORFLOW_STREAM_EXECUTOR_STREAM_H_
22 #define TENSORFLOW_STREAM_EXECUTOR_STREAM_H_
23 
24 #include <complex>
25 #include <functional>
26 #include <memory>
27 
28 #include "tensorflow/core/platform/macros.h"
29 #include "tensorflow/stream_executor/blas.h"
30 #include "tensorflow/stream_executor/device_memory.h"
31 #include "tensorflow/stream_executor/dnn.h"
32 #include "tensorflow/stream_executor/event.h"
33 #include "tensorflow/stream_executor/fft.h"
34 #include "tensorflow/stream_executor/host_or_device_scalar.h"
35 #include "tensorflow/stream_executor/kernel.h"
36 #include "tensorflow/stream_executor/launch_dim.h"
37 #include "tensorflow/stream_executor/lib/array_slice.h"
38 #include "tensorflow/stream_executor/platform/mutex.h"
39 #include "tensorflow/stream_executor/platform/port.h"
40 #include "tensorflow/stream_executor/platform/thread_annotations.h"
41 #include "tensorflow/stream_executor/temporary_memory_manager.h"
42 
43 namespace stream_executor {
44 
45 namespace host {
46 class HostBlas;
47 class HostFft;
48 class HostRng;
49 class HostTimer;
50 }  // namespace host
51 
52 namespace ocl {
53 class CLBlas;
54 }  // namespace ocl
55 
56 namespace internal {
57 class StreamInterface;
58 }  // namespace internal
59 
60 class DeviceMemoryBase;
61 template <typename ElemT>
62 class DeviceMemory;
63 
64 class Timer;
65 
66 namespace dnn {
67 class BatchDescriptor;
68 class FilterDescriptor;
69 class ConvolutionDescriptor;
70 class ProfileResult;
71 class AlgorithmDesc;
72 }  // namespace dnn
73 
74 class StreamExecutor;
75 class ScratchAllocator;
76 
77 // Convert a type to the corresponding QuantizedActivationMode.
78 template <typename ElementType>
79 struct Quantization;
80 
81 // Represents a stream of dependent computations on a GPU device.
82 //
83 // The operations within a stream execute linearly and asynchronously until
84 // BlockHostUntilDone() is invoked, which synchronously joins host code with
85 // the execution of the stream.
86 //
87 // If any given operation fails when entraining work for the stream, ok() will
88 // indicate that an error has occurred. After initialization, once a stream is
89 // !ok(), it will never be ok().
90 //
91 // Thread-safe post-initialization.
92 class Stream {
93  public:
94   // Instantiate a stream tied to parent as a platform executor. Work
95   // entrained onto this stream will be launched/managed on that
96   // StreamExecutor's platform.
97   explicit Stream(StreamExecutor *parent);
98 
99   // Test only. Use an externally-populated value (like a mock) for the
100   // platform-specific stream implementation.
101   Stream(StreamExecutor *parent, internal::StreamInterface *implementation);
102 
103   // Deallocates any stream resources that the parent StreamExecutor has
104   // bestowed
105   // upon this object.
106   ~Stream();
107 
108   // Returns whether any errors have occurred while entraining work for this
109   // stream.
ok()110   bool ok() const { return !InErrorState(); }
111 
112   // Retrieves execution status back into the stream from the underlying
113   // implementation without blocking the stream.
114   //
115   // Normally, Stream::BlockHostUntilDone is used to get execution status.
116   // However, some devices use out-of-band mechnanisms to ensure their streams
117   // have finished on-device work, without needing to block the streams. (These
118   // devices should also override AllowsSyncOnCompletion to return false.) For
119   // these devices, this method can be used after work is finished to retrieve
120   // execution status.
121   port::Status RefreshStatus() LOCKS_EXCLUDED(mu_);
122 
123   // Initialize the stream. This must be performed before entraining any other
124   // operations.
125   Stream &Init() LOCKS_EXCLUDED(mu_);
126 
127   // Initializes timer t via the StreamExecutor.
128   Stream &InitTimer(Timer *t);
129 
130   // Convenience wrapper around Init() and InitTimer().
131   Stream &InitWithTimer(Timer *t);
132 
133   // Get or create a sub-stream from this stream. If there is any sub-stream in
134   // the pool that can be reused then just return this sub-stream.  Otherwise
135   // create a new sub-stream.
136   //
137   // TODO(b/112196569): The semantics of failed sub-streams is error-prone.
138   Stream *GetOrCreateSubStream() LOCKS_EXCLUDED(mu_);
139 
140   // Return the sub-stream back to the host stream so that it can be reused
141   // later. Sub-streams that are !ok() will not be reused.
142   //
143   // TODO(b/112196569): The semantics of failed sub-streams is error-prone.
144   void ReturnSubStream(Stream *sub_stream) LOCKS_EXCLUDED(mu_);
145 
146   // Allocate temporary memories. The stream will deallocate them when blocked
147   // or destroyed.
148   template <typename T>
149   port::StatusOr<std::unique_ptr<TemporaryDeviceMemory<T>>>
150   AllocateTemporaryArray(uint64 element_count);
151 
152   // Entrains onto the stream of operations: a kernel launch with the given
153   // (variadic) parameters for the invocation. These arguments can be things
154   // like DeviceMemory or primitive types such as int. What arguments you may
155   // pass to a given kernel are noted as the template parameters to the
156   // TypedKernel type that the machocc compiler generates.
157   //
158   // Template parameters:
159   //  Params...   The type list of formal parameters that the typed kernel
160   //              expects, which is matched against Args...
161   //  Args...     The deduced type list for passed actual arguments
162   //
163   // Implementation: A compile-time compatibility check is performed that has
164   // some leniency versus an exact parameter pack match -- for example,
165   // `const DeviceMemory<T>` is considered "pack compatible" with a
166   // `const DeviceMemory<T>&` formal parameter; in part, because we don't have
167   // perfect forwarding support without rvalue references. It also attempts to
168   // spit out helpful static_assert error traces with information as to the
169   // argument number and types that were mismatched.
170   template <typename... Params, typename... Args>
171   Stream &ThenLaunch(ThreadDim thread_dims, BlockDim block_dims,
172                      const TypedKernel<Params...> &kernel, Args... args);
173 
174   // Record a "start" event for the interval timer at this point in the
175   // stream's execution (relative to the previously and subsequently enqueued
176   // items in the stream's execution). Streams may be started/stopped multiple
177   // times.
178   Stream &ThenStartTimer(Timer *t);
179 
180   // Record a "stop" event for the interval timer at this point in the
181   // stream's execution. See also Stream::ThenStartTimer.
182   Stream &ThenStopTimer(Timer *t);
183 
184   // TODO(leary) If work is added to the stream that is being depended upon,
185   //              then what? Have to describe what happens.
186   template <typename... Params>
ThenWaitFor(Stream * other,Params...more_streams)187   Stream &ThenWaitFor(Stream *other, Params... more_streams) {
188     return ThenWaitFor(more_streams...).ThenWaitFor(other);
189   }
190 
191   // Create a dependency for this stream's next work on the other stream
192   // completing. Does not take ownership of other, and other must not be
193   // null.
194   //
195   // Checks that a stream does not wait for itself, and it is up to the
196   // user to guarantee that a stream does not come to wait on itself in a
197   // cyclic manner; in that case, behavior is undefined.
198   //
199   // N.B. Base recursion case for the variadic ThenWaitFor.
200   Stream &ThenWaitFor(Stream *other);
201 
202   // Waits for all streams values in others.
203   // Checks that there is no shallow circular wait (i.e. that "this" is not in
204   // others)
205   template <typename P>
ThenWaitFor(P others)206   Stream &ThenWaitFor(P others) {
207     for (auto &stream : *others) {
208       CHECK_NE(stream.get(), this);
209       ThenWaitFor(stream.get());
210     }
211     return *this;
212   }
213 
214   // Waits for an event object to be set.
215   // Note that ThenRecordEvent must have been called on the event before
216   // you call this function; otherwise the event will be considered complete
217   // and this wait will do nothing.
218   Stream &ThenWaitFor(Event *event);
219 
220   // Inserts the specified event into the end of this stream. Once the stream
221   // has processed all events prior to the insertion point, the event will be
222   // marked as completed.
223   // The stream does not take ownership of event - meaning that event's lifetime
224   // must extend past the point at which it is marked complete!
225   Stream &ThenRecordEvent(Event *event);
226 
227   ////////////////
228   // DNN support
229   //
230   // See DnnSupport::* for comments on the following methods.
231 
232   Stream &ThenBatchNormalizationForward(
233       const DeviceMemory<float> &x, const DeviceMemory<float> &scale,
234       const DeviceMemory<float> &offset,
235       const DeviceMemory<float> &estimated_mean,
236       const DeviceMemory<float> &estimated_variance,
237       const dnn::BatchDescriptor &x_desc,
238       const dnn::BatchDescriptor &scale_offset_desc, const double epsilon,
239       DeviceMemory<float> *y, DeviceMemory<float> *batch_mean,
240       DeviceMemory<float> *batch_var, DeviceMemory<float> *saved_mean,
241       DeviceMemory<float> *saved_inv_var, bool is_training,
242       std::function<const DeviceMemory<float> &()> var_to_inv_var,
243       std::function<void()> inv_var_to_var);
244 
245   Stream &ThenBatchNormalizationBackward(
246       const DeviceMemory<float> &y_backprop, const DeviceMemory<float> &x,
247       const DeviceMemory<float> &scale, const DeviceMemory<float> &mean,
248       const DeviceMemory<float> &inv_var, const dnn::BatchDescriptor &x_desc,
249       const dnn::BatchDescriptor &scale_offset_desc, const double epsilon,
250       DeviceMemory<float> *x_backprop, DeviceMemory<float> *scale_backprop,
251       DeviceMemory<float> *offset_backprop);
252 
253   Stream &ThenBatchNormalizationForward(
254       const DeviceMemory<Eigen::half> &x, const DeviceMemory<float> &scale,
255       const DeviceMemory<float> &offset,
256       const DeviceMemory<float> &estimated_mean,
257       const DeviceMemory<float> &estimated_variance,
258       const dnn::BatchDescriptor &x_desc,
259       const dnn::BatchDescriptor &scale_offset_desc, const double epsilon,
260       DeviceMemory<Eigen::half> *y, DeviceMemory<float> *batch_mean,
261       DeviceMemory<float> *batch_var, DeviceMemory<float> *saved_mean,
262       DeviceMemory<float> *saved_inv_var, bool is_training,
263       std::function<const DeviceMemory<float> &()> var_to_inv_var,
264       std::function<void()> inv_var_to_var);
265 
266   Stream &ThenBatchNormalizationBackward(
267       const DeviceMemory<Eigen::half> &y_backprop,
268       const DeviceMemory<Eigen::half> &x, const DeviceMemory<float> &scale,
269       const DeviceMemory<float> &mean, const DeviceMemory<float> &inv_var,
270       const dnn::BatchDescriptor &x_desc,
271       const dnn::BatchDescriptor &scale_offset_desc, const double epsilon,
272       DeviceMemory<Eigen::half> *x_backprop,
273       DeviceMemory<float> *scale_backprop,
274       DeviceMemory<float> *offset_backprop);
275 
276   Stream &ThenConvolve(const dnn::BatchDescriptor &input_descriptor,
277                        const DeviceMemory<float> &input_data,
278                        const dnn::FilterDescriptor &filter_descriptor,
279                        const DeviceMemory<float> &filter_data,
280                        const dnn::ConvolutionDescriptor &convolution_descriptor,
281                        const dnn::BatchDescriptor &output_descriptor,
282                        DeviceMemory<float> *output);
283 
284   Stream &ThenConvolveQuantized(
285       const dnn::BatchDescriptor &input_descriptor,
286       const DeviceMemory<float> &input_data,
287       const dnn::FilterDescriptor &filter_descriptor,
288       const DeviceMemory<int8> &filter_coefficients,
289       const DeviceMemory<float> &coefficient_scales,
290       const dnn::ConvolutionDescriptor &convolution_descriptor,
291       const dnn::BatchDescriptor &output_descriptor,
292       DeviceMemory<float> *output_data);
293 
294   Stream &ThenConvolveQuantized(
295       const dnn::BatchDescriptor &input_descriptor,
296       const DeviceMemory<float> &input_data,
297       const dnn::FilterDescriptor &filter_descriptor,
298       const DeviceMemory<int16> &filter_coefficients,
299       const DeviceMemory<float> &coefficient_scales,
300       const dnn::ConvolutionDescriptor &convolution_descriptor,
301       const dnn::BatchDescriptor &output_descriptor,
302       DeviceMemory<float> *output_data);
303 
304   Stream &ThenConvolveWithAlgorithm(
305       const dnn::BatchDescriptor &input_descriptor,
306       const DeviceMemory<double> &input_data,
307       const dnn::FilterDescriptor &filter_descriptor,
308       const DeviceMemory<double> &filter_data,
309       const dnn::ConvolutionDescriptor &convolution_descriptor,
310       const dnn::BatchDescriptor &output_descriptor,
311       DeviceMemory<double> *output, ScratchAllocator *scratch_allocator,
312       const dnn::AlgorithmConfig &algorithm_config,
313       dnn::ProfileResult *output_profile_result);
314 
315   Stream &ThenConvolveWithAlgorithm(
316       const dnn::BatchDescriptor &input_descriptor,
317       const DeviceMemory<float> &input_data,
318       const dnn::FilterDescriptor &filter_descriptor,
319       const DeviceMemory<float> &filter_data,
320       const dnn::ConvolutionDescriptor &convolution_descriptor,
321       const dnn::BatchDescriptor &output_descriptor,
322       DeviceMemory<float> *output, ScratchAllocator *scratch_allocator,
323       const dnn::AlgorithmConfig &algorithm_config,
324       dnn::ProfileResult *output_profile_result);
325 
326   Stream &ThenConvolveWithAlgorithm(
327       const dnn::BatchDescriptor &input_descriptor,
328       const DeviceMemory<Eigen::half> &input_data,
329       const dnn::FilterDescriptor &filter_descriptor,
330       const DeviceMemory<Eigen::half> &filter_data,
331       const dnn::ConvolutionDescriptor &convolution_descriptor,
332       const dnn::BatchDescriptor &output_descriptor,
333       DeviceMemory<Eigen::half> *output, ScratchAllocator *scratch_allocator,
334       const dnn::AlgorithmConfig &algorithm_config,
335       dnn::ProfileResult *output_profile_result);
336 
337   Stream &ThenFusedConvolveWithAlgorithm(
338       const dnn::BatchDescriptor &conv_input_descriptor,
339       const DeviceMemory<double> &conv_input_data, double conv_input_scale,
340       const dnn::FilterDescriptor &filter_descriptor,
341       const DeviceMemory<double> &filter_data,
342       const dnn::ConvolutionDescriptor &convolution_descriptor,
343       const DeviceMemory<double> &side_input_data, double side_input_scale,
344       const dnn::BatchDescriptor &bias_descriptor,
345       const DeviceMemory<double> &biases, dnn::ActivationMode activation_mode,
346       const dnn::BatchDescriptor &output_descriptor,
347       DeviceMemory<double> *output, ScratchAllocator *scratch_allocator,
348       const dnn::AlgorithmConfig &algorithm_config,
349       dnn::ProfileResult *output_profile_result);
350 
351   Stream &ThenFusedConvolveWithAlgorithm(
352       const dnn::BatchDescriptor &conv_input_descriptor,
353       const DeviceMemory<float> &conv_input_data, float conv_input_scale,
354       const dnn::FilterDescriptor &filter_descriptor,
355       const DeviceMemory<float> &filter_data,
356       const dnn::ConvolutionDescriptor &convolution_descriptor,
357       const DeviceMemory<float> &side_input_data, float side_input_scale,
358       const dnn::BatchDescriptor &bias_descriptor,
359       const DeviceMemory<float> &biases, dnn::ActivationMode activation_mode,
360       const dnn::BatchDescriptor &output_descriptor,
361       DeviceMemory<float> *output, ScratchAllocator *scratch_allocator,
362       const dnn::AlgorithmConfig &algorithm_config,
363       dnn::ProfileResult *output_profile_result);
364 
365   Stream &ThenFusedConvolveWithAlgorithm(
366       const dnn::BatchDescriptor &conv_input_descriptor,
367       const DeviceMemory<Eigen::half> &conv_input_data, float conv_input_scale,
368       const dnn::FilterDescriptor &filter_descriptor,
369       const DeviceMemory<Eigen::half> &filter_data,
370       const dnn::ConvolutionDescriptor &convolution_descriptor,
371       const DeviceMemory<Eigen::half> &side_input_data, float side_input_scale,
372       const dnn::BatchDescriptor &bias_descriptor,
373       const DeviceMemory<Eigen::half> &biases,
374       dnn::ActivationMode activation_mode,
375       const dnn::BatchDescriptor &output_descriptor,
376       DeviceMemory<Eigen::half> *output, ScratchAllocator *scratch_allocator,
377       const dnn::AlgorithmConfig &algorithm_config,
378       dnn::ProfileResult *output_profile_result);
379 
380   Stream &ThenFusedConvolveWithAlgorithm(
381       const dnn::BatchDescriptor &conv_input_descriptor,
382       const DeviceMemory<int8> &conv_input_data, float conv_input_scale,
383       const dnn::FilterDescriptor &filter_descriptor,
384       const DeviceMemory<int8> &filter_data,
385       const dnn::ConvolutionDescriptor &convolution_descriptor,
386       const DeviceMemory<int8> &side_input_data, float side_input_scale,
387       const dnn::BatchDescriptor &bias_descriptor,
388       const DeviceMemory<float> &biases, dnn::ActivationMode activation_mode,
389       const dnn::BatchDescriptor &output_descriptor, DeviceMemory<int8> *output,
390       ScratchAllocator *scratch_allocator,
391       const dnn::AlgorithmConfig &algorithm_config,
392       dnn::ProfileResult *output_profile_result);
393 
394   Stream &ThenSeparableConvolve(
395       const dnn::BatchDescriptor &input_descriptor,
396       const DeviceMemory<float> &input_data,
397       const dnn::FilterDescriptor &filter_descriptor, int depth_multiplier,
398       const DeviceMemory<float> &first_weights,
399       const DeviceMemory<float> &second_weights,
400       const dnn::ConvolutionDescriptor &convolution_descriptor,
401       const dnn::BatchDescriptor &output_descriptor,
402       DeviceMemory<float> *output);
403 
404   Stream &ThenConvolveBackwardDataWithAlgorithm(
405       const dnn::FilterDescriptor &filter_descriptor,
406       const DeviceMemory<double> &filter_data,
407       const dnn::BatchDescriptor &output_descriptor,
408       DeviceMemory<double> backward_output_data,
409       const dnn::ConvolutionDescriptor &convolution_descriptor,
410       const dnn::BatchDescriptor &input_descriptor,
411       DeviceMemory<double> *backward_input_data,
412       ScratchAllocator *scratch_allocator,
413       const dnn::AlgorithmConfig &algorithm_config,
414       dnn::ProfileResult *output_profile_result);
415 
416   Stream &ThenConvolveBackwardDataWithAlgorithm(
417       const dnn::FilterDescriptor &filter_descriptor,
418       const DeviceMemory<float> &filter_data,
419       const dnn::BatchDescriptor &output_descriptor,
420       DeviceMemory<float> backward_output_data,
421       const dnn::ConvolutionDescriptor &convolution_descriptor,
422       const dnn::BatchDescriptor &input_descriptor,
423       DeviceMemory<float> *backward_input_data,
424       ScratchAllocator *scratch_allocator,
425       const dnn::AlgorithmConfig &algorithm_config,
426       dnn::ProfileResult *output_profile_result);
427 
428   Stream &ThenConvolveBackwardDataWithAlgorithm(
429       const dnn::FilterDescriptor &filter_descriptor,
430       const DeviceMemory<Eigen::half> &filter_data,
431       const dnn::BatchDescriptor &output_descriptor,
432       DeviceMemory<Eigen::half> backward_output_data,
433       const dnn::ConvolutionDescriptor &convolution_descriptor,
434       const dnn::BatchDescriptor &input_descriptor,
435       DeviceMemory<Eigen::half> *backward_input_data,
436       ScratchAllocator *scratch_allocator,
437       const dnn::AlgorithmConfig &algorithm_config,
438       dnn::ProfileResult *output_profile_result);
439 
440   Stream &ThenConvolveBackwardFilterWithAlgorithm(
441       const dnn::BatchDescriptor &input_descriptor,
442       const DeviceMemory<double> &input_data,
443       const dnn::BatchDescriptor &output_descriptor,
444       DeviceMemory<double> backward_output_data,
445       const dnn::ConvolutionDescriptor &convolution_descriptor,
446       const dnn::FilterDescriptor &filter_descriptor,
447       DeviceMemory<double> *backward_filter_data,
448       ScratchAllocator *scratch_allocator,
449       const dnn::AlgorithmConfig &algorithm_config,
450       dnn::ProfileResult *output_profile_result);
451 
452   Stream &ThenConvolveBackwardFilterWithAlgorithm(
453       const dnn::BatchDescriptor &input_descriptor,
454       const DeviceMemory<float> &input_data,
455       const dnn::BatchDescriptor &output_descriptor,
456       DeviceMemory<float> backward_output_data,
457       const dnn::ConvolutionDescriptor &convolution_descriptor,
458       const dnn::FilterDescriptor &filter_descriptor,
459       DeviceMemory<float> *backward_filter_data,
460       ScratchAllocator *scratch_allocator,
461       const dnn::AlgorithmConfig &algorithm_config,
462       dnn::ProfileResult *output_profile_result);
463 
464   Stream &ThenConvolveBackwardFilterWithAlgorithm(
465       const dnn::BatchDescriptor &input_descriptor,
466       const DeviceMemory<Eigen::half> &input_data,
467       const dnn::BatchDescriptor &output_descriptor,
468       DeviceMemory<Eigen::half> backward_output_data,
469       const dnn::ConvolutionDescriptor &convolution_descriptor,
470       const dnn::FilterDescriptor &filter_descriptor,
471       DeviceMemory<Eigen::half> *backward_filter_data,
472       ScratchAllocator *scratch_allocator,
473       const dnn::AlgorithmConfig &algorithm_config,
474       dnn::ProfileResult *output_profile_result);
475 
476   Stream &ThenConvolveBackwardBias(const dnn::BatchDescriptor &input_descriptor,
477                                    const DeviceMemory<double> &input_data,
478                                    const dnn::BatchDescriptor &bias_descriptor,
479                                    DeviceMemory<double> *backward_bias_data);
480 
481   Stream &ThenConvolveBackwardBias(const dnn::BatchDescriptor &input_descriptor,
482                                    const DeviceMemory<float> &input_data,
483                                    const dnn::BatchDescriptor &bias_descriptor,
484                                    DeviceMemory<float> *backward_bias_data);
485 
486   Stream &ThenConvolveBackwardBias(
487       const dnn::BatchDescriptor &input_descriptor,
488       const DeviceMemory<Eigen::half> &input_data,
489       const dnn::BatchDescriptor &bias_descriptor,
490       DeviceMemory<Eigen::half> *backward_bias_data);
491 
492   Stream &ThenMatMul(const DeviceMemory<float> &input_data,
493                      const DeviceMemory<float> &weights,
494                      const dnn::BatchDescriptor &input_dimensions,
495                      const dnn::BatchDescriptor &output_dimensions,
496                      DeviceMemory<float> *output_data);
497 
498   Stream &ThenMatMulQuantized(const DeviceMemory<float> &input_data,
499                               const DeviceMemory<int8> &weights,
500                               const DeviceMemory<float> &weight_scales,
501                               const dnn::BatchDescriptor &input_dimensions,
502                               const dnn::BatchDescriptor &output_dimensions,
503                               DeviceMemory<float> *output_data);
504 
505   Stream &ThenMatMulQuantized(const DeviceMemory<float> &input_data,
506                               const DeviceMemory<int16> &weights,
507                               const DeviceMemory<float> &weight_scales,
508                               const dnn::BatchDescriptor &input_dimensions,
509                               const dnn::BatchDescriptor &output_dimensions,
510                               DeviceMemory<float> *output_data);
511 
512   Stream &ThenBiasAdd(const DeviceMemory<float> &input_data,
513                       const DeviceMemory<float> &biases,
514                       const dnn::BatchDescriptor &dimensions,
515                       DeviceMemory<float> *output_data);
516 
517   Stream &ThenPoolForward(const dnn::PoolingDescriptor &pooling_dimensions,
518                           const dnn::BatchDescriptor &input_dimensions,
519                           const DeviceMemory<double> &input_data,
520                           const dnn::BatchDescriptor &output_dimensions,
521                           DeviceMemory<double> *output_data,
522                           ScratchAllocator *workspace_allocator = nullptr);
523 
524   Stream &ThenPoolForward(const dnn::PoolingDescriptor &pooling_dimensions,
525                           const dnn::BatchDescriptor &input_dimensions,
526                           const DeviceMemory<float> &input_data,
527                           const dnn::BatchDescriptor &output_dimensions,
528                           DeviceMemory<float> *output_data,
529                           ScratchAllocator *workspace_allocator = nullptr);
530 
531   Stream &ThenPoolForward(const dnn::PoolingDescriptor &pooling_dimensions,
532                           const dnn::BatchDescriptor &input_dimensions,
533                           const DeviceMemory<Eigen::half> &input_data,
534                           const dnn::BatchDescriptor &output_dimensions,
535                           DeviceMemory<Eigen::half> *output_data,
536                           ScratchAllocator *workspace_allocator = nullptr);
537 
538   Stream &ThenPoolForward(const dnn::PoolingDescriptor &pooling_dimensions,
539                           const dnn::BatchDescriptor &input_dimensions,
540                           const DeviceMemory<int8> &input_data,
541                           const dnn::BatchDescriptor &output_dimensions,
542                           DeviceMemory<int8> *output_data,
543                           ScratchAllocator *workspace_allocator = nullptr);
544 
545   Stream &ThenPoolBackward(const dnn::PoolingDescriptor &pooling_dimensions,
546                            const dnn::BatchDescriptor &input_dimensions,
547                            const DeviceMemory<double> &input_data,
548                            const dnn::BatchDescriptor &output_dimensions,
549                            const DeviceMemory<double> &output_data,
550                            const DeviceMemory<double> &input_diff_data,
551                            DeviceMemory<double> *output_diff_data,
552                            ScratchAllocator *workspace_allocator = nullptr);
553 
554   Stream &ThenPoolBackward(const dnn::PoolingDescriptor &pooling_dimensions,
555                            const dnn::BatchDescriptor &input_dimensions,
556                            const DeviceMemory<float> &input_data,
557                            const dnn::BatchDescriptor &output_dimensions,
558                            const DeviceMemory<float> &output_data,
559                            const DeviceMemory<float> &input_diff_data,
560                            DeviceMemory<float> *output_diff_data,
561                            ScratchAllocator *workspace_allocator = nullptr);
562 
563   Stream &ThenPoolBackward(const dnn::PoolingDescriptor &pooling_dimensions,
564                            const dnn::BatchDescriptor &input_dimensions,
565                            const DeviceMemory<Eigen::half> &input_data,
566                            const dnn::BatchDescriptor &output_dimensions,
567                            const DeviceMemory<Eigen::half> &output_data,
568                            const DeviceMemory<Eigen::half> &input_diff_data,
569                            DeviceMemory<Eigen::half> *output_diff_data,
570                            ScratchAllocator *workspace_allocator = nullptr);
571 
572   Stream &ThenNormalizeWithDimensions(
573       const dnn::NormalizeDescriptor &normalize_descriptor,
574       const dnn::BatchDescriptor &dimensions,
575       const DeviceMemory<float> &input_data, DeviceMemory<float> *output_data);
576 
577   Stream &ThenNormalizeBackwardWithDimensions(
578       const dnn::NormalizeDescriptor &normalize_descriptor,
579       const dnn::BatchDescriptor &dimensions,
580       const DeviceMemory<float> &raw_data,
581       const DeviceMemory<float> &normalized_data,
582       const DeviceMemory<float> &normalized_variable_gradient,
583       DeviceMemory<float> *raw_variable_gradient,
584       ScratchAllocator *workspace_allocator = nullptr);
585 
586   Stream &ThenActivate(dnn::ActivationMode activation_mode,
587                        const dnn::BatchDescriptor &dimensions,
588                        const DeviceMemory<float> &input_data,
589                        DeviceMemory<float> *output_data);
590 
591   // Same as ThenActivate, but also takes an options argument that can be used
592   // for platform-specific option flags.
593   Stream &ThenActivateWithOptions(dnn::ActivationMode activation_mode,
594                                   const dnn::BatchDescriptor &dimensions,
595                                   const DeviceMemory<float> &input_data,
596                                   DeviceMemory<float> *output_data,
597                                   uint64 options);
598 
599   Stream &ThenDepthConcatenate(
600       port::ArraySlice<dnn::BatchDescriptor> input_dimensions,
601       port::ArraySlice<const DeviceMemory<float> *> input_data,
602       DeviceMemory<float> *output_data);
603 
604   Stream &ThenSpaceConcatenate(
605       port::ArraySlice<dnn::BatchDescriptor> input_dimensions,
606       port::ArraySlice<const DeviceMemory<float> *> input_data,
607       DeviceMemory<float> *output_data,
608       dnn::SpaceConcatenateMode concat_direction);
609 
610   // Change the layout of the data by shrinking one dimension (or set of
611   // dimensions) and growing another dimension (or set of dimensions), while
612   // keeping the total number of data elements constant, and maintaining the
613   // current data ordering.
614   Stream &ThenReshape(const dnn::BatchDescriptor &input_dimensions,
615                       const DeviceMemory<float> &input_data,
616                       const dnn::BatchDescriptor &output_dimensions,
617                       DeviceMemory<float> *output_data);
618 
619   // Depth to space takes an X by Y image with depth D*M² and changes it to an
620   // MX x MY image with depth D. Each input location (x,y) with depth D*M² in
621   // the input image is changed to an MxM contiguous area in the output image,
622   // with the values being laid out in raster order specified by
623   // DepthToSpaceLayout, and will have a new depth of D.
624   // See the DoDepthToSpace comment for more information.
625   Stream &ThenDepthToSpace(const dnn::BatchDescriptor &input_dimensions,
626                            const DeviceMemory<float> &input_data,
627                            const dnn::DepthToSpaceLayout &depth_to_space_layout,
628                            const int sqrt_depth_reduction,
629                            DeviceMemory<float> *output_data);
630 
631   // Space to depth is the inverse of depth to space. Space to depth takes each
632   // non-overlapping M by M patch (in the X and Y dimensions) with depth D of
633   // the input, and transforms it to a 1 by 1 patch with depth D*M². If the
634   // input has size (MX, MY, D), the output has size (X, Y, D*M²). The number of
635   // data elements is not changed.
636   Stream &ThenSpaceToDepth(const dnn::BatchDescriptor &input_dimensions,
637                            const DeviceMemory<float> &input_data,
638                            const dnn::DepthToSpaceLayout &space_to_depth_layout,
639                            const int sqrt_depth_increase,
640                            DeviceMemory<float> *output_data);
641 
642   Stream &ThenElementwiseOperate(
643       dnn::ElementwiseOperation operation,
644       port::ArraySlice<dnn::BatchDescriptor> input_dimensions,
645       port::ArraySlice<const DeviceMemory<float> *> input_data,
646       const dnn::BatchDescriptor &output_dimensions,
647       DeviceMemory<float> *output_data);
648 
649   Stream &ThenElementwiseOperateScaledQuantized(
650       dnn::ElementwiseOperation operation,
651       port::ArraySlice<int> input_multiplicands, int output_divisor,
652       port::ArraySlice<dnn::BatchDescriptor> input_dimensions,
653       port::ArraySlice<const DeviceMemory<float> *> input_data,
654       const dnn::BatchDescriptor &output_dimensions,
655       DeviceMemory<float> *output_data);
656 
657   Stream &ThenXYPad(const dnn::BatchDescriptor &dimensions,
658                     const DeviceMemory<float> &input_data, int64 left_pad,
659                     int64 right_pad, int64 top_pad, int64 bottom_pad,
660                     DeviceMemory<float> *output_data);
661 
662   Stream &ThenXYSlice(const dnn::BatchDescriptor &dimensions,
663                       const DeviceMemory<float> &input_data, int64 left_trim,
664                       int64 right_trim, int64 top_trim, int64 bottom_trim,
665                       DeviceMemory<float> *output_data);
666 
667   // Grows the input tensor by replicating the X and Y dimensions. The batch and
668   // depth/feature_map dimensions are unchanged. Currently, the input tensor is
669   // limited to X=1 and Y=1.
670   Stream &ThenXYBroadcast(const dnn::BatchDescriptor &dimensions,
671                           const DeviceMemory<float> &input_data,
672                           int64 replicate_x, int64 replicate_y,
673                           DeviceMemory<float> *output_data);
674 
675   // See DnnSupport::DoMemcpyD2HQuantized.
676   Stream &ThenMemcpyD2HQuantized(const DeviceMemory<float> &gpu_unquantized_src,
677                                  dnn::QuantizedActivationMode mode,
678                                  void *host_dst, uint64 size);
679 
680   // Template version of ThenMemcpyD2HQuantized that takes a MutableArraySlice
681   // and uses the Quantization trait to call the generic version of
682   // ThenMemcpyD2HQuantized with the correct QuantizedActivationMode.
683   template <typename ElementType>
ThenMemcpyD2HQuantized(const DeviceMemory<float> & gpu_unquantized_src,port::MutableArraySlice<ElementType> host_dst)684   Stream &ThenMemcpyD2HQuantized(
685       const DeviceMemory<float> &gpu_unquantized_src,
686       port::MutableArraySlice<ElementType> host_dst) {
687     return ThenMemcpyD2HQuantized(
688         gpu_unquantized_src, Quantization<ElementType>::kModeId,
689         host_dst.data(), host_dst.size() * sizeof(ElementType));
690   }
691 
692   // See DnnSupport::DoMemcpyH2DQuantized.
693   Stream &ThenMemcpyH2DQuantized(const void *host_src, uint64 size,
694                                  dnn::QuantizedActivationMode mode,
695                                  DeviceMemory<float> *gpu_unquantized_dst);
696 
697   // Template version of ThenMemcpyH2DQuantized that takes an ArraySlice
698   // and uses the Quantization trait to call the generic version of
699   // ThenMemcpyH2DQuantized with the correct QuantizedActivationMode.
700   template <typename ElementType>
ThenMemcpyH2DQuantized(port::ArraySlice<ElementType> host_src,DeviceMemory<float> * gpu_unquantized_dst)701   Stream &ThenMemcpyH2DQuantized(port::ArraySlice<ElementType> host_src,
702                                  DeviceMemory<float> *gpu_unquantized_dst) {
703     return ThenMemcpyH2DQuantized(
704         host_src.data(), host_src.size() * sizeof(ElementType),
705         Quantization<ElementType>::kModeId, gpu_unquantized_dst);
706   }
707 
708   // See DnnSupport::DoCopyHostBuffer2Device.
709   Stream &ThenCopyHostBuffer2Device(HostBuffer *buffer_src,
710                                     DeviceMemory<float> *gpu_unquantized_dst);
711 
712   // See DnnSupport::DoCopyDevice2HostBuffer.
713   Stream &ThenCopyDevice2HostBuffer(
714       const DeviceMemory<float> &gpu_unquantized_src, HostBuffer *buffer_dst);
715 
716   /////////////////
717   // BLAS support
718 
719   // See BlasSupport::DoBlasAsum.
720   Stream &ThenBlasAsum(uint64 elem_count, const DeviceMemory<float> &x,
721                        int incx, DeviceMemory<float> *result);
722   Stream &ThenBlasAsum(uint64 elem_count, const DeviceMemory<double> &x,
723                        int incx, DeviceMemory<double> *result);
724   Stream &ThenBlasAsum(uint64 elem_count,
725                        const DeviceMemory<std::complex<float>> &x, int incx,
726                        DeviceMemory<float> *result);
727   Stream &ThenBlasAsum(uint64 elem_count,
728                        const DeviceMemory<std::complex<double>> &x, int incx,
729                        DeviceMemory<double> *result);
730 
731   // See BlasSupport::DoBlasAxpy. Note that, even for the case where alpha is
732   // present in DeviceMemory, it must be an execution-time constant (i.e. a
733   // value
734   // that the stream does not change or populate during the course of
735   // execution). The value is effectively captured at stream-enqueue time.
736   Stream &ThenBlasAxpy(uint64 elem_count, float alpha,
737                        const DeviceMemory<float> &x, int incx,
738                        DeviceMemory<float> *y, int incy);
739   Stream &ThenBlasAxpy(uint64 elem_count, double alpha,
740                        const DeviceMemory<double> &x, int incx,
741                        DeviceMemory<double> *y, int incy);
742   Stream &ThenBlasAxpy(uint64 elem_count, std::complex<float> alpha,
743                        const DeviceMemory<std::complex<float>> &x, int incx,
744                        DeviceMemory<std::complex<float>> *y, int incy);
745   Stream &ThenBlasAxpy(uint64 elem_count, std::complex<double> alpha,
746                        const DeviceMemory<std::complex<double>> &x, int incx,
747                        DeviceMemory<std::complex<double>> *y, int incy);
748 
749   // See BlasSupport::DoBlasCopy.
750   Stream &ThenBlasCopy(uint64 elem_count, const DeviceMemory<float> &x,
751                        int incx, DeviceMemory<float> *y, int incy);
752   Stream &ThenBlasCopy(uint64 elem_count, const DeviceMemory<double> &x,
753                        int incx, DeviceMemory<double> *y, int incy);
754   Stream &ThenBlasCopy(uint64 elem_count,
755                        const DeviceMemory<std::complex<float>> &x, int incx,
756                        DeviceMemory<std::complex<float>> *y, int incy);
757   Stream &ThenBlasCopy(uint64 elem_count,
758                        const DeviceMemory<std::complex<double>> &x, int incx,
759                        DeviceMemory<std::complex<double>> *y, int incy);
760 
761   // See BlasSupport::DoBlasDot.
762   Stream &ThenBlasDot(uint64 elem_count, const DeviceMemory<float> &x, int incx,
763                       const DeviceMemory<float> &y, int incy,
764                       DeviceMemory<float> *result);
765   Stream &ThenBlasDot(uint64 elem_count, const DeviceMemory<double> &x,
766                       int incx, const DeviceMemory<double> &y, int incy,
767                       DeviceMemory<double> *result);
768 
769   // See BlasSupport::DoBlasDotc.
770   Stream &ThenBlasDotc(uint64 elem_count,
771                        const DeviceMemory<std::complex<float>> &x, int incx,
772                        const DeviceMemory<std::complex<float>> &y, int incy,
773                        DeviceMemory<std::complex<float>> *result);
774   Stream &ThenBlasDotc(uint64 elem_count,
775                        const DeviceMemory<std::complex<double>> &x, int incx,
776                        const DeviceMemory<std::complex<double>> &y, int incy,
777                        DeviceMemory<std::complex<double>> *result);
778 
779   // See BlasSupport::DoBlasDotu.
780   Stream &ThenBlasDotu(uint64 elem_count,
781                        const DeviceMemory<std::complex<float>> &x, int incx,
782                        const DeviceMemory<std::complex<float>> &y, int incy,
783                        DeviceMemory<std::complex<float>> *result);
784   Stream &ThenBlasDotu(uint64 elem_count,
785                        const DeviceMemory<std::complex<double>> &x, int incx,
786                        const DeviceMemory<std::complex<double>> &y, int incy,
787                        DeviceMemory<std::complex<double>> *result);
788 
789   // See BlasSupport::DoBlasNrm2.
790   Stream &ThenBlasNrm2(uint64 elem_count, const DeviceMemory<float> &x,
791                        int incx, DeviceMemory<float> *result);
792   Stream &ThenBlasNrm2(uint64 elem_count, const DeviceMemory<double> &x,
793                        int incx, DeviceMemory<double> *result);
794   Stream &ThenBlasNrm2(uint64 elem_count,
795                        const DeviceMemory<std::complex<float>> &x, int incx,
796                        DeviceMemory<float> *result);
797   Stream &ThenBlasNrm2(uint64 elem_count,
798                        const DeviceMemory<std::complex<double>> &x, int incx,
799                        DeviceMemory<double> *result);
800 
801   // See BlasSupport::DoBlasRot.
802   Stream &ThenBlasRot(uint64 elem_count, DeviceMemory<float> *x, int incx,
803                       DeviceMemory<float> *y, int incy, float c, float s);
804   Stream &ThenBlasRot(uint64 elem_count, DeviceMemory<double> *x, int incx,
805                       DeviceMemory<double> *y, int incy, double c, double s);
806   Stream &ThenBlasRot(uint64 elem_count, DeviceMemory<std::complex<float>> *x,
807                       int incx, DeviceMemory<std::complex<float>> *y, int incy,
808                       float c, float s);
809   Stream &ThenBlasRot(uint64 elem_count, DeviceMemory<std::complex<double>> *x,
810                       int incx, DeviceMemory<std::complex<double>> *y, int incy,
811                       double c, double s);
812 
813   // See BlasSupport::DoBlasRotg.
814   Stream &ThenBlasRotg(DeviceMemory<float> *a, DeviceMemory<float> *b,
815                        DeviceMemory<float> *c, DeviceMemory<float> *s);
816   Stream &ThenBlasRotg(DeviceMemory<double> *a, DeviceMemory<double> *b,
817                        DeviceMemory<double> *c, DeviceMemory<double> *s);
818   Stream &ThenBlasRotg(DeviceMemory<std::complex<float>> *a,
819                        DeviceMemory<std::complex<float>> *b,
820                        DeviceMemory<float> *c,
821                        DeviceMemory<std::complex<float>> *s);
822   Stream &ThenBlasRotg(DeviceMemory<std::complex<double>> *a,
823                        DeviceMemory<std::complex<double>> *b,
824                        DeviceMemory<double> *c,
825                        DeviceMemory<std::complex<double>> *s);
826 
827   // See BlasSupport::DoBlasRotm.
828   Stream &ThenBlasRotm(uint64 elem_count, DeviceMemory<float> *x, int incx,
829                        DeviceMemory<float> *y, int incy,
830                        const DeviceMemory<float> &param);
831   Stream &ThenBlasRotm(uint64 elem_count, DeviceMemory<double> *x, int incx,
832                        DeviceMemory<double> *y, int incy,
833                        const DeviceMemory<double> &param);
834 
835   // See BlasSupport::DoBlasRotmg.
836   Stream &ThenBlasRotmg(DeviceMemory<float> *d1, DeviceMemory<float> *d2,
837                         DeviceMemory<float> *x1, const DeviceMemory<float> &y1,
838                         DeviceMemory<float> *param);
839   Stream &ThenBlasRotmg(DeviceMemory<double> *d1, DeviceMemory<double> *d2,
840                         DeviceMemory<double> *x1,
841                         const DeviceMemory<double> &y1,
842                         DeviceMemory<double> *param);
843 
844   // See BlasSupport::DoBlasScal.
845   Stream &ThenBlasScal(uint64 elem_count, float alpha, DeviceMemory<float> *x,
846                        int incx);
847   Stream &ThenBlasScal(uint64 elem_count, double alpha, DeviceMemory<double> *x,
848                        int incx);
849   Stream &ThenBlasScal(uint64 elem_count, float alpha,
850                        DeviceMemory<std::complex<float>> *x, int incx);
851   Stream &ThenBlasScal(uint64 elem_count, double alpha,
852                        DeviceMemory<std::complex<double>> *x, int incx);
853   Stream &ThenBlasScal(uint64 elem_count, std::complex<float> alpha,
854                        DeviceMemory<std::complex<float>> *x, int incx);
855   Stream &ThenBlasScal(uint64 elem_count, std::complex<double> alpha,
856                        DeviceMemory<std::complex<double>> *x, int incx);
857 
858   // See BlasSupport::DoBlasSwap.
859   Stream &ThenBlasSwap(uint64 elem_count, DeviceMemory<float> *x, int incx,
860                        DeviceMemory<float> *y, int incy);
861   Stream &ThenBlasSwap(uint64 elem_count, DeviceMemory<double> *x, int incx,
862                        DeviceMemory<double> *y, int incy);
863   Stream &ThenBlasSwap(uint64 elem_count, DeviceMemory<std::complex<float>> *x,
864                        int incx, DeviceMemory<std::complex<float>> *y,
865                        int incy);
866   Stream &ThenBlasSwap(uint64 elem_count, DeviceMemory<std::complex<double>> *x,
867                        int incx, DeviceMemory<std::complex<double>> *y,
868                        int incy);
869 
870   // See BlasSupport::DoBlasIamax.
871   Stream &ThenBlasIamax(uint64 elem_count, const DeviceMemory<float> &x,
872                         int incx, DeviceMemory<int> *result);
873   Stream &ThenBlasIamax(uint64 elem_count, const DeviceMemory<double> &x,
874                         int incx, DeviceMemory<int> *result);
875   Stream &ThenBlasIamax(uint64 elem_count,
876                         const DeviceMemory<std::complex<float>> &x, int incx,
877                         DeviceMemory<int> *result);
878   Stream &ThenBlasIamax(uint64 elem_count,
879                         const DeviceMemory<std::complex<double>> &x, int incx,
880                         DeviceMemory<int> *result);
881 
882   // See BlasSupport::DoBlasIamin.
883   Stream &ThenBlasIamin(uint64 elem_count, const DeviceMemory<float> &x,
884                         int incx, DeviceMemory<int> *result);
885   Stream &ThenBlasIamin(uint64 elem_count, const DeviceMemory<double> &x,
886                         int incx, DeviceMemory<int> *result);
887   Stream &ThenBlasIamin(uint64 elem_count,
888                         const DeviceMemory<std::complex<float>> &x, int incx,
889                         DeviceMemory<int> *result);
890   Stream &ThenBlasIamin(uint64 elem_count,
891                         const DeviceMemory<std::complex<double>> &x, int incx,
892                         DeviceMemory<int> *result);
893 
894   // See BlasSupport::DoBlasGbmv.
895   Stream &ThenBlasGbmv(blas::Transpose trans, uint64 m, uint64 n, uint64 kl,
896                        uint64 ku, float alpha, const DeviceMemory<float> &a,
897                        int lda, const DeviceMemory<float> &x, int incx,
898                        float beta, DeviceMemory<float> *y, int incy);
899   Stream &ThenBlasGbmv(blas::Transpose trans, uint64 m, uint64 n, uint64 kl,
900                        uint64 ku, double alpha, const DeviceMemory<double> &a,
901                        int lda, const DeviceMemory<double> &x, int incx,
902                        double beta, DeviceMemory<double> *y, int incy);
903   Stream &ThenBlasGbmv(blas::Transpose trans, uint64 m, uint64 n, uint64 kl,
904                        uint64 ku, std::complex<float> alpha,
905                        const DeviceMemory<std::complex<float>> &a, int lda,
906                        const DeviceMemory<std::complex<float>> &x, int incx,
907                        std::complex<float> beta,
908                        DeviceMemory<std::complex<float>> *y, int incy);
909   Stream &ThenBlasGbmv(blas::Transpose trans, uint64 m, uint64 n, uint64 kl,
910                        uint64 ku, std::complex<double> alpha,
911                        const DeviceMemory<std::complex<double>> &a, int lda,
912                        const DeviceMemory<std::complex<double>> &x, int incx,
913                        std::complex<double> beta,
914                        DeviceMemory<std::complex<double>> *y, int incy);
915 
916   // See BlasSupport::DoBlasGemv.
917   Stream &ThenBlasGemv(blas::Transpose trans, uint64 m, uint64 n, float alpha,
918                        const DeviceMemory<float> &a, int lda,
919                        const DeviceMemory<float> &x, int incx, float beta,
920                        DeviceMemory<float> *y, int incy);
921   Stream &ThenBlasGemv(blas::Transpose trans, uint64 m, uint64 n, double alpha,
922                        const DeviceMemory<double> &a, int lda,
923                        const DeviceMemory<double> &x, int incx, double beta,
924                        DeviceMemory<double> *y, int incy);
925   Stream &ThenBlasGemv(blas::Transpose trans, uint64 m, uint64 n,
926                        std::complex<float> alpha,
927                        const DeviceMemory<std::complex<float>> &a, int lda,
928                        const DeviceMemory<std::complex<float>> &x, int incx,
929                        std::complex<float> beta,
930                        DeviceMemory<std::complex<float>> *y, int incy);
931   Stream &ThenBlasGemv(blas::Transpose trans, uint64 m, uint64 n,
932                        std::complex<double> alpha,
933                        const DeviceMemory<std::complex<double>> &a, int lda,
934                        const DeviceMemory<std::complex<double>> &x, int incx,
935                        std::complex<double> beta,
936                        DeviceMemory<std::complex<double>> *y, int incy);
937 
938   Stream &ThenBlasGemvWithProfiling(blas::Transpose trans, uint64 m, uint64 n,
939                                     float alpha, const DeviceMemory<float> &a,
940                                     int lda, const DeviceMemory<float> &x,
941                                     int incx, float beta,
942                                     DeviceMemory<float> *y, int incy,
943                                     blas::ProfileResult *output_profile_result);
944   Stream &ThenBlasGemvWithProfiling(blas::Transpose trans, uint64 m, uint64 n,
945                                     double alpha, const DeviceMemory<double> &a,
946                                     int lda, const DeviceMemory<double> &x,
947                                     int incx, double beta,
948                                     DeviceMemory<double> *y, int incy,
949                                     blas::ProfileResult *output_profile_result);
950   Stream &ThenBlasGemvWithProfiling(
951       blas::Transpose trans, uint64 m, uint64 n, std::complex<float> alpha,
952       const DeviceMemory<std::complex<float>> &a, int lda,
953       const DeviceMemory<std::complex<float>> &x, int incx,
954       std::complex<float> beta, DeviceMemory<std::complex<float>> *y, int incy,
955       blas::ProfileResult *output_profile_result);
956   Stream &ThenBlasGemvWithProfiling(
957       blas::Transpose trans, uint64 m, uint64 n, std::complex<double> alpha,
958       const DeviceMemory<std::complex<double>> &a, int lda,
959       const DeviceMemory<std::complex<double>> &x, int incx,
960       std::complex<double> beta, DeviceMemory<std::complex<double>> *y,
961       int incy, blas::ProfileResult *output_profile_result);
962 
963   // See BlasSupport::DoBlasGer.
964   Stream &ThenBlasGer(uint64 m, uint64 n, float alpha,
965                       const DeviceMemory<float> &x, int incx,
966                       const DeviceMemory<float> &y, int incy,
967                       DeviceMemory<float> *a, int lda);
968   Stream &ThenBlasGer(uint64 m, uint64 n, double alpha,
969                       const DeviceMemory<double> &x, int incx,
970                       const DeviceMemory<double> &y, int incy,
971                       DeviceMemory<double> *a, int lda);
972 
973   // See BlasSupport::DoBlasGerc.
974   Stream &ThenBlasGerc(uint64 m, uint64 n, std::complex<float> alpha,
975                        const DeviceMemory<std::complex<float>> &x, int incx,
976                        const DeviceMemory<std::complex<float>> &y, int incy,
977                        DeviceMemory<std::complex<float>> *a, int lda);
978   Stream &ThenBlasGerc(uint64 m, uint64 n, std::complex<double> alpha,
979                        const DeviceMemory<std::complex<double>> &x, int incx,
980                        const DeviceMemory<std::complex<double>> &y, int incy,
981                        DeviceMemory<std::complex<double>> *a, int lda);
982 
983   // See BlasSupport::DoBlasGeru.
984   Stream &ThenBlasGeru(uint64 m, uint64 n, std::complex<float> alpha,
985                        const DeviceMemory<std::complex<float>> &x, int incx,
986                        const DeviceMemory<std::complex<float>> &y, int incy,
987                        DeviceMemory<std::complex<float>> *a, int lda);
988   Stream &ThenBlasGeru(uint64 m, uint64 n, std::complex<double> alpha,
989                        const DeviceMemory<std::complex<double>> &x, int incx,
990                        const DeviceMemory<std::complex<double>> &y, int incy,
991                        DeviceMemory<std::complex<double>> *a, int lda);
992 
993   // See BlasSupport::DoBlasHbmv.
994   Stream &ThenBlasHbmv(blas::UpperLower uplo, uint64 n, uint64 k,
995                        std::complex<float> alpha,
996                        const DeviceMemory<std::complex<float>> &a, int lda,
997                        const DeviceMemory<std::complex<float>> &x, int incx,
998                        std::complex<float> beta,
999                        DeviceMemory<std::complex<float>> *y, int incy);
1000   Stream &ThenBlasHbmv(blas::UpperLower uplo, uint64 n, uint64 k,
1001                        std::complex<double> alpha,
1002                        const DeviceMemory<std::complex<double>> &a, int lda,
1003                        const DeviceMemory<std::complex<double>> &x, int incx,
1004                        std::complex<double> beta,
1005                        DeviceMemory<std::complex<double>> *y, int incy);
1006 
1007   // See BlasSupport::DoBlasHemv.
1008   Stream &ThenBlasHemv(blas::UpperLower uplo, uint64 n,
1009                        std::complex<float> alpha,
1010                        const DeviceMemory<std::complex<float>> &a, int lda,
1011                        const DeviceMemory<std::complex<float>> &x, int incx,
1012                        std::complex<float> beta,
1013                        DeviceMemory<std::complex<float>> *y, int incy);
1014   Stream &ThenBlasHemv(blas::UpperLower uplo, uint64 n,
1015                        std::complex<double> alpha,
1016                        const DeviceMemory<std::complex<double>> &a, int lda,
1017                        const DeviceMemory<std::complex<double>> &x, int incx,
1018                        std::complex<double> beta,
1019                        DeviceMemory<std::complex<double>> *y, int incy);
1020 
1021   // See BlasSupport::DoBlasHer.
1022   Stream &ThenBlasHer(blas::UpperLower uplo, uint64 n, float alpha,
1023                       const DeviceMemory<std::complex<float>> &x, int incx,
1024                       DeviceMemory<std::complex<float>> *a, int lda);
1025   Stream &ThenBlasHer(blas::UpperLower uplo, uint64 n, double alpha,
1026                       const DeviceMemory<std::complex<double>> &x, int incx,
1027                       DeviceMemory<std::complex<double>> *a, int lda);
1028 
1029   // See BlasSupport::DoBlasHer2.
1030   Stream &ThenBlasHer2(blas::UpperLower uplo, uint64 n,
1031                        std::complex<float> alpha,
1032                        const DeviceMemory<std::complex<float>> &x, int incx,
1033                        const DeviceMemory<std::complex<float>> &y, int incy,
1034                        DeviceMemory<std::complex<float>> *a, int lda);
1035   Stream &ThenBlasHer2(blas::UpperLower uplo, uint64 n,
1036                        std::complex<double> alpha,
1037                        const DeviceMemory<std::complex<double>> &x, int incx,
1038                        const DeviceMemory<std::complex<double>> &y, int incy,
1039                        DeviceMemory<std::complex<double>> *a, int lda);
1040 
1041   // See BlasSupport::DoBlasHpmv.
1042   Stream &ThenBlasHpmv(blas::UpperLower uplo, uint64 n,
1043                        std::complex<float> alpha,
1044                        const DeviceMemory<std::complex<float>> &ap,
1045                        const DeviceMemory<std::complex<float>> &x, int incx,
1046                        std::complex<float> beta,
1047                        DeviceMemory<std::complex<float>> *y, int incy);
1048   Stream &ThenBlasHpmv(blas::UpperLower uplo, uint64 n,
1049                        std::complex<double> alpha,
1050                        const DeviceMemory<std::complex<double>> &ap,
1051                        const DeviceMemory<std::complex<double>> &x, int incx,
1052                        std::complex<double> beta,
1053                        DeviceMemory<std::complex<double>> *y, int incy);
1054 
1055   // See BlasSupport::DoBlasHpr.
1056   Stream &ThenBlasHpr(blas::UpperLower uplo, uint64 n, float alpha,
1057                       const DeviceMemory<std::complex<float>> &x, int incx,
1058                       DeviceMemory<std::complex<float>> *ap);
1059   Stream &ThenBlasHpr(blas::UpperLower uplo, uint64 n, double alpha,
1060                       const DeviceMemory<std::complex<double>> &x, int incx,
1061                       DeviceMemory<std::complex<double>> *ap);
1062 
1063   // See BlasSupport::DoBlasHpr2.
1064   Stream &ThenBlasHpr2(blas::UpperLower uplo, uint64 n,
1065                        std::complex<float> alpha,
1066                        const DeviceMemory<std::complex<float>> &x, int incx,
1067                        const DeviceMemory<std::complex<float>> &y, int incy,
1068                        DeviceMemory<std::complex<float>> *ap);
1069   Stream &ThenBlasHpr2(blas::UpperLower uplo, uint64 n,
1070                        std::complex<double> alpha,
1071                        const DeviceMemory<std::complex<double>> &x, int incx,
1072                        const DeviceMemory<std::complex<double>> &y, int incy,
1073                        DeviceMemory<std::complex<double>> *ap);
1074 
1075   // See BlasSupport::DoBlasSbmv.
1076   Stream &ThenBlasSbmv(blas::UpperLower uplo, uint64 n, uint64 k, float alpha,
1077                        const DeviceMemory<float> &a, int lda,
1078                        const DeviceMemory<float> &x, int incx, float beta,
1079                        DeviceMemory<float> *y, int incy);
1080   Stream &ThenBlasSbmv(blas::UpperLower uplo, uint64 n, uint64 k, double alpha,
1081                        const DeviceMemory<double> &a, int lda,
1082                        const DeviceMemory<double> &x, int incx, double beta,
1083                        DeviceMemory<double> *y, int incy);
1084 
1085   // See BlasSupport::DoBlasSpmv.
1086   Stream &ThenBlasSpmv(blas::UpperLower uplo, uint64 n, float alpha,
1087                        const DeviceMemory<float> &ap,
1088                        const DeviceMemory<float> &x, int incx, float beta,
1089                        DeviceMemory<float> *y, int incy);
1090   Stream &ThenBlasSpmv(blas::UpperLower uplo, uint64 n, double alpha,
1091                        const DeviceMemory<double> &ap,
1092                        const DeviceMemory<double> &x, int incx, double beta,
1093                        DeviceMemory<double> *y, int incy);
1094 
1095   // See BlasSupport::DoBlasSpr.
1096   Stream &ThenBlasSpr(blas::UpperLower uplo, uint64 n, float alpha,
1097                       const DeviceMemory<float> &x, int incx,
1098                       DeviceMemory<float> *ap);
1099   Stream &ThenBlasSpr(blas::UpperLower uplo, uint64 n, double alpha,
1100                       const DeviceMemory<double> &x, int incx,
1101                       DeviceMemory<double> *ap);
1102 
1103   // See BlasSupport::DoBlasSpr2.
1104   Stream &ThenBlasSpr2(blas::UpperLower uplo, uint64 n, float alpha,
1105                        const DeviceMemory<float> &x, int incx,
1106                        const DeviceMemory<float> &y, int incy,
1107                        DeviceMemory<float> *ap);
1108   Stream &ThenBlasSpr2(blas::UpperLower uplo, uint64 n, double alpha,
1109                        const DeviceMemory<double> &x, int incx,
1110                        const DeviceMemory<double> &y, int incy,
1111                        DeviceMemory<double> *ap);
1112 
1113   // See BlasSupport::DoBlasSymv.
1114   Stream &ThenBlasSymv(blas::UpperLower uplo, uint64 n, float alpha,
1115                        const DeviceMemory<float> &a, int lda,
1116                        const DeviceMemory<float> &x, int incx, float beta,
1117                        DeviceMemory<float> *y, int incy);
1118   Stream &ThenBlasSymv(blas::UpperLower uplo, uint64 n, double alpha,
1119                        const DeviceMemory<double> &a, int lda,
1120                        const DeviceMemory<double> &x, int incx, double beta,
1121                        DeviceMemory<double> *y, int incy);
1122 
1123   // See BlasSupport::DoBlasSyr.
1124   Stream &ThenBlasSyr(blas::UpperLower uplo, uint64 n, float alpha,
1125                       const DeviceMemory<float> &x, int incx,
1126                       DeviceMemory<float> *a, int lda);
1127   Stream &ThenBlasSyr(blas::UpperLower uplo, uint64 n, double alpha,
1128                       const DeviceMemory<double> &x, int incx,
1129                       DeviceMemory<double> *a, int lda);
1130 
1131   // See BlasSupport::DoBlasSyr2.
1132   Stream &ThenBlasSyr2(blas::UpperLower uplo, uint64 n, float alpha,
1133                        const DeviceMemory<float> &x, int incx,
1134                        const DeviceMemory<float> &y, int incy,
1135                        DeviceMemory<float> *a, int lda);
1136   Stream &ThenBlasSyr2(blas::UpperLower uplo, uint64 n, double alpha,
1137                        const DeviceMemory<double> &x, int incx,
1138                        const DeviceMemory<double> &y, int incy,
1139                        DeviceMemory<double> *a, int lda);
1140 
1141   // See BlasSupport::DoBlasTbmv.
1142   Stream &ThenBlasTbmv(blas::UpperLower uplo, blas::Transpose trans,
1143                        blas::Diagonal diag, uint64 n, uint64 k,
1144                        const DeviceMemory<float> &a, int lda,
1145                        DeviceMemory<float> *x, int incx);
1146   Stream &ThenBlasTbmv(blas::UpperLower uplo, blas::Transpose trans,
1147                        blas::Diagonal diag, uint64 n, uint64 k,
1148                        const DeviceMemory<double> &a, int lda,
1149                        DeviceMemory<double> *x, int incx);
1150   Stream &ThenBlasTbmv(blas::UpperLower uplo, blas::Transpose trans,
1151                        blas::Diagonal diag, uint64 n, uint64 k,
1152                        const DeviceMemory<std::complex<float>> &a, int lda,
1153                        DeviceMemory<std::complex<float>> *x, int incx);
1154   Stream &ThenBlasTbmv(blas::UpperLower uplo, blas::Transpose trans,
1155                        blas::Diagonal diag, uint64 n, uint64 k,
1156                        const DeviceMemory<std::complex<double>> &a, int lda,
1157                        DeviceMemory<std::complex<double>> *x, int incx);
1158 
1159   // See BlasSupport::DoBlasTbsv.
1160   Stream &ThenBlasTbsv(blas::UpperLower uplo, blas::Transpose trans,
1161                        blas::Diagonal diag, uint64 n, uint64 k,
1162                        const DeviceMemory<float> &a, int lda,
1163                        DeviceMemory<float> *x, int incx);
1164   Stream &ThenBlasTbsv(blas::UpperLower uplo, blas::Transpose trans,
1165                        blas::Diagonal diag, uint64 n, uint64 k,
1166                        const DeviceMemory<double> &a, int lda,
1167                        DeviceMemory<double> *x, int incx);
1168   Stream &ThenBlasTbsv(blas::UpperLower uplo, blas::Transpose trans,
1169                        blas::Diagonal diag, uint64 n, uint64 k,
1170                        const DeviceMemory<std::complex<float>> &a, int lda,
1171                        DeviceMemory<std::complex<float>> *x, int incx);
1172   Stream &ThenBlasTbsv(blas::UpperLower uplo, blas::Transpose trans,
1173                        blas::Diagonal diag, uint64 n, uint64 k,
1174                        const DeviceMemory<std::complex<double>> &a, int lda,
1175                        DeviceMemory<std::complex<double>> *x, int incx);
1176 
1177   // See BlasSupport::DoBlasTpmv.
1178   Stream &ThenBlasTpmv(blas::UpperLower uplo, blas::Transpose trans,
1179                        blas::Diagonal diag, uint64 n,
1180                        const DeviceMemory<float> &ap, DeviceMemory<float> *x,
1181                        int incx);
1182   Stream &ThenBlasTpmv(blas::UpperLower uplo, blas::Transpose trans,
1183                        blas::Diagonal diag, uint64 n,
1184                        const DeviceMemory<double> &ap, DeviceMemory<double> *x,
1185                        int incx);
1186   Stream &ThenBlasTpmv(blas::UpperLower uplo, blas::Transpose trans,
1187                        blas::Diagonal diag, uint64 n,
1188                        const DeviceMemory<std::complex<float>> &ap,
1189                        DeviceMemory<std::complex<float>> *x, int incx);
1190   Stream &ThenBlasTpmv(blas::UpperLower uplo, blas::Transpose trans,
1191                        blas::Diagonal diag, uint64 n,
1192                        const DeviceMemory<std::complex<double>> &ap,
1193                        DeviceMemory<std::complex<double>> *x, int incx);
1194 
1195   // See BlasSupport::DoBlasTpsv.
1196   Stream &ThenBlasTpsv(blas::UpperLower uplo, blas::Transpose trans,
1197                        blas::Diagonal diag, uint64 n,
1198                        const DeviceMemory<float> &ap, DeviceMemory<float> *x,
1199                        int incx);
1200   Stream &ThenBlasTpsv(blas::UpperLower uplo, blas::Transpose trans,
1201                        blas::Diagonal diag, uint64 n,
1202                        const DeviceMemory<double> &ap, DeviceMemory<double> *x,
1203                        int incx);
1204   Stream &ThenBlasTpsv(blas::UpperLower uplo, blas::Transpose trans,
1205                        blas::Diagonal diag, uint64 n,
1206                        const DeviceMemory<std::complex<float>> &ap,
1207                        DeviceMemory<std::complex<float>> *x, int incx);
1208   Stream &ThenBlasTpsv(blas::UpperLower uplo, blas::Transpose trans,
1209                        blas::Diagonal diag, uint64 n,
1210                        const DeviceMemory<std::complex<double>> &ap,
1211                        DeviceMemory<std::complex<double>> *x, int incx);
1212 
1213   // See BlasSupport::DoBlasTrmv.
1214   Stream &ThenBlasTrmv(blas::UpperLower uplo, blas::Transpose trans,
1215                        blas::Diagonal diag, uint64 n,
1216                        const DeviceMemory<float> &a, int lda,
1217                        DeviceMemory<float> *x, int incx);
1218   Stream &ThenBlasTrmv(blas::UpperLower uplo, blas::Transpose trans,
1219                        blas::Diagonal diag, uint64 n,
1220                        const DeviceMemory<double> &a, int lda,
1221                        DeviceMemory<double> *x, int incx);
1222   Stream &ThenBlasTrmv(blas::UpperLower uplo, blas::Transpose trans,
1223                        blas::Diagonal diag, uint64 n,
1224                        const DeviceMemory<std::complex<float>> &a, int lda,
1225                        DeviceMemory<std::complex<float>> *x, int incx);
1226   Stream &ThenBlasTrmv(blas::UpperLower uplo, blas::Transpose trans,
1227                        blas::Diagonal diag, uint64 n,
1228                        const DeviceMemory<std::complex<double>> &a, int lda,
1229                        DeviceMemory<std::complex<double>> *x, int incx);
1230 
1231   // See BlasSupport::DoBlasTrsv.
1232   Stream &ThenBlasTrsv(blas::UpperLower uplo, blas::Transpose trans,
1233                        blas::Diagonal diag, uint64 n,
1234                        const DeviceMemory<float> &a, int lda,
1235                        DeviceMemory<float> *x, int incx);
1236   Stream &ThenBlasTrsv(blas::UpperLower uplo, blas::Transpose trans,
1237                        blas::Diagonal diag, uint64 n,
1238                        const DeviceMemory<double> &a, int lda,
1239                        DeviceMemory<double> *x, int incx);
1240   Stream &ThenBlasTrsv(blas::UpperLower uplo, blas::Transpose trans,
1241                        blas::Diagonal diag, uint64 n,
1242                        const DeviceMemory<std::complex<float>> &a, int lda,
1243                        DeviceMemory<std::complex<float>> *x, int incx);
1244   Stream &ThenBlasTrsv(blas::UpperLower uplo, blas::Transpose trans,
1245                        blas::Diagonal diag, uint64 n,
1246                        const DeviceMemory<std::complex<double>> &a, int lda,
1247                        DeviceMemory<std::complex<double>> *x, int incx);
1248 
1249   // See BlasSupport::DoBlasGemm.
1250   TF_EXPORT Stream &ThenBlasGemm(blas::Transpose transa, blas::Transpose transb,
1251                                  uint64 m, uint64 n, uint64 k, float alpha,
1252                                  const DeviceMemory<Eigen::half> &a, int lda,
1253                                  const DeviceMemory<Eigen::half> &b, int ldb,
1254                                  float beta, DeviceMemory<Eigen::half> *c,
1255                                  int ldc);
1256   TF_EXPORT Stream &ThenBlasGemm(blas::Transpose transa, blas::Transpose transb,
1257                                  uint64 m, uint64 n, uint64 k, float alpha,
1258                                  const DeviceMemory<float> &a, int lda,
1259                                  const DeviceMemory<float> &b, int ldb,
1260                                  float beta, DeviceMemory<float> *c, int ldc);
1261   TF_EXPORT Stream &ThenBlasGemm(blas::Transpose transa, blas::Transpose transb,
1262                                  uint64 m, uint64 n, uint64 k, double alpha,
1263                                  const DeviceMemory<double> &a, int lda,
1264                                  const DeviceMemory<double> &b, int ldb,
1265                                  double beta, DeviceMemory<double> *c, int ldc);
1266   TF_EXPORT Stream &ThenBlasGemm(blas::Transpose transa, blas::Transpose transb,
1267                                  uint64 m, uint64 n, uint64 k,
1268                                  std::complex<float> alpha,
1269                                  const DeviceMemory<std::complex<float>> &a,
1270                                  int lda,
1271                                  const DeviceMemory<std::complex<float>> &b,
1272                                  int ldb, std::complex<float> beta,
1273                                  DeviceMemory<std::complex<float>> *c, int ldc);
1274   TF_EXPORT Stream &ThenBlasGemm(blas::Transpose transa, blas::Transpose transb,
1275                                  uint64 m, uint64 n, uint64 k,
1276                                  std::complex<double> alpha,
1277                                  const DeviceMemory<std::complex<double>> &a,
1278                                  int lda,
1279                                  const DeviceMemory<std::complex<double>> &b,
1280                                  int ldb, std::complex<double> beta,
1281                                  DeviceMemory<std::complex<double>> *c,
1282                                  int ldc);
1283 
1284   Stream &ThenBlasGemmWithProfiling(blas::Transpose transa,
1285                                     blas::Transpose transb, uint64 m, uint64 n,
1286                                     uint64 k, float alpha,
1287                                     const DeviceMemory<Eigen::half> &a, int lda,
1288                                     const DeviceMemory<Eigen::half> &b, int ldb,
1289                                     float beta, DeviceMemory<Eigen::half> *c,
1290                                     int ldc,
1291                                     blas::ProfileResult *output_profile_result);
1292   Stream &ThenBlasGemmWithProfiling(blas::Transpose transa,
1293                                     blas::Transpose transb, uint64 m, uint64 n,
1294                                     uint64 k, float alpha,
1295                                     const DeviceMemory<float> &a, int lda,
1296                                     const DeviceMemory<float> &b, int ldb,
1297                                     float beta, DeviceMemory<float> *c, int ldc,
1298                                     blas::ProfileResult *output_profile_result);
1299   Stream &ThenBlasGemmWithProfiling(blas::Transpose transa,
1300                                     blas::Transpose transb, uint64 m, uint64 n,
1301                                     uint64 k, double alpha,
1302                                     const DeviceMemory<double> &a, int lda,
1303                                     const DeviceMemory<double> &b, int ldb,
1304                                     double beta, DeviceMemory<double> *c,
1305                                     int ldc,
1306                                     blas::ProfileResult *output_profile_result);
1307   Stream &ThenBlasGemmWithProfiling(
1308       blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n,
1309       uint64 k, std::complex<float> alpha,
1310       const DeviceMemory<std::complex<float>> &a, int lda,
1311       const DeviceMemory<std::complex<float>> &b, int ldb,
1312       std::complex<float> beta, DeviceMemory<std::complex<float>> *c, int ldc,
1313       blas::ProfileResult *output_profile_result);
1314   Stream &ThenBlasGemmWithProfiling(
1315       blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n,
1316       uint64 k, std::complex<double> alpha,
1317       const DeviceMemory<std::complex<double>> &a, int lda,
1318       const DeviceMemory<std::complex<double>> &b, int ldb,
1319       std::complex<double> beta, DeviceMemory<std::complex<double>> *c, int ldc,
1320       blas::ProfileResult *output_profile_result);
1321 
1322   // See BlasSupport::DoBlasGemmWithAlgorithm.
1323   Stream &ThenBlasGemmWithAlgorithm(
1324       blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n,
1325       uint64 k, const HostOrDeviceScalar<Eigen::half> &alpha,
1326       const DeviceMemory<Eigen::half> &a, int lda,
1327       const DeviceMemory<Eigen::half> &b, int ldb,
1328       const HostOrDeviceScalar<Eigen::half> &beta, DeviceMemory<Eigen::half> *c,
1329       int ldc, blas::ComputationType computation_type,
1330       blas::AlgorithmType algorithm,
1331       blas::ProfileResult *output_profile_result);
1332   Stream &ThenBlasGemmWithAlgorithm(
1333       blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n,
1334       uint64 k, const HostOrDeviceScalar<int> &alpha,
1335       const DeviceMemory<int8> &a, int lda, const DeviceMemory<int8> &b,
1336       int ldb, const HostOrDeviceScalar<int> &beta, DeviceMemory<int> *c,
1337       int ldc, blas::ComputationType computation_type,
1338       blas::AlgorithmType algorithm,
1339       blas::ProfileResult *output_profile_result);
1340   Stream &ThenBlasGemmWithAlgorithm(
1341       blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n,
1342       uint64 k, const HostOrDeviceScalar<float> &alpha,
1343       const DeviceMemory<float> &a, int lda, const DeviceMemory<float> &b,
1344       int ldb, const HostOrDeviceScalar<float> &beta, DeviceMemory<float> *c,
1345       int ldc, blas::ComputationType computation_type,
1346       blas::AlgorithmType algorithm,
1347       blas::ProfileResult *output_profile_result);
1348   Stream &ThenBlasGemmWithAlgorithm(
1349       blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n,
1350       uint64 k, const HostOrDeviceScalar<double> &alpha,
1351       const DeviceMemory<double> &a, int lda, const DeviceMemory<double> &b,
1352       int ldb, const HostOrDeviceScalar<double> &beta, DeviceMemory<double> *c,
1353       int ldc, blas::ComputationType computation_type,
1354       blas::AlgorithmType algorithm,
1355       blas::ProfileResult *output_profile_result);
1356   Stream &ThenBlasGemmWithAlgorithm(
1357       blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n,
1358       uint64 k, const HostOrDeviceScalar<std::complex<float>> &alpha,
1359       const DeviceMemory<std::complex<float>> &a, int lda,
1360       const DeviceMemory<std::complex<float>> &b, int ldb,
1361       const HostOrDeviceScalar<std::complex<float>> &beta,
1362       DeviceMemory<std::complex<float>> *c, int ldc,
1363       blas::ComputationType computation_type, blas::AlgorithmType algorithm,
1364       blas::ProfileResult *output_profile_result);
1365   Stream &ThenBlasGemmWithAlgorithm(
1366       blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n,
1367       uint64 k, const HostOrDeviceScalar<std::complex<double>> &alpha,
1368       const DeviceMemory<std::complex<double>> &a, int lda,
1369       const DeviceMemory<std::complex<double>> &b, int ldb,
1370       const HostOrDeviceScalar<std::complex<double>> &beta,
1371       DeviceMemory<std::complex<double>> *c, int ldc,
1372       blas::ComputationType computation_type, blas::AlgorithmType algorithm,
1373       blas::ProfileResult *output_profile_result);
1374 
1375   // See BlasSupport::DoBlasGemmBatched.
1376   Stream &ThenBlasGemmBatched(
1377       blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n,
1378       uint64 k, float alpha,
1379       const port::ArraySlice<DeviceMemory<Eigen::half> *> &a, int lda,
1380       const port::ArraySlice<DeviceMemory<Eigen::half> *> &b, int ldb,
1381       float beta, const port::ArraySlice<DeviceMemory<Eigen::half> *> &c,
1382       int ldc, int batch_count);
1383   Stream &ThenBlasGemmBatched(blas::Transpose transa, blas::Transpose transb,
1384                               uint64 m, uint64 n, uint64 k, float alpha,
1385                               const port::ArraySlice<DeviceMemory<float> *> &a,
1386                               int lda,
1387                               const port::ArraySlice<DeviceMemory<float> *> &b,
1388                               int ldb, float beta,
1389                               const port::ArraySlice<DeviceMemory<float> *> &c,
1390                               int ldc, int batch_count);
1391   Stream &ThenBlasGemmBatched(blas::Transpose transa, blas::Transpose transb,
1392                               uint64 m, uint64 n, uint64 k, double alpha,
1393                               const port::ArraySlice<DeviceMemory<double> *> &a,
1394                               int lda,
1395                               const port::ArraySlice<DeviceMemory<double> *> &b,
1396                               int ldb, double beta,
1397                               const port::ArraySlice<DeviceMemory<double> *> &c,
1398                               int ldc, int batch_count);
1399   Stream &ThenBlasGemmBatched(
1400       blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n,
1401       uint64 k, std::complex<float> alpha,
1402       const port::ArraySlice<DeviceMemory<std::complex<float>> *> &a, int lda,
1403       const port::ArraySlice<DeviceMemory<std::complex<float>> *> &b, int ldb,
1404       std::complex<float> beta,
1405       const port::ArraySlice<DeviceMemory<std::complex<float>> *> &c, int ldc,
1406       int batch_count);
1407   Stream &ThenBlasGemmBatched(
1408       blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n,
1409       uint64 k, std::complex<double> alpha,
1410       const port::ArraySlice<DeviceMemory<std::complex<double>> *> &a, int lda,
1411       const port::ArraySlice<DeviceMemory<std::complex<double>> *> &b, int ldb,
1412       std::complex<double> beta,
1413       const port::ArraySlice<DeviceMemory<std::complex<double>> *> &c, int ldc,
1414       int batch_count);
1415   Stream &ThenBlasGemmBatchedWithScratch(
1416       blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n,
1417       uint64 k, float alpha,
1418       const port::ArraySlice<DeviceMemory<Eigen::half> *> &a, int lda,
1419       const port::ArraySlice<DeviceMemory<Eigen::half> *> &b, int ldb,
1420       float beta, const port::ArraySlice<DeviceMemory<Eigen::half> *> &c,
1421       int ldc, int batch_count, ScratchAllocator *scratch_allocator);
1422   Stream &ThenBlasGemmBatchedWithScratch(
1423       blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n,
1424       uint64 k, float alpha, const port::ArraySlice<DeviceMemory<float> *> &a,
1425       int lda, const port::ArraySlice<DeviceMemory<float> *> &b, int ldb,
1426       float beta, const port::ArraySlice<DeviceMemory<float> *> &c, int ldc,
1427       int batch_count, ScratchAllocator *scratch_allocator);
1428   Stream &ThenBlasGemmBatchedWithScratch(
1429       blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n,
1430       uint64 k, double alpha, const port::ArraySlice<DeviceMemory<double> *> &a,
1431       int lda, const port::ArraySlice<DeviceMemory<double> *> &b, int ldb,
1432       double beta, const port::ArraySlice<DeviceMemory<double> *> &c, int ldc,
1433       int batch_count, ScratchAllocator *scratch_allocator);
1434   Stream &ThenBlasGemmBatchedWithScratch(
1435       blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n,
1436       uint64 k, std::complex<float> alpha,
1437       const port::ArraySlice<DeviceMemory<std::complex<float>> *> &a, int lda,
1438       const port::ArraySlice<DeviceMemory<std::complex<float>> *> &b, int ldb,
1439       std::complex<float> beta,
1440       const port::ArraySlice<DeviceMemory<std::complex<float>> *> &c, int ldc,
1441       int batch_count, ScratchAllocator *scratch_allocator);
1442   Stream &ThenBlasGemmBatchedWithScratch(
1443       blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n,
1444       uint64 k, std::complex<double> alpha,
1445       const port::ArraySlice<DeviceMemory<std::complex<double>> *> &a, int lda,
1446       const port::ArraySlice<DeviceMemory<std::complex<double>> *> &b, int ldb,
1447       std::complex<double> beta,
1448       const port::ArraySlice<DeviceMemory<std::complex<double>> *> &c, int ldc,
1449       int batch_count, ScratchAllocator *scratch_allocator);
1450   Stream &ThenBlasGemmStridedBatched(
1451       blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n,
1452       uint64 k, float alpha, const DeviceMemory<Eigen::half> &a, int lda,
1453       int64 stride_a, const DeviceMemory<Eigen::half> &b, int ldb,
1454       int64 stride_b, float beta, DeviceMemory<Eigen::half> *c, int ldc,
1455       int64 stride_c, int batch_count);
1456   Stream &ThenBlasGemmStridedBatched(
1457       blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n,
1458       uint64 k, float alpha, const DeviceMemory<float> &a, int lda,
1459       int64 stride_a, const DeviceMemory<float> &b, int ldb, int64 stride_b,
1460       float beta, DeviceMemory<float> *c, int ldc, int64 stride_c,
1461       int batch_count);
1462   Stream &ThenBlasGemmStridedBatched(
1463       blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n,
1464       uint64 k, double alpha, const DeviceMemory<double> &a, int lda,
1465       int64 stride_a, const DeviceMemory<double> &b, int ldb, int64 stride_b,
1466       double beta, DeviceMemory<double> *c, int ldc, int64 stride_c,
1467       int batch_count);
1468   Stream &ThenBlasGemmStridedBatched(
1469       blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n,
1470       uint64 k, std::complex<float> alpha,
1471       const DeviceMemory<std::complex<float>> &a, int lda, int64 stride_a,
1472       const DeviceMemory<std::complex<float>> &b, int ldb, int64 stride_b,
1473       std::complex<float> beta, DeviceMemory<std::complex<float>> *c, int ldc,
1474       int64 stride_c, int batch_count);
1475   Stream &ThenBlasGemmStridedBatched(
1476       blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n,
1477       uint64 k, std::complex<double> alpha,
1478       const DeviceMemory<std::complex<double>> &a, int lda, int64 stride_a,
1479       const DeviceMemory<std::complex<double>> &b, int ldb, int64 stride_b,
1480       std::complex<double> beta, DeviceMemory<std::complex<double>> *c, int ldc,
1481       int64 stride_c, int batch_count);
1482 
1483   // See BlasSupport::DoBlasHemm.
1484   Stream &ThenBlasHemm(blas::Side side, blas::UpperLower uplo, uint64 m,
1485                        uint64 n, std::complex<float> alpha,
1486                        const DeviceMemory<std::complex<float>> &a, int lda,
1487                        const DeviceMemory<std::complex<float>> &b, int ldb,
1488                        std::complex<float> beta,
1489                        DeviceMemory<std::complex<float>> *c, int ldc);
1490   Stream &ThenBlasHemm(blas::Side side, blas::UpperLower uplo, uint64 m,
1491                        uint64 n, std::complex<double> alpha,
1492                        const DeviceMemory<std::complex<double>> &a, int lda,
1493                        const DeviceMemory<std::complex<double>> &b, int ldb,
1494                        std::complex<double> beta,
1495                        DeviceMemory<std::complex<double>> *c, int ldc);
1496 
1497   // See BlasSupport::DoBlasHerk.
1498   Stream &ThenBlasHerk(blas::UpperLower uplo, blas::Transpose trans, uint64 n,
1499                        uint64 k, float alpha,
1500                        const DeviceMemory<std::complex<float>> &a, int lda,
1501                        float beta, DeviceMemory<std::complex<float>> *c,
1502                        int ldc);
1503   Stream &ThenBlasHerk(blas::UpperLower uplo, blas::Transpose trans, uint64 n,
1504                        uint64 k, double alpha,
1505                        const DeviceMemory<std::complex<double>> &a, int lda,
1506                        double beta, DeviceMemory<std::complex<double>> *c,
1507                        int ldc);
1508 
1509   // See BlasSupport::DoBlasHer2k.
1510   Stream &ThenBlasHer2k(blas::UpperLower uplo, blas::Transpose trans, uint64 n,
1511                         uint64 k, std::complex<float> alpha,
1512                         const DeviceMemory<std::complex<float>> &a, int lda,
1513                         const DeviceMemory<std::complex<float>> &b, int ldb,
1514                         float beta, DeviceMemory<std::complex<float>> *c,
1515                         int ldc);
1516   Stream &ThenBlasHer2k(blas::UpperLower uplo, blas::Transpose trans, uint64 n,
1517                         uint64 k, std::complex<double> alpha,
1518                         const DeviceMemory<std::complex<double>> &a, int lda,
1519                         const DeviceMemory<std::complex<double>> &b, int ldb,
1520                         double beta, DeviceMemory<std::complex<double>> *c,
1521                         int ldc);
1522 
1523   // See BlasSupport::DoBlasSymm.
1524   Stream &ThenBlasSymm(blas::Side side, blas::UpperLower uplo, uint64 m,
1525                        uint64 n, float alpha, const DeviceMemory<float> &a,
1526                        int lda, const DeviceMemory<float> &b, int ldb,
1527                        float beta, DeviceMemory<float> *c, int ldc);
1528   Stream &ThenBlasSymm(blas::Side side, blas::UpperLower uplo, uint64 m,
1529                        uint64 n, double alpha, const DeviceMemory<double> &a,
1530                        int lda, const DeviceMemory<double> &b, int ldb,
1531                        double beta, DeviceMemory<double> *c, int ldc);
1532   Stream &ThenBlasSymm(blas::Side side, blas::UpperLower uplo, uint64 m,
1533                        uint64 n, std::complex<float> alpha,
1534                        const DeviceMemory<std::complex<float>> &a, int lda,
1535                        const DeviceMemory<std::complex<float>> &b, int ldb,
1536                        std::complex<float> beta,
1537                        DeviceMemory<std::complex<float>> *c, int ldc);
1538   Stream &ThenBlasSymm(blas::Side side, blas::UpperLower uplo, uint64 m,
1539                        uint64 n, std::complex<double> alpha,
1540                        const DeviceMemory<std::complex<double>> &a, int lda,
1541                        const DeviceMemory<std::complex<double>> &b, int ldb,
1542                        std::complex<double> beta,
1543                        DeviceMemory<std::complex<double>> *c, int ldc);
1544 
1545   // See BlasSupport::DoBlasSyrk.
1546   Stream &ThenBlasSyrk(blas::UpperLower uplo, blas::Transpose trans, uint64 n,
1547                        uint64 k, float alpha, const DeviceMemory<float> &a,
1548                        int lda, float beta, DeviceMemory<float> *c, int ldc);
1549   Stream &ThenBlasSyrk(blas::UpperLower uplo, blas::Transpose trans, uint64 n,
1550                        uint64 k, double alpha, const DeviceMemory<double> &a,
1551                        int lda, double beta, DeviceMemory<double> *c, int ldc);
1552   Stream &ThenBlasSyrk(blas::UpperLower uplo, blas::Transpose trans, uint64 n,
1553                        uint64 k, std::complex<float> alpha,
1554                        const DeviceMemory<std::complex<float>> &a, int lda,
1555                        std::complex<float> beta,
1556                        DeviceMemory<std::complex<float>> *c, int ldc);
1557   Stream &ThenBlasSyrk(blas::UpperLower uplo, blas::Transpose trans, uint64 n,
1558                        uint64 k, std::complex<double> alpha,
1559                        const DeviceMemory<std::complex<double>> &a, int lda,
1560                        std::complex<double> beta,
1561                        DeviceMemory<std::complex<double>> *c, int ldc);
1562 
1563   // See BlasSupport::DoBlasSyr2k.
1564   Stream &ThenBlasSyr2k(blas::UpperLower uplo, blas::Transpose trans, uint64 n,
1565                         uint64 k, float alpha, const DeviceMemory<float> &a,
1566                         int lda, const DeviceMemory<float> &b, int ldb,
1567                         float beta, DeviceMemory<float> *c, int ldc);
1568   Stream &ThenBlasSyr2k(blas::UpperLower uplo, blas::Transpose trans, uint64 n,
1569                         uint64 k, double alpha, const DeviceMemory<double> &a,
1570                         int lda, const DeviceMemory<double> &b, int ldb,
1571                         double beta, DeviceMemory<double> *c, int ldc);
1572   Stream &ThenBlasSyr2k(blas::UpperLower uplo, blas::Transpose trans, uint64 n,
1573                         uint64 k, std::complex<float> alpha,
1574                         const DeviceMemory<std::complex<float>> &a, int lda,
1575                         const DeviceMemory<std::complex<float>> &b, int ldb,
1576                         std::complex<float> beta,
1577                         DeviceMemory<std::complex<float>> *c, int ldc);
1578   Stream &ThenBlasSyr2k(blas::UpperLower uplo, blas::Transpose trans, uint64 n,
1579                         uint64 k, std::complex<double> alpha,
1580                         const DeviceMemory<std::complex<double>> &a, int lda,
1581                         const DeviceMemory<std::complex<double>> &b, int ldb,
1582                         std::complex<double> beta,
1583                         DeviceMemory<std::complex<double>> *c, int ldc);
1584 
1585   // See BlasSupport::DoBlasTrmm.
1586   Stream &ThenBlasTrmm(blas::Side side, blas::UpperLower uplo,
1587                        blas::Transpose transa, blas::Diagonal diag, uint64 m,
1588                        uint64 n, float alpha, const DeviceMemory<float> &a,
1589                        int lda, DeviceMemory<float> *b, int ldb);
1590   Stream &ThenBlasTrmm(blas::Side side, blas::UpperLower uplo,
1591                        blas::Transpose transa, blas::Diagonal diag, uint64 m,
1592                        uint64 n, double alpha, const DeviceMemory<double> &a,
1593                        int lda, DeviceMemory<double> *b, int ldb);
1594   Stream &ThenBlasTrmm(blas::Side side, blas::UpperLower uplo,
1595                        blas::Transpose transa, blas::Diagonal diag, uint64 m,
1596                        uint64 n, std::complex<float> alpha,
1597                        const DeviceMemory<std::complex<float>> &a, int lda,
1598                        DeviceMemory<std::complex<float>> *b, int ldb);
1599   Stream &ThenBlasTrmm(blas::Side side, blas::UpperLower uplo,
1600                        blas::Transpose transa, blas::Diagonal diag, uint64 m,
1601                        uint64 n, std::complex<double> alpha,
1602                        const DeviceMemory<std::complex<double>> &a, int lda,
1603                        DeviceMemory<std::complex<double>> *b, int ldb);
1604 
1605   // See BlasSupport::DoBlasTrsm.
1606   Stream &ThenBlasTrsm(blas::Side side, blas::UpperLower uplo,
1607                        blas::Transpose transa, blas::Diagonal diag, uint64 m,
1608                        uint64 n, float alpha, const DeviceMemory<float> &a,
1609                        int lda, DeviceMemory<float> *b, int ldb);
1610   Stream &ThenBlasTrsm(blas::Side side, blas::UpperLower uplo,
1611                        blas::Transpose transa, blas::Diagonal diag, uint64 m,
1612                        uint64 n, double alpha, const DeviceMemory<double> &a,
1613                        int lda, DeviceMemory<double> *b, int ldb);
1614   Stream &ThenBlasTrsm(blas::Side side, blas::UpperLower uplo,
1615                        blas::Transpose transa, blas::Diagonal diag, uint64 m,
1616                        uint64 n, std::complex<float> alpha,
1617                        const DeviceMemory<std::complex<float>> &a, int lda,
1618                        DeviceMemory<std::complex<float>> *b, int ldb);
1619   Stream &ThenBlasTrsm(blas::Side side, blas::UpperLower uplo,
1620                        blas::Transpose transa, blas::Diagonal diag, uint64 m,
1621                        uint64 n, std::complex<double> alpha,
1622                        const DeviceMemory<std::complex<double>> &a, int lda,
1623                        DeviceMemory<std::complex<double>> *b, int ldb);
1624 
1625   // See FftSupport::DoFft.
1626   Stream &ThenFft(fft::Plan *plan,
1627                   const DeviceMemory<std::complex<float>> &input,
1628                   DeviceMemory<std::complex<float>> *output);
1629   Stream &ThenFft(fft::Plan *plan,
1630                   const DeviceMemory<std::complex<double>> &input,
1631                   DeviceMemory<std::complex<double>> *output);
1632   Stream &ThenFft(fft::Plan *plan, const DeviceMemory<float> &input,
1633                   DeviceMemory<std::complex<float>> *output);
1634   Stream &ThenFft(fft::Plan *plan, const DeviceMemory<double> &input,
1635                   DeviceMemory<std::complex<double>> *output);
1636   Stream &ThenFft(fft::Plan *plan,
1637                   const DeviceMemory<std::complex<float>> &input,
1638                   DeviceMemory<float> *output);
1639   Stream &ThenFft(fft::Plan *plan,
1640                   const DeviceMemory<std::complex<double>> &input,
1641                   DeviceMemory<double> *output);
1642 
1643   // Makes the RNG use the provided value as the basis for further generation.
1644   // /dev/urandom (good) and /dev/random (better, but sometimes slow) are good
1645   // sources of seed data if the default (high quality) sources are not
1646   // desired.
1647   // For most use cases, this function will not be necessary; each provided
1648   // back-end implementation will be appropriately seeded by default.
1649   // At a minimum 16 bytes of data are required in the seed buffer.
1650   //
1651   // To seed with good (non-reproducible) data:
1652   //   File* f = File::Open("/dev/random", "r");
1653   //   int64 bytes_read = f->Read(seed_data, bytes_to_read);
1654   //   < error checking >
1655   //   stream.ThenSetRngSeed(seed_data, bytes_read);
1656   //
1657   // To seed with reproducible data:
1658   //   uint64_t seed_data[2] = { <data> };
1659   //   stream.ThenSetRngSeed(seed_data, 16);
1660   Stream &ThenSetRngSeed(const uint8 *seed, uint64 seed_bytes);
1661 
1662   // Populates the memory indicated by values with uniform-random-distribution
1663   // values. TODO(leary) seeding API/description
1664   //
1665   // Uses the type and size of the DeviceMemory to infer what data should be
1666   // populated.
1667   Stream &ThenPopulateRandUniform(DeviceMemory<float> *values);
1668   Stream &ThenPopulateRandUniform(DeviceMemory<double> *values);
1669   Stream &ThenPopulateRandUniform(DeviceMemory<std::complex<float>> *values);
1670   Stream &ThenPopulateRandUniform(DeviceMemory<std::complex<double>> *values);
1671   Stream &ThenPopulateRandGaussian(float mean, float stddev,
1672                                    DeviceMemory<float> *values);
1673   Stream &ThenPopulateRandGaussian(double mean, double stddev,
1674                                    DeviceMemory<double> *values);
1675 
1676   // Entrain onto the stream: a memcpy to a host destination from a GPU source
1677   // of the given target size. host_dst must be a pointer to host memory
1678   // allocated by StreamExecutor::HostMemoryAllocate or otherwise allocated and
1679   // then registered with StreamExecutor::HostMemoryRegister.
1680   Stream &ThenMemcpy(void *host_dst, const DeviceMemoryBase &gpu_src,
1681                      uint64 size);
1682 
1683   // Entrain onto the stream: a memcpy to a GPU destination from a host source
1684   // of the given target size. host_src must be a pointer to host memory
1685   // allocated by StreamExecutor::HostMemoryAllocate or otherwise allocated and
1686   // then registered with StreamExecutor::HostMemoryRegister.
1687   Stream &ThenMemcpy(DeviceMemoryBase *gpu_dst, const void *host_src,
1688                      uint64 size);
1689 
1690   // Alternative interface for memcpying from device to host that takes an
1691   // array slice. Checks that the destination size can accommodate the host
1692   // slice size.
1693   template <typename T>
ThenMemcpyD2H(const DeviceMemory<T> & gpu_src,port::MutableArraySlice<T> host_dst)1694   Stream &ThenMemcpyD2H(const DeviceMemory<T> &gpu_src,
1695                         port::MutableArraySlice<T> host_dst) {
1696     auto host_size = host_dst.size() * sizeof(T);
1697     CHECK(gpu_src.size() == 0 || host_size >= gpu_src.size());
1698     return ThenMemcpy(host_dst.begin(), gpu_src, host_size);
1699   }
1700 
1701   // Alternative interface for memcpying from host to device that takes an
1702   // array slice. Checks that the destination size can accommodate the host
1703   // slice size.
1704   template <typename T>
ThenMemcpyH2D(port::ArraySlice<T> host_src,DeviceMemory<T> * gpu_dst)1705   Stream &ThenMemcpyH2D(port::ArraySlice<T> host_src,
1706                         DeviceMemory<T> *gpu_dst) {
1707     auto host_size = host_src.size() * sizeof(T);
1708     CHECK(gpu_dst->size() == 0 || gpu_dst->size() >= host_size);
1709     return ThenMemcpy(gpu_dst, host_src.begin(), host_size);
1710   }
1711 
1712   // Entrain onto the stream: a memcpy to a GPU destination from a GPU source
1713   // of the given target size. gpu_src/dst must be pointers to GPU memory and
1714   // peer access must be enabled between their owning StreamExecutors.
1715   Stream &ThenMemcpy(DeviceMemoryBase *gpu_dst, const DeviceMemoryBase &gpu_src,
1716                      uint64 size);
1717 
1718   // Calls to the device-to-device copy overload of ThenMemcpy -- useful for
1719   // ensuring that the host pointer isn't getting confused accidentally with a
1720   // device pointer if you're not doing metaprogramming against the API.
ThenMemcpyD2D(DeviceMemoryBase * gpu_dst,const DeviceMemoryBase & gpu_src,uint64 size)1721   Stream &ThenMemcpyD2D(DeviceMemoryBase *gpu_dst,
1722                         const DeviceMemoryBase &gpu_src, uint64 size) {
1723     return ThenMemcpy(gpu_dst, gpu_src, size);
1724   }
1725 
1726   // Entrain onto the stream: a memset of zero at a GPU location of size bytes.
1727   // The location must not be null.
1728   Stream &ThenMemZero(DeviceMemoryBase *location, uint64 size);
1729 
1730   // Entrain onto the stream: a memset of a 32-bit pattern at a GPU location of
1731   // size bytes, where bytes must be evenly 32-bit sized (i.e. evenly divisible
1732   // by 4). The location must not be null.
1733   Stream &ThenMemset32(DeviceMemoryBase *location, uint32 pattern, uint64 size);
1734 
1735   // Enqueue a forward operation of the RNN model onto the stream.
1736   // See DnnSupport::DoRnnForward for more details.
1737   Stream &ThenRnnForward(const dnn::RnnDescriptor &rnn_desc,
1738                          const dnn::RnnSequenceTensorDescriptor &input_desc,
1739                          const DeviceMemory<Eigen::half> &input_data,
1740                          const dnn::RnnStateTensorDescriptor &input_h_desc,
1741                          const DeviceMemory<Eigen::half> &input_h_data,
1742                          const dnn::RnnStateTensorDescriptor &input_c_desc,
1743                          const DeviceMemory<Eigen::half> &input_c_data,
1744                          const DeviceMemory<Eigen::half> &params,
1745                          const dnn::RnnSequenceTensorDescriptor &output_desc,
1746                          DeviceMemory<Eigen::half> *output_data,
1747                          const dnn::RnnStateTensorDescriptor &output_h_desc,
1748                          DeviceMemory<Eigen::half> *output_h_data,
1749                          const dnn::RnnStateTensorDescriptor &output_c_desc,
1750                          DeviceMemory<Eigen::half> *output_c_data,
1751                          bool is_training,
1752                          ScratchAllocator *reserve_space_allocator,
1753                          ScratchAllocator *workspace_allocator,
1754                          dnn::ProfileResult *output_profile_result);
1755 
1756   Stream &ThenRnnForward(const dnn::RnnDescriptor &rnn_desc,
1757                          const dnn::RnnSequenceTensorDescriptor &input_desc,
1758                          const DeviceMemory<float> &input_data,
1759                          const dnn::RnnStateTensorDescriptor &input_h_desc,
1760                          const DeviceMemory<float> &input_h_data,
1761                          const dnn::RnnStateTensorDescriptor &input_c_desc,
1762                          const DeviceMemory<float> &input_c_data,
1763                          const DeviceMemory<float> &params,
1764                          const dnn::RnnSequenceTensorDescriptor &output_desc,
1765                          DeviceMemory<float> *output_data,
1766                          const dnn::RnnStateTensorDescriptor &output_h_desc,
1767                          DeviceMemory<float> *output_h_data,
1768                          const dnn::RnnStateTensorDescriptor &output_c_desc,
1769                          DeviceMemory<float> *output_c_data, bool is_training,
1770                          ScratchAllocator *reserve_space_allocator,
1771                          ScratchAllocator *workspace_allocator,
1772                          dnn::ProfileResult *output_profile_result);
1773 
1774   Stream &ThenRnnForward(const dnn::RnnDescriptor &rnn_desc,
1775                          const dnn::RnnSequenceTensorDescriptor &input_desc,
1776                          const DeviceMemory<double> &input_data,
1777                          const dnn::RnnStateTensorDescriptor &input_h_desc,
1778                          const DeviceMemory<double> &input_h_data,
1779                          const dnn::RnnStateTensorDescriptor &input_c_desc,
1780                          const DeviceMemory<double> &input_c_data,
1781                          const DeviceMemory<double> &params,
1782                          const dnn::RnnSequenceTensorDescriptor &output_desc,
1783                          DeviceMemory<double> *output_data,
1784                          const dnn::RnnStateTensorDescriptor &output_h_desc,
1785                          DeviceMemory<double> *output_h_data,
1786                          const dnn::RnnStateTensorDescriptor &output_c_desc,
1787                          DeviceMemory<double> *output_c_data, bool is_training,
1788                          ScratchAllocator *reserve_space_allocator,
1789                          ScratchAllocator *workspace_allocator,
1790                          dnn::ProfileResult *output_profile_result);
1791 
1792   // Enqueue a backward operation of the RNN model onto the stream.
1793   // See DnnSupport::DoRnnBackward for more details.
1794   Stream &ThenRnnBackward(
1795       const dnn::RnnDescriptor &rnn_desc,
1796       const dnn::RnnSequenceTensorDescriptor &input_desc,
1797       const DeviceMemory<Eigen::half> &input_data,
1798       const dnn::RnnStateTensorDescriptor &input_h_desc,
1799       const DeviceMemory<Eigen::half> &input_h_data,
1800       const dnn::RnnStateTensorDescriptor &input_c_desc,
1801       const DeviceMemory<Eigen::half> &input_c_data,
1802       const DeviceMemory<Eigen::half> &params,
1803       const dnn::RnnSequenceTensorDescriptor &output_desc,
1804       const DeviceMemory<Eigen::half> &output_data,
1805       const dnn::RnnStateTensorDescriptor &output_h_desc,
1806       const DeviceMemory<Eigen::half> &output_h_data,
1807       const dnn::RnnStateTensorDescriptor &output_c_desc,
1808       const DeviceMemory<Eigen::half> &output_c_data,
1809       const DeviceMemory<Eigen::half> &output_backprop_data,
1810       const DeviceMemory<Eigen::half> &output_h_backprop_data,
1811       const DeviceMemory<Eigen::half> &output_c_backprop_data,
1812       DeviceMemory<Eigen::half> *input_backprop_data,
1813       DeviceMemory<Eigen::half> *input_h_backprop_data,
1814       DeviceMemory<Eigen::half> *input_c_backprop_data,
1815       DeviceMemory<Eigen::half> *params_backprop_data,
1816       DeviceMemory<uint8> *reserve_space_data,
1817       ScratchAllocator *workspace_allocator,
1818       dnn::ProfileResult *output_profile_result);
1819 
1820   Stream &ThenRnnBackward(const dnn::RnnDescriptor &rnn_desc,
1821                           const dnn::RnnSequenceTensorDescriptor &input_desc,
1822                           const DeviceMemory<float> &input_data,
1823                           const dnn::RnnStateTensorDescriptor &input_h_desc,
1824                           const DeviceMemory<float> &input_h_data,
1825                           const dnn::RnnStateTensorDescriptor &input_c_desc,
1826                           const DeviceMemory<float> &input_c_data,
1827                           const DeviceMemory<float> &params,
1828                           const dnn::RnnSequenceTensorDescriptor &output_desc,
1829                           const DeviceMemory<float> &output_data,
1830                           const dnn::RnnStateTensorDescriptor &output_h_desc,
1831                           const DeviceMemory<float> &output_h_data,
1832                           const dnn::RnnStateTensorDescriptor &output_c_desc,
1833                           const DeviceMemory<float> &output_c_data,
1834                           const DeviceMemory<float> &output_backprop_data,
1835                           const DeviceMemory<float> &output_h_backprop_data,
1836                           const DeviceMemory<float> &output_c_backprop_data,
1837                           DeviceMemory<float> *input_backprop_data,
1838                           DeviceMemory<float> *input_h_backprop_data,
1839                           DeviceMemory<float> *input_c_backprop_data,
1840                           DeviceMemory<float> *params_backprop_data,
1841                           DeviceMemory<uint8> *reserve_space_data,
1842                           ScratchAllocator *workspace_allocator,
1843                           dnn::ProfileResult *output_profile_result);
1844 
1845   Stream &ThenRnnBackward(const dnn::RnnDescriptor &rnn_desc,
1846                           const dnn::RnnSequenceTensorDescriptor &input_desc,
1847                           const DeviceMemory<double> &input_data,
1848                           const dnn::RnnStateTensorDescriptor &input_h_desc,
1849                           const DeviceMemory<double> &input_h_data,
1850                           const dnn::RnnStateTensorDescriptor &input_c_desc,
1851                           const DeviceMemory<double> &input_c_data,
1852                           const DeviceMemory<double> &params,
1853                           const dnn::RnnSequenceTensorDescriptor &output_desc,
1854                           const DeviceMemory<double> &output_data,
1855                           const dnn::RnnStateTensorDescriptor &output_h_desc,
1856                           const DeviceMemory<double> &output_h_data,
1857                           const dnn::RnnStateTensorDescriptor &output_c_desc,
1858                           const DeviceMemory<double> &output_c_data,
1859                           const DeviceMemory<double> &output_backprop_data,
1860                           const DeviceMemory<double> &output_h_backprop_data,
1861                           const DeviceMemory<double> &output_c_backprop_data,
1862                           DeviceMemory<double> *input_backprop_data,
1863                           DeviceMemory<double> *input_h_backprop_data,
1864                           DeviceMemory<double> *input_c_backprop_data,
1865                           DeviceMemory<double> *params_backprop_data,
1866                           DeviceMemory<uint8> *reserve_space_data,
1867                           ScratchAllocator *workspace_allocator,
1868                           dnn::ProfileResult *output_profile_result);
1869 
1870   // Enqueue onto the stream a operation that transforms a tensor.
1871   // See DnnSupport::DoTransformTensor for more details.
1872   Stream &ThenTransformTensor(const dnn::BatchDescriptor &input_desc,
1873                               dnn::DataType input_type,
1874                               const DeviceMemoryBase &input_data,
1875                               const dnn::BatchDescriptor &output_desc,
1876                               dnn::DataType output_type, float scale,
1877                               DeviceMemoryBase *output_data);
1878 
1879   // The templated version of the above ThenTransformTensor. Useful when the
1880   // input and output types are statically known.
1881   template <typename InElemT, typename OutElemT>
ThenTransformTensor(const dnn::BatchDescriptor & input_desc,const DeviceMemory<InElemT> & input_data,const dnn::BatchDescriptor & output_desc,DeviceMemory<OutElemT> * output_data)1882   Stream &ThenTransformTensor(const dnn::BatchDescriptor &input_desc,
1883                               const DeviceMemory<InElemT> &input_data,
1884                               const dnn::BatchDescriptor &output_desc,
1885                               DeviceMemory<OutElemT> *output_data) {
1886     return ThenTransformTensor(input_desc, dnn::ToDataType<InElemT>(),
1887                                input_data, output_desc,
1888                                dnn::ToDataType<OutElemT>(), output_data);
1889   }
1890 
1891   // (Synchronously) block the host code waiting for the operations
1892   // entrained on the stream (enqueued to this point in program
1893   // execution) to complete.
1894   //
1895   // Returns an OK status if the blocking was successful and the stream is ok().
1896   // Otherwise returns an error describing why the blocking failed.
1897   port::Status BlockHostUntilDone() LOCKS_EXCLUDED(mu_);
1898 
1899   // Warning! This method interacts with internal threads in
1900   // sometimes-unpredictable ways and is intended for GPU-Executor-internal
1901   // use
1902   // only. Please check with a member of the FASTR team before making use of
1903   // this method.
1904   //
1905   // Entrains onto the stream a function to be executed on the host at some
1906   // point in the future.
1907   // Async host callbacks DO NOT block the stream as device functions (or as
1908   // synchronous host callbacks). No synchronization is possible with
1909   // asynchronous callbacks; they are strictly fire-and-forget.
1910   // This method is private due to the potential for undefined behavior with
1911   // synchronization using OpenCL user events.
1912   // The ONLY lifetime guarantee in these calls is that the StreamExecutor
1913   // parameter will still be valid - this Stream may not be!
1914   // Any callbacks requiring device API calls must use this method.
1915   Stream &ThenEnqueueOnBackgroundThread(
1916       std::function<void(StreamExecutor *)> task);
1917 
1918   // Returns the (opaque) platform-specific backing object. Ownership is not
1919   // transferred to the caller.
implementation()1920   internal::StreamInterface *implementation() { return implementation_.get(); }
1921 
1922   // Entrains onto the stream a callback to the host (from the device).
1923   // Behaves as ThenDoHostCallbackWithStatus below, but the callback should
1924   // never fail or its failure is inconsequential.
1925   //
1926   // This is kept for backward compatibility. Future code should use
1927   // ThenDoHostCallbackWithStatus and explicitly return a success status.
1928   // TODO(b/112125301): Eventually remove this method.
1929   Stream &ThenDoHostCallback(std::function<void()> callback);
1930 
1931   // Entrains onto the stream a callback to the host (from the device).
1932   // Host callbacks block/occupy the stream just as device functions
1933   // (execute one at a time, block later stream operations).
1934   // Whether the callback return status affects the result of BlockHostUntilDone
1935   // is platform-dependent.
1936   //
1937   // Behavior is undefined when synchronizing using OpenCL user events.
1938   // Behavior is undefined if host callbacks call device routines or insert
1939   // them into any stream.
1940   //
1941   // On certain platforms, ThenDoHostCallback is expected to have significant
1942   // negative effects on performance.
1943   Stream &ThenDoHostCallbackWithStatus(std::function<port::Status()> callback);
1944 
1945   // Returns the StreamExecutor (parent object) associated with this stream.
parent()1946   StreamExecutor *parent() const {
1947     CHECK(parent_ != nullptr);
1948     return parent_;
1949   }
1950 
1951   // Returns the (internal usage) temporary-memory-allocation manager associated
1952   // with this stream.
1953   internal::TemporaryMemoryManager *temporary_memory_manager();
1954 
1955   // Returns a debugging string "[stream=0x...,impl=0x...]".
1956   string DebugStreamPointers() const;
1957 
1958  private:
1959   friend class host::HostBlas;  // for parent_.
1960   friend class host::HostFft;   // for parent_.
1961   friend class host::HostRng;   // for parent_.
1962   template <typename... Args>
1963   friend struct ThenBlasImpl;  // for implementing ThenBlasXXX.
1964   friend class ocl::CLBlas;    // for parent_.
1965 
InErrorState()1966   bool InErrorState() const LOCKS_EXCLUDED(mu_) {
1967     tf_shared_lock lock(mu_);
1968     return !ok_;
1969   }
1970 
1971   // Sets the error state if operation_retcode is false.
1972   // This is a useful shorthand for many stream routines.
CheckError(bool operation_retcode)1973   void CheckError(bool operation_retcode) LOCKS_EXCLUDED(mu_) {
1974     if (operation_retcode) {
1975       return;
1976     }
1977     mutex_lock lock(mu_);
1978     ok_ = false;
1979   }
1980 
1981   // Checks the status and logs the error message, if any.
1982   void CheckStatus(port::Status status) LOCKS_EXCLUDED(mu_);
1983 
SetError()1984   void SetError() { CheckError(false /* = operation_retcode */); }
1985 
SetErrorAndLogNoDnnSupport()1986   void SetErrorAndLogNoDnnSupport() {
1987     SetError();
1988     LOG(WARNING) << "attempting to perform DNN operation using StreamExecutor "
1989                     "without DNN support";
1990   }
1991 
1992   // The StreamExecutor that supports the operation of this stream.
1993   StreamExecutor *parent_;
1994 
1995   // The platform-dependent implementation that the StreamExecutor interface
1996   // delegates to.
1997   std::unique_ptr<internal::StreamInterface> implementation_;
1998 
1999   // mutex that guards the allocation / error state flags.
2000   // Mutable so that it can be obtained via const reader lock.
2001   mutable mutex mu_;
2002 
2003   // Whether Init() was successfully called to allocate this stream on the
2004   // underlying platform. It simply flips from 0 to 1 with a sanity check.
2005   // See StreamExecutor::AllocateStream.
2006   bool allocated_ GUARDED_BY(mu_);
2007 
2008   // Whether all operations have entrained successfully to the current program
2009   // point.
2010   bool ok_ GUARDED_BY(mu_);
2011 
2012   // Sub-streams that are generated from this stream. Each element has a pointer
2013   // to sub-stream and a boolean value indicating if this substream is ready to
2014   // be reused.
2015   std::vector<std::pair<std::unique_ptr<Stream>, bool>> sub_streams_
2016       GUARDED_BY(mu_);
2017 
2018   // Streams can allocate temporary memories to help with work they enqueue
2019   // (e.g. for scratch memory spaces). This member tracks those allocations and
2020   // notes when they can be reclaimed -- reclamation is attempted when
2021   // BlockHostUntilDone() is called.
2022   internal::TemporaryMemoryManager temporary_memory_manager_;
2023 
2024   // Implementation of ThenConvolveBackwardBias that is shared by all types.
2025   template <typename T>
2026   Stream &ThenConvolveBackwardBiasImpl(
2027       const dnn::BatchDescriptor &input_descriptor,
2028       const DeviceMemory<T> &input_data,
2029       const dnn::BatchDescriptor &bias_descriptor,
2030       DeviceMemory<T> *backward_bias_data);
2031 
2032   SE_DISALLOW_COPY_AND_ASSIGN(Stream);
2033 };
2034 
2035 ////////////
2036 // Inlines
2037 
2038 template <typename T>
2039 inline port::StatusOr<std::unique_ptr<TemporaryDeviceMemory<T>>>
AllocateTemporaryArray(uint64 element_count)2040 Stream::AllocateTemporaryArray(uint64 element_count) {
2041   return temporary_memory_manager_.AllocateArray<T>(element_count);
2042 }
2043 
temporary_memory_manager()2044 inline internal::TemporaryMemoryManager *Stream::temporary_memory_manager() {
2045   return &temporary_memory_manager_;
2046 }
2047 
2048 template <>
2049 struct Quantization<uint8> {
2050   static constexpr dnn::QuantizedActivationMode kModeId =
2051       dnn::QuantizedActivationMode::k8Bit;
2052 };
2053 
2054 template <>
2055 struct Quantization<uint16> {
2056   static constexpr dnn::QuantizedActivationMode kModeId =
2057       dnn::QuantizedActivationMode::k16Bit;
2058 };
2059 
2060 template <>
2061 struct Quantization<int32> {
2062   static constexpr dnn::QuantizedActivationMode kModeId =
2063       dnn::QuantizedActivationMode::k32Bit;
2064 };
2065 
2066 }  // namespace stream_executor
2067 
2068 #endif  // TENSORFLOW_STREAM_EXECUTOR_STREAM_H_
2069