1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // The CUDA-specific DNN library support, implementing the general DnnSupport
17 // interface.
18 
19 #ifndef TENSORFLOW_STREAM_EXECUTOR_CUDA_CUDA_DNN_H_
20 #define TENSORFLOW_STREAM_EXECUTOR_CUDA_CUDA_DNN_H_
21 
22 #include "tensorflow/core/platform/thread_annotations.h"
23 #include "tensorflow/stream_executor/cuda/cuda_activation.h"
24 #include "tensorflow/stream_executor/dnn.h"
25 #include "tensorflow/stream_executor/lib/status.h"
26 #include "tensorflow/stream_executor/plugin_registry.h"
27 #include "tensorflow/stream_executor/temporary_device_memory.h"
28 
29 namespace stream_executor {
30 namespace gpu {
31 
32 class GpuExecutor;
33 class CudnnRnnDescriptor;
34 class CudnnRnnSequenceTensorDescriptor;
35 class CudnnRnnStateTensorDescriptor;
36 class CudnnCtcLossDescriptor;
37 
38 // Opaque and unique identifier for the cuDNN plugin.
39 extern const PluginId kCuDnnPlugin;
40 
41 // cudnn-library based DNN support. For details on overridden interface
42 // functions, see dnn.h.
43 class CudnnSupport : public dnn::DnnSupport {
44  public:
45   explicit CudnnSupport(GpuExecutor* parent);
46 
47   port::Status Init() override;
48   port::StatusOr<perftools::gputools::dnn::VersionInfo> GetVersion() override;
49 
50   port::StatusOr<std::unique_ptr<dnn::RnnDescriptor>> createRnnDescriptor(
51       int num_layers, int hidden_size, int input_size, int cell_size,
52       int batch_size, dnn::RnnInputMode input_mode,
53       dnn::RnnDirectionMode direction_mode, dnn::RnnMode rnn_mode,
54       dnn::DataType data_type, const dnn::AlgorithmConfig& algorithm_config,
55       float dropout, uint64 seed, ScratchAllocator* state_allocator,
56       bool use_padded_io) override;
57 
58   port::StatusOr<std::unique_ptr<dnn::RnnSequenceTensorDescriptor>>
59   createRnnSequenceTensorDescriptor(int max_seq_length, int batch_size,
60                                     int data_size,
61                                     dnn::DataType data_type) override;
62 
63   port::StatusOr<std::unique_ptr<dnn::RnnSequenceTensorDescriptor>>
64   createRnnSequenceTensorDescriptor(int max_seq_length, int batch_size,
65                                     int data_size,
66                                     const absl::Span<const int>& seq_lengths,
67                                     bool time_major,
68                                     dnn::DataType data_type) override;
69 
70   port::StatusOr<std::unique_ptr<dnn::RnnStateTensorDescriptor>>
71   createRnnStateTensorDescriptor(int num_layer, int batch_size, int data_size,
72                                  dnn::DataType data_type) override;
73 
74   bool DoRnnForward(Stream* stream, const dnn::RnnDescriptor& rnn_desc,
75                     const dnn::RnnSequenceTensorDescriptor& input_desc,
76                     const DeviceMemory<Eigen::half>& input_data,
77                     const dnn::RnnStateTensorDescriptor& input_h_desc,
78                     const DeviceMemory<Eigen::half>& input_h_data,
79                     const dnn::RnnStateTensorDescriptor& input_c_desc,
80                     const DeviceMemory<Eigen::half>& input_c_data,
81                     const DeviceMemory<Eigen::half>& params,
82                     const dnn::RnnSequenceTensorDescriptor& output_desc,
83                     DeviceMemory<Eigen::half>* output_data,
84                     const dnn::RnnStateTensorDescriptor& output_h_desc,
85                     DeviceMemory<Eigen::half>* output_h_data,
86                     const dnn::RnnStateTensorDescriptor& output_c_desc,
87                     DeviceMemory<Eigen::half>* output_c_data, bool is_training,
88                     ScratchAllocator* reserve_space_allocator,
89                     ScratchAllocator* workspace_allocator,
90                     dnn::ProfileResult* output_profile_result) override;
91 
92   bool DoRnnForward(Stream* stream, const dnn::RnnDescriptor& rnn_desc,
93                     const dnn::RnnSequenceTensorDescriptor& input_desc,
94                     const DeviceMemory<float>& input_data,
95                     const dnn::RnnStateTensorDescriptor& input_h_desc,
96                     const DeviceMemory<float>& input_h_data,
97                     const dnn::RnnStateTensorDescriptor& input_c_desc,
98                     const DeviceMemory<float>& input_c_data,
99                     const DeviceMemory<float>& params,
100                     const dnn::RnnSequenceTensorDescriptor& output_desc,
101                     DeviceMemory<float>* output_data,
102                     const dnn::RnnStateTensorDescriptor& output_h_desc,
103                     DeviceMemory<float>* output_h_data,
104                     const dnn::RnnStateTensorDescriptor& output_c_desc,
105                     DeviceMemory<float>* output_c_data, bool is_training,
106                     ScratchAllocator* reserve_space_allocator,
107                     ScratchAllocator* workspace_allocator,
108                     dnn::ProfileResult* output_profile_result) override;
109 
110   bool DoRnnForward(Stream* stream, const dnn::RnnDescriptor& rnn_desc,
111                     const dnn::RnnSequenceTensorDescriptor& input_desc,
112                     const DeviceMemory<double>& input_data,
113                     const dnn::RnnStateTensorDescriptor& input_h_desc,
114                     const DeviceMemory<double>& input_h_data,
115                     const dnn::RnnStateTensorDescriptor& input_c_desc,
116                     const DeviceMemory<double>& input_c_data,
117                     const DeviceMemory<double>& params,
118                     const dnn::RnnSequenceTensorDescriptor& output_desc,
119                     DeviceMemory<double>* output_data,
120                     const dnn::RnnStateTensorDescriptor& output_h_desc,
121                     DeviceMemory<double>* output_h_data,
122                     const dnn::RnnStateTensorDescriptor& output_c_desc,
123                     DeviceMemory<double>* output_c_data, bool is_training,
124                     ScratchAllocator* reserve_space_allocator,
125                     ScratchAllocator* workspace_allocator,
126                     dnn::ProfileResult* output_profile_result) override;
127 
128   bool DoRnnBackward(Stream* stream, const dnn::RnnDescriptor& rnn_desc,
129                      const dnn::RnnSequenceTensorDescriptor& input_desc,
130                      const DeviceMemory<Eigen::half>& input_data,
131                      const dnn::RnnStateTensorDescriptor& input_h_desc,
132                      const DeviceMemory<Eigen::half>& input_h_data,
133                      const dnn::RnnStateTensorDescriptor& input_c_desc,
134                      const DeviceMemory<Eigen::half>& input_c_data,
135                      const DeviceMemory<Eigen::half>& params,
136                      const dnn::RnnSequenceTensorDescriptor& output_desc,
137                      const DeviceMemory<Eigen::half>& output_data,
138                      const dnn::RnnStateTensorDescriptor& output_h_desc,
139                      const DeviceMemory<Eigen::half>& output_h_data,
140                      const dnn::RnnStateTensorDescriptor& output_c_desc,
141                      const DeviceMemory<Eigen::half>& output_c_data,
142                      const DeviceMemory<Eigen::half>& output_backprop_data,
143                      const DeviceMemory<Eigen::half>& output_h_backprop_data,
144                      const DeviceMemory<Eigen::half>& output_c_backprop_data,
145                      DeviceMemory<Eigen::half>* input_backprop_data,
146                      DeviceMemory<Eigen::half>* input_h_backprop_data,
147                      DeviceMemory<Eigen::half>* input_c_backprop_data,
148                      DeviceMemory<Eigen::half>* params_backprop_data,
149                      DeviceMemory<uint8>* reserve_space_data,
150                      ScratchAllocator* workspace_allocator,
151                      dnn::ProfileResult* output_profile_result) override;
152 
153   bool DoRnnBackward(Stream* stream, const dnn::RnnDescriptor& rnn_desc,
154                      const dnn::RnnSequenceTensorDescriptor& input_desc,
155                      const DeviceMemory<float>& input_data,
156                      const dnn::RnnStateTensorDescriptor& input_h_desc,
157                      const DeviceMemory<float>& input_h_data,
158                      const dnn::RnnStateTensorDescriptor& input_c_desc,
159                      const DeviceMemory<float>& input_c_data,
160                      const DeviceMemory<float>& params,
161                      const dnn::RnnSequenceTensorDescriptor& output_desc,
162                      const DeviceMemory<float>& output_data,
163                      const dnn::RnnStateTensorDescriptor& output_h_desc,
164                      const DeviceMemory<float>& output_h_data,
165                      const dnn::RnnStateTensorDescriptor& output_c_desc,
166                      const DeviceMemory<float>& output_c_data,
167                      const DeviceMemory<float>& output_backprop_data,
168                      const DeviceMemory<float>& output_h_backprop_data,
169                      const DeviceMemory<float>& output_c_backprop_data,
170                      DeviceMemory<float>* input_backprop_data,
171                      DeviceMemory<float>* input_h_backprop_data,
172                      DeviceMemory<float>* input_c_backprop_data,
173                      DeviceMemory<float>* params_backprop_data,
174                      DeviceMemory<uint8>* reserve_space_data,
175                      ScratchAllocator* workspace_allocator,
176                      dnn::ProfileResult* output_profile_result) override;
177 
178   bool DoRnnBackward(Stream* stream, const dnn::RnnDescriptor& rnn_desc,
179                      const dnn::RnnSequenceTensorDescriptor& input_desc,
180                      const DeviceMemory<double>& input_data,
181                      const dnn::RnnStateTensorDescriptor& input_h_desc,
182                      const DeviceMemory<double>& input_h_data,
183                      const dnn::RnnStateTensorDescriptor& input_c_desc,
184                      const DeviceMemory<double>& input_c_data,
185                      const DeviceMemory<double>& params,
186                      const dnn::RnnSequenceTensorDescriptor& output_desc,
187                      const DeviceMemory<double>& output_data,
188                      const dnn::RnnStateTensorDescriptor& output_h_desc,
189                      const DeviceMemory<double>& output_h_data,
190                      const dnn::RnnStateTensorDescriptor& output_c_desc,
191                      const DeviceMemory<double>& output_c_data,
192                      const DeviceMemory<double>& output_backprop_data,
193                      const DeviceMemory<double>& output_h_backprop_data,
194                      const DeviceMemory<double>& output_c_backprop_data,
195                      DeviceMemory<double>* input_backprop_data,
196                      DeviceMemory<double>* input_h_backprop_data,
197                      DeviceMemory<double>* input_c_backprop_data,
198                      DeviceMemory<double>* params_backprop_data,
199                      DeviceMemory<uint8>* reserve_space_data,
200                      ScratchAllocator* workspace_allocator,
201                      dnn::ProfileResult* output_profile_result) override;
202 
203   bool GetConvolveAlgorithms(
204       bool with_winograd_nonfused, int cc_major, int cc_minor,
205       std::vector<dnn::AlgorithmDesc>* out_algorithms) override;
206 
207   bool GetRnnAlgorithms(
208       std::vector<dnn::AlgorithmDesc>* out_algorithms) override;
209 
210   bool GetConvolveBackwardDataAlgorithms(
211       bool with_winograd_nonfused, int cc_major, int cc_minor,
212       std::vector<dnn::AlgorithmDesc>* out_algorithms) override;
213 
214   bool GetConvolveBackwardFilterAlgorithms(
215       bool with_winograd_nonfused, int cc_major, int cc_minor,
216       std::vector<dnn::AlgorithmDesc>* out_algorithms) override;
217 
218   bool DoBatchNormalizationForward(
219       Stream* stream, const DeviceMemory<float>& x,
220       const DeviceMemory<float>& scale, const DeviceMemory<float>& offset,
221       const DeviceMemory<float>& estimated_mean,
222       const DeviceMemory<float>& estimated_var_iance,
223       const DeviceMemory<float>& side_input, const dnn::BatchDescriptor& x_desc,
224       const dnn::BatchDescriptor& scale_offset_desc, const double epsilon,
225       const double exponential_average_factor,
226       dnn::ActivationMode activation_mode, DeviceMemory<float>* y,
227       DeviceMemory<float>* batch_mean, DeviceMemory<float>* batch_var,
228       DeviceMemory<float>* saved_mean, DeviceMemory<float>* saved_inv_var,
229       bool is_training, ScratchAllocator* reserve_space_allocator,
230       ScratchAllocator* workspace_allocator) override;
231 
232   bool DoBatchNormalizationForward(
233       Stream* stream, const DeviceMemory<Eigen::half>& x,
234       const DeviceMemory<float>& scale, const DeviceMemory<float>& offset,
235       const DeviceMemory<float>& estimated_mean,
236       const DeviceMemory<float>& estimated_variance,
237       const DeviceMemory<float>& side_input, const dnn::BatchDescriptor& x_desc,
238       const dnn::BatchDescriptor& scale_offset_desc, const double epsilon,
239       const double exponential_average_factor,
240       dnn::ActivationMode activation_mode, DeviceMemory<Eigen::half>* y,
241       DeviceMemory<float>* batch_mean, DeviceMemory<float>* batch_var,
242       DeviceMemory<float>* saved_mean, DeviceMemory<float>* saved_inv_var,
243       bool is_training, ScratchAllocator* reserve_space_allocator,
244       ScratchAllocator* workspace_allocator) override;
245 
246   bool DoBatchNormalizationBackward(
247       Stream* stream, const DeviceMemory<float>& y_backprop,
248       const DeviceMemory<float>& x, const DeviceMemory<float>& scale,
249       const DeviceMemory<float>& mean, const DeviceMemory<float>& inv_var,
250       const dnn::BatchDescriptor& x_desc,
251       const dnn::BatchDescriptor& scale_offset_desc, const double epsilon,
252       DeviceMemory<float>* x_backprop, DeviceMemory<float>* scale_backprop,
253       DeviceMemory<float>* offset_backprop,
254       DeviceMemory<uint8>* reserve_space_data,
255       ScratchAllocator* workspace_allocator) override;
256 
257   bool DoBatchNormalizationBackward(
258       Stream* stream, const DeviceMemory<Eigen::half>& y_backprop,
259       const DeviceMemory<Eigen::half>& x, const DeviceMemory<float>& scale,
260       const DeviceMemory<float>& mean, const DeviceMemory<float>& inv_var,
261       const dnn::BatchDescriptor& x_desc,
262       const dnn::BatchDescriptor& scale_offset_desc, const double epsilon,
263       DeviceMemory<Eigen::half>* x_backprop,
264       DeviceMemory<float>* scale_backprop, DeviceMemory<float>* offset_backprop,
265       DeviceMemory<uint8>* reserve_space_data,
266       ScratchAllocator* workspace_allocator) override;
267 
268   port::Status DoConvolve(
269       dnn::ConvolutionKind kind, dnn::DataType element_type,
270       dnn::DataType output_type, Stream* stream,
271       const dnn::BatchDescriptor& input_descriptor, DeviceMemoryBase input_data,
272       const dnn::FilterDescriptor& filter_descriptor,
273       DeviceMemoryBase filter_data,
274       const dnn::BatchDescriptor& output_descriptor,
275       DeviceMemoryBase output_data,
276       const dnn::ConvolutionDescriptor& convolution_descriptor,
277       dnn::AlgorithmDesc algorithm_desc, DeviceMemory<uint8> scratch_memory,
278       dnn::ProfileResult* output_profile_result) override;
279 
280   port::Status DoFusedConvolve(
281       Stream* stream, const dnn::BatchDescriptor& conv_input_descriptor,
282       const DeviceMemory<double>& conv_input_data, double conv_input_scale,
283       const dnn::FilterDescriptor& filter_descriptor,
284       const DeviceMemory<double>& filter_data,
285       const dnn::ConvolutionDescriptor& convolution_descriptor,
286       const DeviceMemory<double>& side_input_data, double side_input_scale,
287       const dnn::BatchDescriptor& bias_descriptor,
288       const DeviceMemory<double>& biases, dnn::ActivationMode activation_mode,
289       const dnn::BatchDescriptor& output_descriptor,
290       DeviceMemory<double>* output_data, ScratchAllocator* scratch_allocator,
291       const dnn::AlgorithmConfig& algorithm_config,
292       dnn::ProfileResult* output_profile_result) override;
293 
294   port::Status DoFusedConvolve(
295       Stream* stream, const dnn::BatchDescriptor& conv_input_descriptor,
296       const DeviceMemory<float>& conv_input_data, float conv_input_scale,
297       const dnn::FilterDescriptor& filter_descriptor,
298       const DeviceMemory<float>& filter_data,
299       const dnn::ConvolutionDescriptor& convolution_descriptor,
300       const DeviceMemory<float>& side_input_data, float side_input_scale,
301       const dnn::BatchDescriptor& bias_descriptor,
302       const DeviceMemory<float>& biases, dnn::ActivationMode activation_mode,
303       const dnn::BatchDescriptor& output_descriptor,
304       DeviceMemory<float>* output_data, ScratchAllocator* scratch_allocator,
305       const dnn::AlgorithmConfig& algorithm_config,
306       dnn::ProfileResult* output_profile_result) override;
307 
308   port::Status DoFusedConvolve(
309       Stream* stream, const dnn::BatchDescriptor& conv_input_descriptor,
310       const DeviceMemory<Eigen::half>& conv_input_data, float conv_input_scale,
311       const dnn::FilterDescriptor& filter_descriptor,
312       const DeviceMemory<Eigen::half>& filter_data,
313       const dnn::ConvolutionDescriptor& convolution_descriptor,
314       const DeviceMemory<Eigen::half>& side_input_data, float side_input_scale,
315       const dnn::BatchDescriptor& bias_descriptor,
316       const DeviceMemory<Eigen::half>& biases,
317       dnn::ActivationMode activation_mode,
318       const dnn::BatchDescriptor& output_descriptor,
319       DeviceMemory<Eigen::half>* output_data,
320       ScratchAllocator* scratch_allocator,
321       const dnn::AlgorithmConfig& algorithm_config,
322       dnn::ProfileResult* output_profile_result) override;
323 
324   port::Status DoFusedConvolve(
325       Stream* stream, const dnn::BatchDescriptor& conv_input_descriptor,
326       const DeviceMemory<int8>& conv_input_data, float conv_input_scale,
327       const dnn::FilterDescriptor& filter_descriptor,
328       const DeviceMemory<int8>& filter_data,
329       const dnn::ConvolutionDescriptor& convolution_descriptor,
330       const DeviceMemory<int8>& side_input_data, float side_input_scale,
331       const dnn::BatchDescriptor& bias_descriptor,
332       const DeviceMemory<float>& biases, dnn::ActivationMode activation_mode,
333       const dnn::BatchDescriptor& output_descriptor,
334       DeviceMemory<int8>* output_data, ScratchAllocator* scratch_allocator,
335       const dnn::AlgorithmConfig& algorithm_config,
336       dnn::ProfileResult* output_profile_result) override;
337 
338   port::Status DoFusedConvolve(
339       Stream* stream, const dnn::BatchDescriptor& conv_input_descriptor,
340       const DeviceMemory<int8>& conv_input_data, float conv_input_scale,
341       const dnn::FilterDescriptor& filter_descriptor,
342       const DeviceMemory<int8>& filter_data,
343       const dnn::ConvolutionDescriptor& convolution_descriptor,
344       const DeviceMemory<float>& side_input_data, float side_input_scale,
345       const dnn::BatchDescriptor& bias_descriptor,
346       const DeviceMemory<float>& biases, dnn::ActivationMode activation_mode,
347       const dnn::BatchDescriptor& output_descriptor,
348       DeviceMemory<float>* output_data, ScratchAllocator* scratch_allocator,
349       const dnn::AlgorithmConfig& algorithm_config,
350       dnn::ProfileResult* output_profile_result) override;
351 
DoConvolveQuantized(Stream * stream,const dnn::BatchDescriptor & input_descriptor,const DeviceMemory<float> & input_data,const dnn::FilterDescriptor & filter_descriptor,const DeviceMemory<int8> & filter_coefficients,const DeviceMemory<float> & coefficient_scales,const dnn::ConvolutionDescriptor & convolution_descriptor,const dnn::BatchDescriptor & output_descriptor,DeviceMemory<float> * output_data)352   bool DoConvolveQuantized(
353       Stream* stream, const dnn::BatchDescriptor& input_descriptor,
354       const DeviceMemory<float>& input_data,
355       const dnn::FilterDescriptor& filter_descriptor,
356       const DeviceMemory<int8>& filter_coefficients,
357       const DeviceMemory<float>& coefficient_scales,
358       const dnn::ConvolutionDescriptor& convolution_descriptor,
359       const dnn::BatchDescriptor& output_descriptor,
360       DeviceMemory<float>* output_data) override {
361     LOG(ERROR) << "DoConvolveQuantized not supported by cuDNN";
362     return false;
363   }
364 
DoConvolveQuantized(Stream * stream,const dnn::BatchDescriptor & input_descriptor,const DeviceMemory<float> & input_data,const dnn::FilterDescriptor & filter_descriptor,const DeviceMemory<int16> & filter_coefficients,const DeviceMemory<float> & coefficient_scales,const dnn::ConvolutionDescriptor & convolution_descriptor,const dnn::BatchDescriptor & output_descriptor,DeviceMemory<float> * output_data)365   bool DoConvolveQuantized(
366       Stream* stream, const dnn::BatchDescriptor& input_descriptor,
367       const DeviceMemory<float>& input_data,
368       const dnn::FilterDescriptor& filter_descriptor,
369       const DeviceMemory<int16>& filter_coefficients,
370       const DeviceMemory<float>& coefficient_scales,
371       const dnn::ConvolutionDescriptor& convolution_descriptor,
372       const dnn::BatchDescriptor& output_descriptor,
373       DeviceMemory<float>* output_data) override {
374     LOG(ERROR) << "DoConvolveQuantized not supported by cuDNN";
375     return false;
376   }
377 
DoSeparableConvolve(Stream * stream,const dnn::BatchDescriptor & batch_descriptor,const DeviceMemory<float> & input_data,const dnn::FilterDescriptor & filter_descriptor,int depth_multiplier,const DeviceMemory<float> & first_weights,const DeviceMemory<float> & second_weights,const dnn::ConvolutionDescriptor & convolution_descriptor,const dnn::BatchDescriptor & output_descriptor,DeviceMemory<float> * output_data)378   bool DoSeparableConvolve(
379       Stream* stream, const dnn::BatchDescriptor& batch_descriptor,
380       const DeviceMemory<float>& input_data,
381       const dnn::FilterDescriptor& filter_descriptor, int depth_multiplier,
382       const DeviceMemory<float>& first_weights,
383       const DeviceMemory<float>& second_weights,
384       const dnn::ConvolutionDescriptor& convolution_descriptor,
385       const dnn::BatchDescriptor& output_descriptor,
386       DeviceMemory<float>* output_data) override {
387     LOG(ERROR) << "separable convolution not supported by CUDNN";
388     return false;
389   }
390 
391   bool DoConvolveBackwardBias(
392       Stream* stream, const dnn::BatchDescriptor& input_descriptor,
393       const DeviceMemory<double>& input_data,
394       const dnn::BatchDescriptor& bias_descriptor,
395       DeviceMemory<double>* backward_bias_data) override;
396 
397   bool DoConvolveBackwardBias(Stream* stream,
398                               const dnn::BatchDescriptor& input_descriptor,
399                               const DeviceMemory<float>& input_data,
400                               const dnn::BatchDescriptor& bias_descriptor,
401                               DeviceMemory<float>* backward_bias_data) override;
402 
403   bool DoConvolveBackwardBias(
404       Stream* stream, const dnn::BatchDescriptor& input_descriptor,
405       const DeviceMemory<Eigen::half>& input_data,
406       const dnn::BatchDescriptor& bias_descriptor,
407       DeviceMemory<Eigen::half>* backward_bias_data) override;
408 
409   bool DoMatMul(Stream* stream, const DeviceMemory<float>& input_data,
410                 const DeviceMemory<float>& weights,
411                 const dnn::BatchDescriptor& input_dimensions,
412                 const dnn::BatchDescriptor& output_dimensions,
413                 DeviceMemory<float>* output_data) override;
414 
DoMatMulQuantized(Stream * stream,const DeviceMemory<float> & input_data,const DeviceMemory<int8> & quantized_weights,const DeviceMemory<float> & weight_scales,const dnn::BatchDescriptor & input_dimensions,const dnn::BatchDescriptor & output_dimensions,DeviceMemory<float> * output_data)415   bool DoMatMulQuantized(Stream* stream, const DeviceMemory<float>& input_data,
416                          const DeviceMemory<int8>& quantized_weights,
417                          const DeviceMemory<float>& weight_scales,
418                          const dnn::BatchDescriptor& input_dimensions,
419                          const dnn::BatchDescriptor& output_dimensions,
420                          DeviceMemory<float>* output_data) override {
421     LOG(ERROR) << "DNN MatMulQuantized not supported by CUDNN";
422     return false;
423   }
424 
DoMatMulQuantized(Stream * stream,const DeviceMemory<float> & input_data,const DeviceMemory<int16> & quantized_weights,const DeviceMemory<float> & weight_scales,const dnn::BatchDescriptor & input_dimensions,const dnn::BatchDescriptor & output_dimensions,DeviceMemory<float> * output_data)425   bool DoMatMulQuantized(Stream* stream, const DeviceMemory<float>& input_data,
426                          const DeviceMemory<int16>& quantized_weights,
427                          const DeviceMemory<float>& weight_scales,
428                          const dnn::BatchDescriptor& input_dimensions,
429                          const dnn::BatchDescriptor& output_dimensions,
430                          DeviceMemory<float>* output_data) override {
431     LOG(ERROR) << "DNN MatMulQuantized not supported by CUDNN";
432     return false;
433   }
434 
435   bool DoBiasAdd(Stream* stream, const DeviceMemory<float>& input_data,
436                  const DeviceMemory<float>& biases,
437                  const dnn::BatchDescriptor& dimensions,
438                  DeviceMemory<float>* output_data) override;
439 
440   bool DoActivate(Stream* stream, dnn::ActivationMode activation_mode,
441                   const dnn::BatchDescriptor& dimensions,
442                   const DeviceMemory<float>& input_data,
443                   DeviceMemory<float>* output_data, uint64 options) override;
444 
445   bool DoPoolForward(Stream* stream,
446                      const dnn::PoolingDescriptor& pooling_dimensions,
447                      const dnn::BatchDescriptor& input_dimensions,
448                      const DeviceMemory<double>& input_data,
449                      const dnn::BatchDescriptor& output_dimensions,
450                      DeviceMemory<double>* output_data,
451                      ScratchAllocator* workspace_allocator) override;
452 
453   bool DoPoolForward(Stream* stream,
454                      const dnn::PoolingDescriptor& pooling_dimensions,
455                      const dnn::BatchDescriptor& input_dimensions,
456                      const DeviceMemory<float>& input_data,
457                      const dnn::BatchDescriptor& output_dimensions,
458                      DeviceMemory<float>* output_data,
459                      ScratchAllocator* workspace_allocator) override;
460 
461   bool DoPoolForward(Stream* stream,
462                      const dnn::PoolingDescriptor& pooling_dimensions,
463                      const dnn::BatchDescriptor& input_dimensions,
464                      const DeviceMemory<Eigen::half>& input_data,
465                      const dnn::BatchDescriptor& output_dimensions,
466                      DeviceMemory<Eigen::half>* output_data,
467                      ScratchAllocator* workspace_allocator) override;
468 
469   bool DoPoolForward(Stream* stream,
470                      const dnn::PoolingDescriptor& pooling_dimensions,
471                      const dnn::BatchDescriptor& input_dimensions,
472                      const DeviceMemory<int8>& input_data,
473                      const dnn::BatchDescriptor& output_dimensions,
474                      DeviceMemory<int8>* output_data,
475                      ScratchAllocator* workspace_allocator) override;
476 
477   bool DoPoolBackward(Stream* stream,
478                       const dnn::PoolingDescriptor& pooling_dimensions,
479                       const dnn::BatchDescriptor& input_dimensions,
480                       const DeviceMemory<double>& input_data,
481                       const dnn::BatchDescriptor& output_dimensions,
482                       const DeviceMemory<double>& output_data,
483                       const DeviceMemory<double>& input_diff_data,
484                       DeviceMemory<double>* output_diff_data,
485                       ScratchAllocator* workspace_allocator) override;
486 
487   bool DoPoolBackward(Stream* stream,
488                       const dnn::PoolingDescriptor& pooling_dimensions,
489                       const dnn::BatchDescriptor& input_dimensions,
490                       const DeviceMemory<float>& input_data,
491                       const dnn::BatchDescriptor& output_dimensions,
492                       const DeviceMemory<float>& output_data,
493                       const DeviceMemory<float>& input_diff_data,
494                       DeviceMemory<float>* output_diff_data,
495                       ScratchAllocator* workspace_allocator) override;
496 
497   bool DoPoolBackward(Stream* stream,
498                       const dnn::PoolingDescriptor& pooling_dimensions,
499                       const dnn::BatchDescriptor& input_dimensions,
500                       const DeviceMemory<Eigen::half>& input_data,
501                       const dnn::BatchDescriptor& output_dimensions,
502                       const DeviceMemory<Eigen::half>& output_data,
503                       const DeviceMemory<Eigen::half>& input_diff_data,
504                       DeviceMemory<Eigen::half>* output_diff_data,
505                       ScratchAllocator* workspace_allocator) override;
506 
507   bool DoNormalizeWithDimensions(
508       Stream* stream, const dnn::NormalizeDescriptor& normalize_descriptor,
509       const dnn::BatchDescriptor& dimensions,
510       const DeviceMemory<float>& input_data,
511       DeviceMemory<float>* output_data) override;
512 
513   bool DoNormalizeBackwardWithDimensions(
514       Stream* stream, const dnn::NormalizeDescriptor& normalize_descriptor,
515       const dnn::BatchDescriptor& dimensions,
516       const DeviceMemory<float>& raw_data,
517       const DeviceMemory<float>& normalized_data,
518       const DeviceMemory<float>& normalized_variable_gradient,
519       DeviceMemory<float>* raw_variable_gradient,
520       ScratchAllocator* workspace_allocator) override;
521 
522   bool DoDepthConcatenate(
523       Stream* stream, port::ArraySlice<dnn::BatchDescriptor> input_dimensions,
524       port::ArraySlice<const DeviceMemory<float>*> input_data,
525       DeviceMemory<float>* output_data) override;
526 
527   bool DoElementwiseOperate(
528       Stream* stream, dnn::ElementwiseOperation operation,
529       port::ArraySlice<dnn::BatchDescriptor> input_dimensions,
530       port::ArraySlice<const DeviceMemory<float>*> input_data,
531       const dnn::BatchDescriptor& output_dimensions,
532       DeviceMemory<float>* output_data) override;
533 
534   bool DoXYPad(Stream* stream, const dnn::BatchDescriptor& dimensions,
535                const DeviceMemory<float>& input_data, int64 left_pad,
536                int64 right_pad, int64 top_pad, int64 bottom_pad,
537                DeviceMemory<float>* output_data) override;
538 
539   bool DoXYSlice(Stream* stream, const dnn::BatchDescriptor& dimensions,
540                  const DeviceMemory<float>& input_data, int64 left_trim,
541                  int64 right_trim, int64 top_trim, int64 bottom_trim,
542                  DeviceMemory<float>* output_data) override;
543 
544   bool DoMemcpyD2HQuantized(Stream* stream,
545                             const DeviceMemory<float>& device_unquantized_src,
546                             dnn::QuantizedActivationMode mode, void* host_dst,
547                             int64 size) override;
548 
549   bool DoMemcpyH2DQuantized(
550       Stream* stream, const void* host_src, int64 size,
551       dnn::QuantizedActivationMode mode,
552       DeviceMemory<float>* device_unquantized_dst) override;
553 
554   // Derives an output batch descriptor from an input batch and convolution
555   // descriptors.
556   bool DeriveOutputBatchDescriptor(
557       const dnn::BatchDescriptor& batch_descriptor,
558       const dnn::FilterDescriptor& filter_descriptor,
559       const dnn::ConvolutionDescriptor& convolution_descriptor,
560       dnn::BatchDescriptor* output_batch_descriptor);
561 
562   port::Status DoCtcLoss(Stream* stream, dnn::DataType element_type,
563                          const dnn::RnnStateTensorDescriptor& probs_desc,
564                          const DeviceMemoryBase probs_data,
565                          absl::Span<const int> labels_data,
566                          absl::Span<const int> labels_lengths_data,
567                          absl::Span<const int> input_lengths_data,
568                          DeviceMemoryBase costs_data,
569                          const dnn::RnnStateTensorDescriptor& grads_desc,
570                          DeviceMemoryBase grads_data,
571                          DeviceMemory<uint8> scratch_memory,
572                          int ctc_loss_algo_id) override;
573 
574   bool DoTransformTensor(Stream* stream, const dnn::BatchDescriptor& input_desc,
575                          dnn::DataType input_type,
576                          const DeviceMemoryBase& input_data,
577                          const dnn::BatchDescriptor& output_desc,
578                          dnn::DataType output_type, float scale,
579                          DeviceMemoryBase* output_data) override;
580 
581  private:
582   GpuExecutor* parent_;  // Parent executor object. Not owned.
583 
584   // Provides access to the cuDNN handle.
585   std::unique_ptr<class CudnnAccess> cudnn_;
586 
587   template <class T, class U>
588   port::Status DoBatchNormalizationForwardImpl(
589       Stream* stream, dnn::DataType input_data_type,
590       dnn::DataType scale_data_type, const DeviceMemory<T>& x,
591       const DeviceMemory<U>& scale, const DeviceMemory<U>& offset,
592       const DeviceMemory<U>& estimated_mean,
593       const DeviceMemory<U>& estimated_variance,
594       const DeviceMemory<U>& side_input, const dnn::BatchDescriptor& x_desc,
595       const dnn::BatchDescriptor& scale_offset_desc, const double epsilon,
596       const double exponential_average_factor,
597       dnn::ActivationMode activation_mode, DeviceMemory<T>* y,
598       DeviceMemory<U>* batch_mean, DeviceMemory<U>* batch_var,
599       DeviceMemory<U>* saved_mean, DeviceMemory<U>* saved_inv_var,
600       bool is_training, ScratchAllocator* reserve_space_allocator,
601       ScratchAllocator* workspace_allocator);
602 
603   template <class T, class U>
604   port::Status DoBatchNormalizationBackwardImpl(
605       Stream* stream, int cudnn_input_type, int cudnn_scale_type,
606       const DeviceMemory<T>& y_backprop, const DeviceMemory<T>& x,
607       const DeviceMemory<U>& scale, const DeviceMemory<U>& mean,
608       const DeviceMemory<U>& inv_var, const dnn::BatchDescriptor& x_desc,
609       const dnn::BatchDescriptor& scale_offset_desc, const double epsilon,
610       DeviceMemory<T>* x_backprop, DeviceMemory<U>* scale_backprop,
611       DeviceMemory<U>* offset_backprop, DeviceMemory<uint8>* reserve_space_data,
612       ScratchAllocator* workspace_allocator);
613 
614   template <typename ElementType, typename BiasType, typename ScaleType,
615             typename OutputType>
616   port::Status DoFusedConvolveImpl(
617       Stream* stream, const dnn::BatchDescriptor& conv_input_descriptor,
618       const DeviceMemory<ElementType>& conv_input_data,
619       ScaleType conv_input_scale,
620       const dnn::FilterDescriptor& filter_descriptor,
621       const DeviceMemory<ElementType>& filter_data,
622       const dnn::ConvolutionDescriptor& convolution_descriptor,
623       const DeviceMemory<OutputType>& side_input_data,
624       ScaleType side_input_scale, const dnn::BatchDescriptor& bias_descriptor,
625       const DeviceMemory<BiasType>& biases, dnn::ActivationMode activation_mode,
626       const dnn::BatchDescriptor& output_descriptor,
627       DeviceMemory<OutputType>* output_data, dnn::DataType accumulator_type,
628       ScratchAllocator* scratch_allocator,
629       const dnn::AlgorithmConfig& algorithm_config,
630       dnn::ProfileResult* output_profile_result);
631 
632   template <class T>
633   port::Status DoConvolveBackwardBiasImpl(
634       Stream* stream, const dnn::BatchDescriptor& input_descriptor,
635       const DeviceMemory<T>& input_data,
636       const dnn::BatchDescriptor& bias_descriptor,
637       DeviceMemory<T>* backward_bias_data);
638 
639   template <class T>
640   port::Status DoRnnForwardImpl(
641       Stream* stream, const CudnnRnnDescriptor& rnn_desc,
642       const CudnnRnnSequenceTensorDescriptor& input_desc,
643       const DeviceMemory<T>& input_data,
644       const CudnnRnnStateTensorDescriptor& input_h_desc,
645       const DeviceMemory<T>& input_h_data,
646       const CudnnRnnStateTensorDescriptor& input_c_desc,
647       const DeviceMemory<T>& input_c_data, const DeviceMemory<T>& params,
648       const CudnnRnnSequenceTensorDescriptor& output_desc,
649       DeviceMemory<T>* output_data,
650       const CudnnRnnStateTensorDescriptor& output_h_desc,
651       DeviceMemory<T>* output_h_data,
652       const CudnnRnnStateTensorDescriptor& output_c_desc,
653       DeviceMemory<T>* output_c_data, bool is_training,
654       ScratchAllocator* reserve_space_allocator,
655       ScratchAllocator* workspace_allocator,
656       dnn::ProfileResult* output_profile_result);
657 
658   template <class T>
659   port::Status DoRnnBackwardImpl(
660       Stream* stream, const CudnnRnnDescriptor& rnn_desc,
661       const CudnnRnnSequenceTensorDescriptor& input_desc,
662       const DeviceMemory<T>& input_data,
663       const CudnnRnnStateTensorDescriptor& input_h_desc,
664       const DeviceMemory<T>& input_h_data,
665       const CudnnRnnStateTensorDescriptor& input_c_desc,
666       const DeviceMemory<T>& input_c_data, const DeviceMemory<T>& params,
667       const CudnnRnnSequenceTensorDescriptor& output_desc,
668       const DeviceMemory<T>& output_data,
669       const CudnnRnnStateTensorDescriptor& output_h_desc,
670       const DeviceMemory<T>& output_h_data,
671       const CudnnRnnStateTensorDescriptor& output_c_desc,
672       const DeviceMemory<T>& output_c_data,
673       const DeviceMemory<T>& output_backprop_data,
674       const DeviceMemory<T>& output_h_backprop_data,
675       const DeviceMemory<T>& output_c_backprop_data,
676       DeviceMemory<T>* input_backprop_data,
677       DeviceMemory<T>* input_h_backprop_data,
678       DeviceMemory<T>* input_c_backprop_data,
679       DeviceMemory<T>* params_backprop_data,
680       DeviceMemory<uint8>* reserve_space_data,
681       ScratchAllocator* workspace_allocator,
682       dnn::ProfileResult* output_profile_result);
683 
684   port::Status DoCtcLossImpl(
685       Stream* stream, const CudnnRnnStateTensorDescriptor& probs_desc,
686       const DeviceMemoryBase probs_data, absl::Span<const int> labels_data,
687       absl::Span<const int> labels_lengths_data,
688       absl::Span<const int> input_lengths_data, DeviceMemoryBase costs_data,
689       const CudnnRnnStateTensorDescriptor& grads_desc,
690       DeviceMemoryBase grads_data, const CudnnCtcLossDescriptor& ctc_loss_desc,
691       DeviceMemory<uint8> scratch_memory, int ctc_loss_algo_id);
692 
693  private:
694   port::Status DoPrepareForConvolution(
695       dnn::ConvolutionKind kind, dnn::DataType element_type, Stream* stream,
696       const dnn::BatchDescriptor& input_descriptor, DeviceMemoryBase input_data,
697       const dnn::FilterDescriptor& filter_descriptor,
698       DeviceMemoryBase filter_data,
699       const dnn::BatchDescriptor& output_descriptor,
700       DeviceMemoryBase output_data,
701       const dnn::ConvolutionDescriptor& convolution_descriptor,
702       const dnn::AlgorithmConfig& algorithm_config,
703       ScratchAllocator* scratch_allocator, dnn::AlgorithmDesc* algorithm_desc,
704       DeviceMemory<uint8>* scratch_memory) override;
705 
706   port::Status DoPrepareForCtcLoss(
707       Stream* stream, dnn::DataType element_type,
708       const dnn::RnnStateTensorDescriptor& probs_desc,
709       const dnn::RnnStateTensorDescriptor& grads_desc,
710       absl::Span<const int> labels_data,
711       absl::Span<const int> labels_lengths_data,
712       absl::Span<const int> input_lengths_data,
713       ScratchAllocator* scratch_allocator, DeviceMemory<uint8>* scratch_memory,
714       int* ctc_loss_algo_id) override;
715 
716   SE_DISALLOW_COPY_AND_ASSIGN(CudnnSupport);
717 };
718 
719 }  // namespace gpu
720 }  // namespace stream_executor
721 
722 #endif  // TENSORFLOW_STREAM_EXECUTOR_CUDA_CUDA_DNN_H_
723