1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // The ROCM-specific DNN library support, implementing the general DnnSupport
17 // interface.
18 
19 #ifndef TENSORFLOW_STREAM_EXECUTOR_ROCM_ROCM_DNN_H_
20 #define TENSORFLOW_STREAM_EXECUTOR_ROCM_ROCM_DNN_H_
21 
22 #include "absl/synchronization/mutex.h"
23 #include "rocm/include/miopen/miopen.h"
24 #include "tensorflow/core/platform/thread_annotations.h"
25 #include "tensorflow/stream_executor/dnn.h"
26 #include "tensorflow/stream_executor/lib/status.h"
27 #include "tensorflow/stream_executor/plugin_registry.h"
28 #include "tensorflow/stream_executor/temporary_device_memory.h"
29 
30 namespace stream_executor {
31 namespace gpu {
32 
33 class GpuExecutor;
34 class MIOpenRnnDescriptor;
35 class MIOpenRnnSequenceTensorDescriptor;
36 class MIOpenRnnStateTensorDescriptor;
37 class MIOpenCTCLossDescriptor;
38 
39 // Opaque and unique identifier for the MIOpen plugin.
40 extern const PluginId kMIOpenPlugin;
41 
42 struct PoolingWorkspaceDescriptor {
43   std::vector<int64> input_dims;
44   std::vector<int64> output_dims;
45   dnn::PoolingDescriptor op;
46   int dtype;
47   uint64_t timestamp;
48   std::unique_ptr<TemporaryDeviceMemory<uint8>> workspace;
49   size_t workspace_size;
50   bool IsSame(const dnn::BatchDescriptor& input_dimensions,
51               const dnn::BatchDescriptor& output_dimensions,
52               const dnn::PoolingDescriptor& pooling_dimensions, int _type);
53 };
54 
55 struct PoolingWorkspaceCache {
56   std::map<const void*, PoolingWorkspaceDescriptor> cache;
57   const int trim_size = 1000;
58   const uint64_t memory_budget = 2e7;
59   uint64_t timestamp = 0;
60   uint64_t memory_used = 0;
61   bool find(const void* p, const dnn::BatchDescriptor& input_dimensions,
62             const dnn::BatchDescriptor& output_dimensions,
63             const dnn::PoolingDescriptor& pooling_dimensions, int _type,
64             PoolingWorkspaceDescriptor*& pdesc);
65   void insert(const void* p, const dnn::BatchDescriptor& input_dimensions,
66               const dnn::BatchDescriptor& output_dimensions,
67               const dnn::PoolingDescriptor& pooling_dimensions, int _type,
68               std::unique_ptr<TemporaryDeviceMemory<uint8>>& workspace,
69               size_t wsp_size, hipStream_t hip_stream);
70 
71  private:
72   void trim(hipStream_t hip_stream);
73 };
74 
75 // miopen-library based DNN support. For details on overridden interface
76 // functions, see dnn.h.
77 class MIOpenSupport : public dnn::DnnSupport {
78  public:
79   explicit MIOpenSupport(GpuExecutor* parent);
80 
81   port::Status Init() override;
82   port::StatusOr<perftools::gputools::dnn::VersionInfo> GetVersion() override;
83 
84   port::StatusOr<std::unique_ptr<dnn::RnnDescriptor>> createRnnDescriptor(
85       int num_layers, int hidden_size, int input_size, int cell_size,
86       int batch_size, dnn::RnnInputMode input_mode,
87       dnn::RnnDirectionMode direction_mode, dnn::RnnMode rnn_mode,
88       dnn::DataType data_type, const dnn::AlgorithmConfig& algorithm_config,
89       float dropout, uint64 seed, ScratchAllocator* state_allocator,
90       bool use_padded_io) override;
91 
92   port::StatusOr<std::unique_ptr<dnn::RnnSequenceTensorDescriptor>>
93   createRnnSequenceTensorDescriptor(int seq_length, int batch_size,
94                                     int data_size,
95                                     dnn::DataType data_type) override;
96 
97   port::StatusOr<std::unique_ptr<dnn::RnnStateTensorDescriptor>>
98   createRnnStateTensorDescriptor(int num_layer, int batch_size, int data_size,
99                                  dnn::DataType data_type) override;
100 
101   bool DoRnnForward(Stream* stream, const dnn::RnnDescriptor& rnn_desc,
102                     const dnn::RnnSequenceTensorDescriptor& input_desc,
103                     const DeviceMemory<Eigen::half>& input_data,
104                     const dnn::RnnStateTensorDescriptor& input_h_desc,
105                     const DeviceMemory<Eigen::half>& input_h_data,
106                     const dnn::RnnStateTensorDescriptor& input_c_desc,
107                     const DeviceMemory<Eigen::half>& input_c_data,
108                     const DeviceMemory<Eigen::half>& params,
109                     const dnn::RnnSequenceTensorDescriptor& output_desc,
110                     DeviceMemory<Eigen::half>* output_data,
111                     const dnn::RnnStateTensorDescriptor& output_h_desc,
112                     DeviceMemory<Eigen::half>* output_h_data,
113                     const dnn::RnnStateTensorDescriptor& output_c_desc,
114                     DeviceMemory<Eigen::half>* output_c_data, bool is_training,
115                     ScratchAllocator* reserve_space_allocator,
116                     ScratchAllocator* workspace_allocator,
117                     dnn::ProfileResult* output_profile_result) override;
118 
119   bool DoRnnForward(Stream* stream, const dnn::RnnDescriptor& rnn_desc,
120                     const dnn::RnnSequenceTensorDescriptor& input_desc,
121                     const DeviceMemory<float>& input_data,
122                     const dnn::RnnStateTensorDescriptor& input_h_desc,
123                     const DeviceMemory<float>& input_h_data,
124                     const dnn::RnnStateTensorDescriptor& input_c_desc,
125                     const DeviceMemory<float>& input_c_data,
126                     const DeviceMemory<float>& params,
127                     const dnn::RnnSequenceTensorDescriptor& output_desc,
128                     DeviceMemory<float>* output_data,
129                     const dnn::RnnStateTensorDescriptor& output_h_desc,
130                     DeviceMemory<float>* output_h_data,
131                     const dnn::RnnStateTensorDescriptor& output_c_desc,
132                     DeviceMemory<float>* output_c_data, bool is_training,
133                     ScratchAllocator* reserve_space_allocator,
134                     ScratchAllocator* workspace_allocator,
135                     dnn::ProfileResult* output_profile_result) override;
136 
137   bool DoRnnForward(Stream* stream, const dnn::RnnDescriptor& rnn_desc,
138                     const dnn::RnnSequenceTensorDescriptor& input_desc,
139                     const DeviceMemory<double>& input_data,
140                     const dnn::RnnStateTensorDescriptor& input_h_desc,
141                     const DeviceMemory<double>& input_h_data,
142                     const dnn::RnnStateTensorDescriptor& input_c_desc,
143                     const DeviceMemory<double>& input_c_data,
144                     const DeviceMemory<double>& params,
145                     const dnn::RnnSequenceTensorDescriptor& output_desc,
146                     DeviceMemory<double>* output_data,
147                     const dnn::RnnStateTensorDescriptor& output_h_desc,
148                     DeviceMemory<double>* output_h_data,
149                     const dnn::RnnStateTensorDescriptor& output_c_desc,
150                     DeviceMemory<double>* output_c_data, bool is_training,
151                     ScratchAllocator* reserve_space_allocator,
152                     ScratchAllocator* workspace_allocator,
153                     dnn::ProfileResult* output_profile_result) override;
154 
155   bool DoRnnBackward(Stream* stream, const dnn::RnnDescriptor& rnn_desc,
156                      const dnn::RnnSequenceTensorDescriptor& input_desc,
157                      const DeviceMemory<Eigen::half>& input_data,
158                      const dnn::RnnStateTensorDescriptor& input_h_desc,
159                      const DeviceMemory<Eigen::half>& input_h_data,
160                      const dnn::RnnStateTensorDescriptor& input_c_desc,
161                      const DeviceMemory<Eigen::half>& input_c_data,
162                      const DeviceMemory<Eigen::half>& params,
163                      const dnn::RnnSequenceTensorDescriptor& output_desc,
164                      const DeviceMemory<Eigen::half>& output_data,
165                      const dnn::RnnStateTensorDescriptor& output_h_desc,
166                      const DeviceMemory<Eigen::half>& output_h_data,
167                      const dnn::RnnStateTensorDescriptor& output_c_desc,
168                      const DeviceMemory<Eigen::half>& output_c_data,
169                      const DeviceMemory<Eigen::half>& output_backprop_data,
170                      const DeviceMemory<Eigen::half>& output_h_backprop_data,
171                      const DeviceMemory<Eigen::half>& output_c_backprop_data,
172                      DeviceMemory<Eigen::half>* input_backprop_data,
173                      DeviceMemory<Eigen::half>* input_h_backprop_data,
174                      DeviceMemory<Eigen::half>* input_c_backprop_data,
175                      DeviceMemory<Eigen::half>* params_backprop_data,
176                      DeviceMemory<uint8>* reserve_space_data,
177                      ScratchAllocator* workspace_allocator,
178                      dnn::ProfileResult* output_profile_result) override;
179 
180   bool DoRnnBackward(Stream* stream, const dnn::RnnDescriptor& rnn_desc,
181                      const dnn::RnnSequenceTensorDescriptor& input_desc,
182                      const DeviceMemory<float>& input_data,
183                      const dnn::RnnStateTensorDescriptor& input_h_desc,
184                      const DeviceMemory<float>& input_h_data,
185                      const dnn::RnnStateTensorDescriptor& input_c_desc,
186                      const DeviceMemory<float>& input_c_data,
187                      const DeviceMemory<float>& params,
188                      const dnn::RnnSequenceTensorDescriptor& output_desc,
189                      const DeviceMemory<float>& output_data,
190                      const dnn::RnnStateTensorDescriptor& output_h_desc,
191                      const DeviceMemory<float>& output_h_data,
192                      const dnn::RnnStateTensorDescriptor& output_c_desc,
193                      const DeviceMemory<float>& output_c_data,
194                      const DeviceMemory<float>& output_backprop_data,
195                      const DeviceMemory<float>& output_h_backprop_data,
196                      const DeviceMemory<float>& output_c_backprop_data,
197                      DeviceMemory<float>* input_backprop_data,
198                      DeviceMemory<float>* input_h_backprop_data,
199                      DeviceMemory<float>* input_c_backprop_data,
200                      DeviceMemory<float>* params_backprop_data,
201                      DeviceMemory<uint8>* reserve_space_data,
202                      ScratchAllocator* workspace_allocator,
203                      dnn::ProfileResult* output_profile_result) override;
204 
205   bool DoRnnBackward(Stream* stream, const dnn::RnnDescriptor& rnn_desc,
206                      const dnn::RnnSequenceTensorDescriptor& input_desc,
207                      const DeviceMemory<double>& input_data,
208                      const dnn::RnnStateTensorDescriptor& input_h_desc,
209                      const DeviceMemory<double>& input_h_data,
210                      const dnn::RnnStateTensorDescriptor& input_c_desc,
211                      const DeviceMemory<double>& input_c_data,
212                      const DeviceMemory<double>& params,
213                      const dnn::RnnSequenceTensorDescriptor& output_desc,
214                      const DeviceMemory<double>& output_data,
215                      const dnn::RnnStateTensorDescriptor& output_h_desc,
216                      const DeviceMemory<double>& output_h_data,
217                      const dnn::RnnStateTensorDescriptor& output_c_desc,
218                      const DeviceMemory<double>& output_c_data,
219                      const DeviceMemory<double>& output_backprop_data,
220                      const DeviceMemory<double>& output_h_backprop_data,
221                      const DeviceMemory<double>& output_c_backprop_data,
222                      DeviceMemory<double>* input_backprop_data,
223                      DeviceMemory<double>* input_h_backprop_data,
224                      DeviceMemory<double>* input_c_backprop_data,
225                      DeviceMemory<double>* params_backprop_data,
226                      DeviceMemory<uint8>* reserve_space_data,
227                      ScratchAllocator* workspace_allocator,
228                      dnn::ProfileResult* output_profile_result) override;
229 
230   bool GetConvolveAlgorithms(
231       bool with_winograd_nonfused, int cc_major, int cc_minor,
232       std::vector<dnn::AlgorithmDesc>* out_algorithms) override;
233 
234   bool GetMIOpenConvolveAlgorithms(
235       dnn::ConvolutionKind kind, dnn::DataType element_type, Stream* stream,
236       const dnn::BatchDescriptor& input_descriptor, DeviceMemoryBase input_data,
237       const dnn::FilterDescriptor& filter_descriptor,
238       DeviceMemoryBase filter_data,
239       const dnn::BatchDescriptor& output_descriptor,
240       DeviceMemoryBase output_data,
241       const dnn::ConvolutionDescriptor& convolution_descriptor,
242       ScratchAllocator* scratch_allocator,
243       std::vector<dnn::ProfileResult>* out_algorithms) override;
244 
245   bool GetRnnAlgorithms(
246       std::vector<dnn::AlgorithmDesc>* out_algorithms) override;
247 
248   bool GetConvolveBackwardDataAlgorithms(
249       bool with_winograd_nonfused, int cc_major, int cc_minor,
250       std::vector<dnn::AlgorithmDesc>* out_algorithms) override;
251 
252   bool GetConvolveBackwardFilterAlgorithms(
253       bool with_winograd_nonfused, int cc_major, int cc_minor,
254       std::vector<dnn::AlgorithmDesc>* out_algorithms) override;
255 
256   bool DoBatchNormalizationForward(
257       Stream* stream, const DeviceMemory<float>& x,
258       const DeviceMemory<float>& scale, const DeviceMemory<float>& offset,
259       const DeviceMemory<float>& estimated_mean,
260       const DeviceMemory<float>& estimated_variance,
261       const DeviceMemory<float>& side_input, const dnn::BatchDescriptor& x_desc,
262       const dnn::BatchDescriptor& scale_offset_desc, const double epsilon,
263       const double exponential_average_factor,
264       dnn::ActivationMode activation_mode, DeviceMemory<float>* y,
265       DeviceMemory<float>* batch_mean, DeviceMemory<float>* batch_var,
266       DeviceMemory<float>* saved_mean, DeviceMemory<float>* saved_inv_var,
267       bool is_training, ScratchAllocator* reserve_space_allocator,
268       ScratchAllocator* workspace_allocator) override;
269 
270   bool DoBatchNormalizationForward(
271       Stream* stream, const DeviceMemory<Eigen::half>& x,
272       const DeviceMemory<float>& scale, const DeviceMemory<float>& offset,
273       const DeviceMemory<float>& estimated_mean,
274       const DeviceMemory<float>& estimated_variance,
275       const DeviceMemory<float>& side_input, const dnn::BatchDescriptor& x_desc,
276       const dnn::BatchDescriptor& scale_offset_desc, const double epsilon,
277       const double exponential_average_factor,
278       dnn::ActivationMode activation_mode, DeviceMemory<Eigen::half>* y,
279       DeviceMemory<float>* batch_mean, DeviceMemory<float>* batch_var,
280       DeviceMemory<float>* saved_mean, DeviceMemory<float>* saved_inv_var,
281       bool is_training, ScratchAllocator* reserve_space_allocator,
282       ScratchAllocator* workspace_allocator) override;
283 
284   bool DoBatchNormalizationBackward(
285       Stream* stream, const DeviceMemory<float>& y_backprop,
286       const DeviceMemory<float>& x, const DeviceMemory<float>& scale,
287       const DeviceMemory<float>& mean, const DeviceMemory<float>& variance,
288       const dnn::BatchDescriptor& x_desc,
289       const dnn::BatchDescriptor& scale_offset_desc, const double epsilon,
290       DeviceMemory<float>* x_backprop, DeviceMemory<float>* scale_backprop,
291       DeviceMemory<float>* offset_backprop,
292       DeviceMemory<uint8>* reserve_space_data,
293       ScratchAllocator* workspace_allocator) override;
294 
295   bool DoBatchNormalizationBackward(
296       Stream* stream, const DeviceMemory<Eigen::half>& y_backprop,
297       const DeviceMemory<Eigen::half>& x, const DeviceMemory<float>& scale,
298       const DeviceMemory<float>& mean, const DeviceMemory<float>& inv_var,
299       const dnn::BatchDescriptor& x_desc,
300       const dnn::BatchDescriptor& scale_offset_desc, const double epsilon,
301       DeviceMemory<Eigen::half>* x_backprop,
302       DeviceMemory<float>* scale_backprop, DeviceMemory<float>* offset_backprop,
303       DeviceMemory<uint8>* reserve_space_data,
304       ScratchAllocator* workspace_allocator) override;
305 
306   port::Status DoConvolve(
307       dnn::ConvolutionKind kind, dnn::DataType element_type,
308       dnn::DataType output_type, Stream* stream,
309       const dnn::BatchDescriptor& input_descriptor, DeviceMemoryBase input_data,
310       const dnn::FilterDescriptor& filter_descriptor,
311       DeviceMemoryBase filter_data,
312       const dnn::BatchDescriptor& output_descriptor,
313       DeviceMemoryBase output_data,
314       const dnn::ConvolutionDescriptor& convolution_descriptor,
315       dnn::AlgorithmDesc algorithm_desc, DeviceMemory<uint8> scratch_memory,
316       dnn::ProfileResult* output_profile_result) override;
317 
318   port::Status DoFusedConvolve(
319       Stream* stream, const dnn::BatchDescriptor& conv_input_descriptor,
320       const DeviceMemory<double>& conv_input_data, double conv_input_scale,
321       const dnn::FilterDescriptor& filter_descriptor,
322       const DeviceMemory<double>& filter_data,
323       const dnn::ConvolutionDescriptor& convolution_descriptor,
324       const DeviceMemory<double>& side_input_data, double side_input_scale,
325       const dnn::BatchDescriptor& bias_descriptor,
326       const DeviceMemory<double>& biases, dnn::ActivationMode activation_mode,
327       const dnn::BatchDescriptor& output_descriptor,
328       DeviceMemory<double>* output_data, ScratchAllocator* scratch_allocator,
329       const dnn::AlgorithmConfig& algorithm_config,
330       dnn::ProfileResult* output_profile_result) override;
331 
332   port::Status DoFusedConvolve(
333       Stream* stream, const dnn::BatchDescriptor& conv_input_descriptor,
334       const DeviceMemory<float>& conv_input_data, float conv_input_scale,
335       const dnn::FilterDescriptor& filter_descriptor,
336       const DeviceMemory<float>& filter_data,
337       const dnn::ConvolutionDescriptor& convolution_descriptor,
338       const DeviceMemory<float>& side_input_data, float side_input_scale,
339       const dnn::BatchDescriptor& bias_descriptor,
340       const DeviceMemory<float>& biases, dnn::ActivationMode activation_mode,
341       const dnn::BatchDescriptor& output_descriptor,
342       DeviceMemory<float>* output_data, ScratchAllocator* scratch_allocator,
343       const dnn::AlgorithmConfig& algorithm_config,
344       dnn::ProfileResult* output_profile_result) override;
345 
346   port::Status DoFusedConvolve(
347       Stream* stream, const dnn::BatchDescriptor& conv_input_descriptor,
348       const DeviceMemory<Eigen::half>& conv_input_data, float conv_input_scale,
349       const dnn::FilterDescriptor& filter_descriptor,
350       const DeviceMemory<Eigen::half>& filter_data,
351       const dnn::ConvolutionDescriptor& convolution_descriptor,
352       const DeviceMemory<Eigen::half>& side_input_data, float side_input_scale,
353       const dnn::BatchDescriptor& bias_descriptor,
354       const DeviceMemory<Eigen::half>& biases,
355       dnn::ActivationMode activation_mode,
356       const dnn::BatchDescriptor& output_descriptor,
357       DeviceMemory<Eigen::half>* output_data,
358       ScratchAllocator* scratch_allocator,
359       const dnn::AlgorithmConfig& algorithm_config,
360       dnn::ProfileResult* output_profile_result) override;
361 
362   port::Status DoFusedConvolve(
363       Stream* stream, const dnn::BatchDescriptor& conv_input_descriptor,
364       const DeviceMemory<int8>& conv_input_data, float conv_input_scale,
365       const dnn::FilterDescriptor& filter_descriptor,
366       const DeviceMemory<int8>& filter_data,
367       const dnn::ConvolutionDescriptor& convolution_descriptor,
368       const DeviceMemory<int8>& side_input_data, float side_input_scale,
369       const dnn::BatchDescriptor& bias_descriptor,
370       const DeviceMemory<float>& biases, dnn::ActivationMode activation_mode,
371       const dnn::BatchDescriptor& output_descriptor,
372       DeviceMemory<int8>* output_data, ScratchAllocator* scratch_allocator,
373       const dnn::AlgorithmConfig& algorithm_config,
374       dnn::ProfileResult* output_profile_result) override;
375 
DoConvolveQuantized(Stream * stream,const dnn::BatchDescriptor & input_descriptor,const DeviceMemory<float> & input_data,const dnn::FilterDescriptor & filter_descriptor,const DeviceMemory<int8> & filter_coefficients,const DeviceMemory<float> & coefficient_scales,const dnn::ConvolutionDescriptor & convolution_descriptor,const dnn::BatchDescriptor & output_descriptor,DeviceMemory<float> * output_data)376   bool DoConvolveQuantized(
377       Stream* stream, const dnn::BatchDescriptor& input_descriptor,
378       const DeviceMemory<float>& input_data,
379       const dnn::FilterDescriptor& filter_descriptor,
380       const DeviceMemory<int8>& filter_coefficients,
381       const DeviceMemory<float>& coefficient_scales,
382       const dnn::ConvolutionDescriptor& convolution_descriptor,
383       const dnn::BatchDescriptor& output_descriptor,
384       DeviceMemory<float>* output_data) override {
385     LOG(ERROR) << "DoConvolveQuantized not supported by MIOpen";
386     return false;
387   }
388 
DoConvolveQuantized(Stream * stream,const dnn::BatchDescriptor & input_descriptor,const DeviceMemory<float> & input_data,const dnn::FilterDescriptor & filter_descriptor,const DeviceMemory<int16> & filter_coefficients,const DeviceMemory<float> & coefficient_scales,const dnn::ConvolutionDescriptor & convolution_descriptor,const dnn::BatchDescriptor & output_descriptor,DeviceMemory<float> * output_data)389   bool DoConvolveQuantized(
390       Stream* stream, const dnn::BatchDescriptor& input_descriptor,
391       const DeviceMemory<float>& input_data,
392       const dnn::FilterDescriptor& filter_descriptor,
393       const DeviceMemory<int16>& filter_coefficients,
394       const DeviceMemory<float>& coefficient_scales,
395       const dnn::ConvolutionDescriptor& convolution_descriptor,
396       const dnn::BatchDescriptor& output_descriptor,
397       DeviceMemory<float>* output_data) override {
398     LOG(ERROR) << "DoConvolveQuantized not supported by MIOpen";
399     return false;
400   }
401 
DoSeparableConvolve(Stream * stream,const dnn::BatchDescriptor & batch_descriptor,const DeviceMemory<float> & input_data,const dnn::FilterDescriptor & filter_descriptor,int depth_multiplier,const DeviceMemory<float> & first_weights,const DeviceMemory<float> & second_weights,const dnn::ConvolutionDescriptor & convolution_descriptor,const dnn::BatchDescriptor & output_descriptor,DeviceMemory<float> * output_data)402   bool DoSeparableConvolve(
403       Stream* stream, const dnn::BatchDescriptor& batch_descriptor,
404       const DeviceMemory<float>& input_data,
405       const dnn::FilterDescriptor& filter_descriptor, int depth_multiplier,
406       const DeviceMemory<float>& first_weights,
407       const DeviceMemory<float>& second_weights,
408       const dnn::ConvolutionDescriptor& convolution_descriptor,
409       const dnn::BatchDescriptor& output_descriptor,
410       DeviceMemory<float>* output_data) override {
411     LOG(ERROR) << "separable convolution not supported by MIOpen";
412     return false;
413   }
414 
415   bool DoConvolveBackwardBias(
416       Stream* stream, const dnn::BatchDescriptor& input_descriptor,
417       const DeviceMemory<double>& input_data,
418       const dnn::BatchDescriptor& bias_descriptor,
419       DeviceMemory<double>* backward_bias_data) override;
420 
421   bool DoConvolveBackwardBias(Stream* stream,
422                               const dnn::BatchDescriptor& input_descriptor,
423                               const DeviceMemory<float>& input_data,
424                               const dnn::BatchDescriptor& bias_descriptor,
425                               DeviceMemory<float>* backward_bias_data) override;
426 
427   bool DoConvolveBackwardBias(
428       Stream* stream, const dnn::BatchDescriptor& input_descriptor,
429       const DeviceMemory<Eigen::half>& input_data,
430       const dnn::BatchDescriptor& bias_descriptor,
431       DeviceMemory<Eigen::half>* backward_bias_data) override;
432 
433   bool DoMatMul(Stream* stream, const DeviceMemory<float>& input_data,
434                 const DeviceMemory<float>& weights,
435                 const dnn::BatchDescriptor& input_dimensions,
436                 const dnn::BatchDescriptor& output_dimensions,
437                 DeviceMemory<float>* output_data) override;
438 
DoMatMulQuantized(Stream * stream,const DeviceMemory<float> & input_data,const DeviceMemory<int8> & quantized_weights,const DeviceMemory<float> & weight_scales,const dnn::BatchDescriptor & input_dimensions,const dnn::BatchDescriptor & output_dimensions,DeviceMemory<float> * output_data)439   bool DoMatMulQuantized(Stream* stream, const DeviceMemory<float>& input_data,
440                          const DeviceMemory<int8>& quantized_weights,
441                          const DeviceMemory<float>& weight_scales,
442                          const dnn::BatchDescriptor& input_dimensions,
443                          const dnn::BatchDescriptor& output_dimensions,
444                          DeviceMemory<float>* output_data) override {
445     LOG(ERROR) << "DNN MatMulQuantized not supported by MIOpen";
446     return false;
447   }
448 
DoMatMulQuantized(Stream * stream,const DeviceMemory<float> & input_data,const DeviceMemory<int16> & quantized_weights,const DeviceMemory<float> & weight_scales,const dnn::BatchDescriptor & input_dimensions,const dnn::BatchDescriptor & output_dimensions,DeviceMemory<float> * output_data)449   bool DoMatMulQuantized(Stream* stream, const DeviceMemory<float>& input_data,
450                          const DeviceMemory<int16>& quantized_weights,
451                          const DeviceMemory<float>& weight_scales,
452                          const dnn::BatchDescriptor& input_dimensions,
453                          const dnn::BatchDescriptor& output_dimensions,
454                          DeviceMemory<float>* output_data) override {
455     LOG(ERROR) << "DNN MatMulQuantized not supported by MIOpen";
456     return false;
457   }
458 
459   bool DoBiasAdd(Stream* stream, const DeviceMemory<float>& input_data,
460                  const DeviceMemory<float>& biases,
461                  const dnn::BatchDescriptor& dimensions,
462                  DeviceMemory<float>* output_data) override;
463 
464   bool DoActivate(Stream* stream, dnn::ActivationMode activation_mode,
465                   const dnn::BatchDescriptor& dimensions,
466                   const DeviceMemory<float>& input_data,
467                   DeviceMemory<float>* output_data, uint64 options) override;
468 
469   bool DoPoolForward(Stream* stream,
470                      const dnn::PoolingDescriptor& pooling_dimensions,
471                      const dnn::BatchDescriptor& input_dimensions,
472                      const DeviceMemory<double>& input_data,
473                      const dnn::BatchDescriptor& output_dimensions,
474                      DeviceMemory<double>* output_data,
475                      ScratchAllocator* workspace_allocator = nullptr) override;
476 
477   bool DoPoolForward(Stream* stream,
478                      const dnn::PoolingDescriptor& pooling_dimensions,
479                      const dnn::BatchDescriptor& input_dimensions,
480                      const DeviceMemory<float>& input_data,
481                      const dnn::BatchDescriptor& output_dimensions,
482                      DeviceMemory<float>* output_data,
483                      ScratchAllocator* workspace_allocator = nullptr) override;
484 
485   bool DoPoolForward(Stream* stream,
486                      const dnn::PoolingDescriptor& pooling_dimensions,
487                      const dnn::BatchDescriptor& input_dimensions,
488                      const DeviceMemory<Eigen::half>& input_data,
489                      const dnn::BatchDescriptor& output_dimensions,
490                      DeviceMemory<Eigen::half>* output_data,
491                      ScratchAllocator* workspace_allocator = nullptr) override;
492 
493   bool DoPoolBackward(Stream* stream,
494                       const dnn::PoolingDescriptor& pooling_dimensions,
495                       const dnn::BatchDescriptor& input_dimensions,
496                       const DeviceMemory<double>& input_data,
497                       const dnn::BatchDescriptor& output_dimensions,
498                       const DeviceMemory<double>& output_data,
499                       const DeviceMemory<double>& input_diff_data,
500                       DeviceMemory<double>* output_diff_data,
501                       ScratchAllocator* workspace_allocator = nullptr) override;
502 
503   bool DoPoolBackward(Stream* stream,
504                       const dnn::PoolingDescriptor& pooling_dimensions,
505                       const dnn::BatchDescriptor& input_dimensions,
506                       const DeviceMemory<float>& input_data,
507                       const dnn::BatchDescriptor& output_dimensions,
508                       const DeviceMemory<float>& output_data,
509                       const DeviceMemory<float>& input_diff_data,
510                       DeviceMemory<float>* output_diff_data,
511                       ScratchAllocator* workspace_allocator = nullptr) override;
512 
513   bool DoPoolBackward(Stream* stream,
514                       const dnn::PoolingDescriptor& pooling_dimensions,
515                       const dnn::BatchDescriptor& input_dimensions,
516                       const DeviceMemory<Eigen::half>& input_data,
517                       const dnn::BatchDescriptor& output_dimensions,
518                       const DeviceMemory<Eigen::half>& output_data,
519                       const DeviceMemory<Eigen::half>& input_diff_data,
520                       DeviceMemory<Eigen::half>* output_diff_data,
521                       ScratchAllocator* workspace_allocator = nullptr) override;
522 
523   bool DoNormalizeWithDimensions(
524       Stream* stream, const dnn::NormalizeDescriptor& normalize_descriptor,
525       const dnn::BatchDescriptor& dimensions,
526       const DeviceMemory<float>& input_data,
527       DeviceMemory<float>* output_data) override;
528 
529   bool DoNormalizeBackwardWithDimensions(
530       Stream* stream, const dnn::NormalizeDescriptor& normalize_descriptor,
531       const dnn::BatchDescriptor& dimensions,
532       const DeviceMemory<float>& raw_data,
533       const DeviceMemory<float>& normalized_data,
534       const DeviceMemory<float>& normalized_variable_gradient,
535       DeviceMemory<float>* raw_variable_gradient,
536       ScratchAllocator* workspace_allocator = nullptr) override;
537 
538   bool DoDepthConcatenate(
539       Stream* stream, port::ArraySlice<dnn::BatchDescriptor> input_dimensions,
540       port::ArraySlice<const DeviceMemory<float>*> input_data,
541       DeviceMemory<float>* output_data) override;
542 
543   bool DoElementwiseOperate(
544       Stream* stream, dnn::ElementwiseOperation operation,
545       port::ArraySlice<dnn::BatchDescriptor> input_dimensions,
546       port::ArraySlice<const DeviceMemory<float>*> input_data,
547       const dnn::BatchDescriptor& output_dimensions,
548       DeviceMemory<float>* output_data) override;
549 
550   bool DoXYPad(Stream* stream, const dnn::BatchDescriptor& dimensions,
551                const DeviceMemory<float>& input_data, int64 left_pad,
552                int64 right_pad, int64 top_pad, int64 bottom_pad,
553                DeviceMemory<float>* output_data) override;
554 
555   bool DoXYSlice(Stream* stream, const dnn::BatchDescriptor& dimensions,
556                  const DeviceMemory<float>& input_data, int64 left_trim,
557                  int64 right_trim, int64 top_trim, int64 bottom_trim,
558                  DeviceMemory<float>* output_data) override;
559 
560   bool DoMemcpyD2HQuantized(Stream* stream,
561                             const DeviceMemory<float>& device_unquantized_src,
562                             dnn::QuantizedActivationMode mode, void* host_dst,
563                             int64 size) override;
564 
565   bool DoMemcpyH2DQuantized(
566       Stream* stream, const void* host_src, int64 size,
567       dnn::QuantizedActivationMode mode,
568       DeviceMemory<float>* device_unquantized_dst) override;
569 
570   // Derives an output batch descriptor from an input batch and convolution
571   // descriptors.
572   bool DeriveOutputBatchDescriptor(
573       const dnn::BatchDescriptor& batch_descriptor,
574       const dnn::FilterDescriptor& filter_descriptor,
575       const dnn::ConvolutionDescriptor& convolution_descriptor,
576       dnn::BatchDescriptor* output_batch_descriptor);
577 
578   bool DoTransformTensor(Stream* stream, const dnn::BatchDescriptor& input_desc,
579                          dnn::DataType input_type,
580                          const DeviceMemoryBase& input_data,
581                          const dnn::BatchDescriptor& output_desc,
582                          dnn::DataType output_type, float scale,
583                          DeviceMemoryBase* output_data) override;
584 
585   bool DoFusedConvolutionBiasActivation(
586       Stream* stream, const dnn::BatchDescriptor& conv_input_descriptor,
587       const DeviceMemory<float>& conv_input_data,
588       const dnn::FilterDescriptor& filter_descriptor,
589       const DeviceMemory<float>& filter_data,
590       const dnn::ConvolutionDescriptor& convolution_descriptor,
591       const dnn::BatchDescriptor& bias_descriptor,
592       const DeviceMemory<float>& bias_data, dnn::ActivationMode activation_mode,
593       const dnn::BatchDescriptor& output_descriptor,
594       DeviceMemory<float>* output_data,
595       dnn::ProfileResult* output_profile_result) override;
596 
597   bool DoFusedBatchNormActivationInference(
598       Stream* stream, const dnn::BatchDescriptor& x_descriptor,
599       const DeviceMemory<float>& x_data,
600       const dnn::BatchDescriptor& scale_mean_variance_descriptor,
601       const DeviceMemory<float>& scale_data,
602       const DeviceMemory<float>& offset_data,
603       const DeviceMemory<float>& mean_data,
604       const DeviceMemory<float>& variance_data, double epsilon,
605       dnn::ActivationMode activation_mode, DeviceMemory<float>* y_data,
606       dnn::ProfileResult* output_profile_result) override;
607 
608   bool DoFusedBatchNormActivationInference(
609       Stream* stream, const dnn::BatchDescriptor& x_descriptor,
610       const DeviceMemory<Eigen::half>& x_data,
611       const dnn::BatchDescriptor& scale_mean_variance_descriptor,
612       const DeviceMemory<float>& scale_data,
613       const DeviceMemory<float>& offset_data,
614       const DeviceMemory<float>& mean_data,
615       const DeviceMemory<float>& variance_data, double epsilon,
616       dnn::ActivationMode activation_mode, DeviceMemory<Eigen::half>* y_data,
617       dnn::ProfileResult* output_profile_result) override;
618 
619   bool DoFusedBatchNormActivationForward(
620       Stream* stream, const dnn::BatchDescriptor& x_descriptor,
621       const DeviceMemory<float>& x_data,
622       const dnn::BatchDescriptor& scale_offset_mean_variance_descriptor,
623       const DeviceMemory<float>& scale_data,
624       const DeviceMemory<float>& offset_data, double epsilon,
625       dnn::ActivationMode activation_mode, DeviceMemory<float>* y_data,
626       DeviceMemory<float>* batch_mean_data, DeviceMemory<float>* batch_var_data,
627       DeviceMemory<float>* saved_mean_data, DeviceMemory<float>* saved_var_data,
628       dnn::ProfileResult* output_profile_result) override;
629 
630   bool DoFusedBatchNormActivationForward(
631       Stream* stream, const dnn::BatchDescriptor& x_descriptor,
632       const DeviceMemory<Eigen::half>& x_data,
633       const dnn::BatchDescriptor& scale_offset_mean_variance_descriptor,
634       const DeviceMemory<float>& scale_data,
635       const DeviceMemory<float>& offset_data, double epsilon,
636       dnn::ActivationMode activation_mode, DeviceMemory<Eigen::half>* y_data,
637       DeviceMemory<float>* batch_mean_data, DeviceMemory<float>* batch_var_data,
638       DeviceMemory<float>* saved_mean_data, DeviceMemory<float>* saved_var_data,
639       dnn::ProfileResult* output_profile_result) override;
640 
641   bool DoFusedBatchNormActivationBackward(
642       Stream* stream, const dnn::BatchDescriptor& y_act_backprop_descriptor,
643       const DeviceMemory<float>& y_act_backprop_data,
644       const DeviceMemory<float>& y_act_data,
645       dnn::ActivationMode activation_mode, const DeviceMemory<float>& x_bn_data,
646       const dnn::BatchDescriptor& scale_offset_mean_variance_descriptor,
647       const DeviceMemory<float>& scale_data,
648       const DeviceMemory<float>& offset_data,
649       const DeviceMemory<float>& saved_mean_data,
650       const DeviceMemory<float>& saved_var_data,
651       DeviceMemory<float>* x_bn_backprop_data,
652       DeviceMemory<float>* scale_backprop_data,
653       DeviceMemory<float>* offset_backprop_data,
654       dnn::ProfileResult* output_profile_result) override;
655 
656   bool DoFusedBatchNormActivationBackward(
657       Stream* stream, const dnn::BatchDescriptor& y_act_backprop_descriptor,
658       const DeviceMemory<Eigen::half>& y_act_backprop_data,
659       const DeviceMemory<Eigen::half>& y_act_data,
660       dnn::ActivationMode activation_mode,
661       const DeviceMemory<Eigen::half>& x_bn_data,
662       const dnn::BatchDescriptor& scale_offset_mean_variance_descriptor,
663       const DeviceMemory<float>& scale_data,
664       const DeviceMemory<float>& offset_data,
665       const DeviceMemory<float>& saved_mean_data,
666       const DeviceMemory<float>& saved_var_data,
667       DeviceMemory<Eigen::half>* x_bn_backprop_data,
668       DeviceMemory<float>* scale_backprop_data,
669       DeviceMemory<float>* offset_backprop_data,
670       dnn::ProfileResult* output_profile_result) override;
671 
GetParentExecutor()672   GpuExecutor* GetParentExecutor() { return parent_; }
673 
674   port::Status DoCtcLoss(Stream* stream, dnn::DataType element_type,
675                          const dnn::RnnStateTensorDescriptor& probs_desc,
676                          const DeviceMemoryBase probs_data,
677                          absl::Span<const int> labels_data,
678                          absl::Span<const int> labels_lengths_data,
679                          absl::Span<const int> input_lengths_data,
680                          DeviceMemoryBase costs_data,
681                          const dnn::RnnStateTensorDescriptor& grads_desc,
682                          DeviceMemoryBase grads_data,
683                          DeviceMemory<uint8> scratch_memory,
684                          int ctc_loss_algo_id) override;
685 
686  private:
687   GpuExecutor* parent_;  // Parent executor object. Not owned.
688 
689   // Flag to indicate whether Get*Algorithm routines should only return
690   // the best algorithm (as opposed to a list of all applicable ones)
691   bool return_best_algo_only_;
692 
693   // Flag to indicate whether to use Immediate (or Find) mode for Convolutions
694   bool use_immediate_mode_;
695 
696   // Provide access to the MIOpen handle.
697   std::unique_ptr<class MIOpenAccess> miopen_;
698 
699   PoolingWorkspaceCache m_pooling_cache;
700   bool m_pooling_cache_allowed = false;
701   bool m_pooling_cache_enabled = false;
702 
703   template <class T, class U>
704   bool DoBatchNormalizationForwardImpl(
705       Stream* stream, dnn::DataType input_data_type,
706       dnn::DataType scale_data_type, const DeviceMemory<T>& x,
707       const DeviceMemory<U>& scale, const DeviceMemory<U>& offset,
708       const DeviceMemory<U>& estimated_mean,
709       const DeviceMemory<U>& estimated_variance,
710       const DeviceMemory<U>& side_input, const dnn::BatchDescriptor& x_desc,
711       const dnn::BatchDescriptor& scale_offset_desc, const double epsilon,
712       const double exponential_average_factor,
713       dnn::ActivationMode activation_mode, DeviceMemory<T>* y,
714       DeviceMemory<U>* batch_mean, DeviceMemory<U>* batch_var,
715       DeviceMemory<U>* saved_mean, DeviceMemory<U>* saved_inv_var,
716       bool is_training);
717 
718   template <class T, class U>
719   bool DoBatchNormalizationBackwardImpl(
720       Stream* stream, int miopen_input_type, int miopen_scale_type,
721       const DeviceMemory<T>& y_backprop, const DeviceMemory<T>& x,
722       const DeviceMemory<U>& scale, const DeviceMemory<U>& mean,
723       const DeviceMemory<U>& variance, const dnn::BatchDescriptor& x_desc,
724       const dnn::BatchDescriptor& scale_offset_desc, const double epsilon,
725       DeviceMemory<T>* x_backprop, DeviceMemory<U>* scale_backprop,
726       DeviceMemory<U>* offset_backprop);
727 
728   template <class T>
729   bool DoConvolveBackwardBiasImpl(
730       Stream* stream,
731       int miopen_type,  // Actually miopenDataType_t.
732       const dnn::BatchDescriptor& input_descriptor,
733       const DeviceMemory<T>& input_data,
734       const dnn::BatchDescriptor& bias_descriptor,
735       DeviceMemory<T>* backward_bias_data);
736 
737   template <class T>
738   bool DoRnnForwardImpl(Stream* stream, const MIOpenRnnDescriptor& rnn_desc,
739                         const MIOpenRnnSequenceTensorDescriptor& input_desc,
740                         const DeviceMemory<T>& input_data,
741                         const MIOpenRnnStateTensorDescriptor& input_h_desc,
742                         const DeviceMemory<T>& input_h_data,
743                         const MIOpenRnnStateTensorDescriptor& input_c_desc,
744                         const DeviceMemory<T>& input_c_data,
745                         const DeviceMemory<T>& params,
746                         const MIOpenRnnSequenceTensorDescriptor& output_desc,
747                         DeviceMemory<T>* output_data,
748                         const MIOpenRnnStateTensorDescriptor& output_h_desc,
749                         DeviceMemory<T>* output_h_data,
750                         const MIOpenRnnStateTensorDescriptor& output_c_desc,
751                         DeviceMemory<T>* output_c_data, bool is_training,
752                         ScratchAllocator* reserve_space_allocator,
753                         ScratchAllocator* workspace_allocator);
754   template <class T>
755   bool DoRnnBackwardImpl(Stream* stream, const MIOpenRnnDescriptor& rnn_desc,
756                          const MIOpenRnnSequenceTensorDescriptor& input_desc,
757                          const DeviceMemory<T>& input_data,
758                          const MIOpenRnnStateTensorDescriptor& input_h_desc,
759                          const DeviceMemory<T>& input_h_data,
760                          const MIOpenRnnStateTensorDescriptor& input_c_desc,
761                          const DeviceMemory<T>& input_c_data,
762                          const DeviceMemory<T>& params,
763                          const MIOpenRnnSequenceTensorDescriptor& output_desc,
764                          const DeviceMemory<T>& output_data,
765                          const MIOpenRnnStateTensorDescriptor& output_h_desc,
766                          const DeviceMemory<T>& output_h_data,
767                          const MIOpenRnnStateTensorDescriptor& output_c_desc,
768                          const DeviceMemory<T>& output_c_data,
769                          const DeviceMemory<T>& output_backprop_data,
770                          const DeviceMemory<T>& output_h_backprop_data,
771                          const DeviceMemory<T>& output_c_backprop_data,
772                          DeviceMemory<T>* input_backprop_data,
773                          DeviceMemory<T>* input_h_backprop_data,
774                          DeviceMemory<T>* input_c_backprop_data,
775                          DeviceMemory<T>* params_backprop_data,
776                          DeviceMemory<uint8>* reserve_space_data,
777                          ScratchAllocator* workspace_allocator);
778 
779   template <typename T>
780   bool DoFusedConvolutionBiasActivationImpl(
781       Stream* stream,
782       int miopen_type,  // Actually miopenDataType_t.
783       const dnn::BatchDescriptor& conv_input_descriptor,
784       const DeviceMemory<T>& conv_input_data,
785       const dnn::FilterDescriptor& filter_descriptor,
786       const DeviceMemory<T>& filter_data,
787       const dnn::ConvolutionDescriptor& convolution_descriptor,
788       const dnn::BatchDescriptor& bias_descriptor,
789       const DeviceMemory<T>& bias_data, dnn::ActivationMode activation_mode,
790       const dnn::BatchDescriptor& output_descriptor,
791       DeviceMemory<T>* output_data, dnn::ProfileResult* output_profile_result);
792 
793   template <typename T, typename U>
794   bool DoFusedBatchNormActivationInferenceImpl(
795       Stream* stream,
796       int miopen_type,  // Actually miopenDataType_t.
797       const dnn::BatchDescriptor& x_descriptor, const DeviceMemory<T>& x_data,
798       const dnn::BatchDescriptor& scale_offset_mean_variance_descriptor,
799       const DeviceMemory<U>& scale_data, const DeviceMemory<U>& offset_data,
800       const DeviceMemory<U>& mean_data, const DeviceMemory<U>& variance_data,
801       double epsilon, dnn::ActivationMode activation_mode,
802       DeviceMemory<T>* y_data, dnn::ProfileResult* output_profile_result);
803 
804   template <typename T, typename U>
805   bool DoFusedBatchNormActivationForwardImpl(
806       Stream* stream,
807       int miopen_type,  // Actually miopenDataType_t.
808       const dnn::BatchDescriptor& x_descriptor, const DeviceMemory<T>& x_data,
809       const dnn::BatchDescriptor& scale_offset_mean_variance_descriptor,
810       const DeviceMemory<U>& scale_data, const DeviceMemory<U>& offset_data,
811       double epsilon, dnn::ActivationMode activation_mode,
812       DeviceMemory<T>* y_data, DeviceMemory<U>* batch_mean_data,
813       DeviceMemory<U>* batch_var_data, DeviceMemory<U>* saved_mean_data,
814       DeviceMemory<U>* saved_var_data,
815       dnn::ProfileResult* output_profile_result);
816 
817   template <typename T, typename U>
818   bool DoFusedBatchNormActivationBackwardImpl(
819       Stream* stream,
820       int miopen_type,  // Actually miopenDataType_t.
821       const dnn::BatchDescriptor& y_act_backprop_descriptor,
822       const DeviceMemory<T>& y_act_backprop_data,
823       const DeviceMemory<T>& y_act_data, dnn::ActivationMode activation_mode,
824       const DeviceMemory<T>& x_bn_data,
825       const dnn::BatchDescriptor& scale_offset_mean_variance_descriptor,
826       const DeviceMemory<U>& scale_data, const DeviceMemory<U>& offset_data,
827       const DeviceMemory<U>& saved_mean_data,
828       const DeviceMemory<U>& saved_var_data,
829       DeviceMemory<T>* x_bn_backprop_data, DeviceMemory<U>* scale_backprop_data,
830       DeviceMemory<U>* offset_backprop_data,
831       dnn::ProfileResult* output_profile_result);
832 
833   port::Status DoPrepareForConvolution(
834       dnn::ConvolutionKind kind, dnn::DataType element_type, Stream* stream,
835       const dnn::BatchDescriptor& input_descriptor, DeviceMemoryBase input_data,
836       const dnn::FilterDescriptor& filter_descriptor,
837       DeviceMemoryBase filter_data,
838       const dnn::BatchDescriptor& output_descriptor,
839       DeviceMemoryBase output_data,
840       const dnn::ConvolutionDescriptor& convolution_descriptor,
841       const dnn::AlgorithmConfig& algorithm_config,
842       ScratchAllocator* scratch_allocator, dnn::AlgorithmDesc* algorithm_desc,
843       DeviceMemory<uint8>* scratch_memory) override;
844 
845   port::Status DoCtcLossImpl(
846       Stream* stream, const MIOpenRnnStateTensorDescriptor& probs_desc,
847       const DeviceMemoryBase probs_data, absl::Span<const int> labels_data,
848       absl::Span<const int> labels_lengths_data,
849       absl::Span<const int> input_lengths_data, DeviceMemoryBase costs_data,
850       const MIOpenRnnStateTensorDescriptor& grads_desc,
851       DeviceMemoryBase grads_data, const MIOpenCTCLossDescriptor& ctc_loss_desc,
852       DeviceMemory<uint8> scratch_memory, int ctc_loss_algo_id);
853 
854   port::Status DoPrepareForCtcLoss(
855       Stream* stream, dnn::DataType element_type,
856       const dnn::RnnStateTensorDescriptor& probs_desc,
857       const dnn::RnnStateTensorDescriptor& grads_desc,
858       absl::Span<const int> labels_data,
859       absl::Span<const int> labels_lengths_data,
860       absl::Span<const int> input_lengths_data,
861       ScratchAllocator* scratch_allocator, DeviceMemory<uint8>* scratch_memory,
862       int* ctc_loss_algo_id) override;
863 
864   bool GetMIOpenConvolveAlgorithmsImmediateMode(
865       dnn::ConvolutionKind kind, dnn::DataType element_type, Stream* stream,
866       const dnn::BatchDescriptor& input_descriptor, DeviceMemoryBase input_data,
867       const dnn::FilterDescriptor& filter_descriptor,
868       DeviceMemoryBase filter_data,
869       const dnn::BatchDescriptor& output_descriptor,
870       DeviceMemoryBase output_data,
871       const dnn::ConvolutionDescriptor& convolution_descriptor,
872       ScratchAllocator* scratch_allocator,
873       std::vector<dnn::ProfileResult>* out_algorithms);
874 
875   bool GetMIOpenConvolveAlgorithmsFindMode(
876       dnn::ConvolutionKind kind, dnn::DataType element_type, Stream* stream,
877       const dnn::BatchDescriptor& input_descriptor, DeviceMemoryBase input_data,
878       const dnn::FilterDescriptor& filter_descriptor,
879       DeviceMemoryBase filter_data,
880       const dnn::BatchDescriptor& output_descriptor,
881       DeviceMemoryBase output_data,
882       const dnn::ConvolutionDescriptor& convolution_descriptor,
883       ScratchAllocator* scratch_allocator,
884       std::vector<dnn::ProfileResult>* out_algorithms);
885 
886   template <class T>
887   bool DoPoolBackwardImpl(Stream* stream,
888                           const dnn::PoolingDescriptor& pooling_dimensions,
889                           const dnn::BatchDescriptor& input_dimensions,
890                           const DeviceMemory<T>& input_data,
891                           const dnn::BatchDescriptor& output_dimensions,
892                           const DeviceMemory<T>& output_data,
893                           const DeviceMemory<T>& input_diff_data,
894                           DeviceMemory<T>* output_diff_data,
895                           ScratchAllocator* workspace_allocator = nullptr);
896 
897   SE_DISALLOW_COPY_AND_ASSIGN(MIOpenSupport);
898 };
899 
900 }  // namespace gpu
901 }  // namespace stream_executor
902 
903 #endif  // TENSORFLOW_STREAM_EXECUTOR_ROCM_ROCM_DNN_H_
904