1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_CUDNN_BATCHNORM_THUNK_H_
17 #define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_CUDNN_BATCHNORM_THUNK_H_
18 
19 #include "tensorflow/compiler/xla/service/buffer_assignment.h"
20 #include "tensorflow/compiler/xla/service/gpu/buffer_allocations.h"
21 #include "tensorflow/compiler/xla/service/gpu/hlo_execution_profiler.h"
22 #include "tensorflow/compiler/xla/service/gpu/thunk.h"
23 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
24 #include "tensorflow/compiler/xla/types.h"
25 #include "tensorflow/core/lib/core/status.h"
26 
27 namespace xla {
28 namespace gpu {
29 
30 // This file contains thunks which call into cudnn to run the various flavors of
31 // batch normalization: BatchNormInference, BatchNormTraining, and
32 // BatchNormGrad, known to cudnn as BatchNormForwardInference,
33 // BatchNormForwardTraining, and BatchNormBackward.
34 //
35 // As an alternative to using these thunks, XLA can decompose batchnorm HLOs
36 // into smaller components using the BatchNormRewriter pass.  This can result in
37 // faster code because those individual components can fuse into their
38 // inputs/outputs, but it may also be slower if cudnn's batchnorm implementation
39 // outperforms the code XLA generates for these components.
40 //
41 // Currently these thunks require that their inputs are F32s.
42 //
43 // Note that these thunks do not take full advantage of the cudnn batchnorm
44 // functions.  For example, cudnn lets you bias and/or scale the input/output,
45 // but these thunks don't currently support that.
46 
47 class CudnnBatchNormForwardInferenceThunk : public Thunk {
48  public:
49   CudnnBatchNormForwardInferenceThunk(const BufferAllocation::Slice& operand,
50                                       const BufferAllocation::Slice& scale,
51                                       const BufferAllocation::Slice& offset,
52                                       const BufferAllocation::Slice& mean,
53                                       const BufferAllocation::Slice& variance,
54                                       float epsilon, int64 feature_index,
55                                       const BufferAllocation::Slice& output,
56                                       const HloInstruction* hlo);
57 
58   CudnnBatchNormForwardInferenceThunk(
59       const CudnnBatchNormForwardInferenceThunk&) = delete;
60   CudnnBatchNormForwardInferenceThunk& operator=(
61       const CudnnBatchNormForwardInferenceThunk&) = delete;
62 
63   Status ExecuteOnStream(const BufferAllocations& buffer_allocations,
64                          se::Stream* stream,
65                          HloExecutionProfiler* profiler) override;
66 
67  private:
68   BufferAllocation::Slice operand_;
69   BufferAllocation::Slice scale_;
70   BufferAllocation::Slice offset_;
71   BufferAllocation::Slice mean_;
72   BufferAllocation::Slice variance_;
73   float epsilon_;
74   int64 feature_index_;
75   BufferAllocation::Slice output_;
76 };
77 
78 class CudnnBatchNormForwardTrainingThunk : public Thunk {
79  public:
80   CudnnBatchNormForwardTrainingThunk(
81       const BufferAllocation::Slice& operand,
82       const BufferAllocation::Slice& scale,
83       const BufferAllocation::Slice& offset, float epsilon, int64 feature_index,
84       const BufferAllocation::Slice& output_data,
85       const BufferAllocation::Slice& output_mean,
86       const BufferAllocation::Slice& output_inv_stddev,
87       const BufferAllocation::Slice& output_tuple, const HloInstruction* hlo);
88 
89   CudnnBatchNormForwardTrainingThunk(
90       const CudnnBatchNormForwardTrainingThunk&) = delete;
91   CudnnBatchNormForwardTrainingThunk& operator=(
92       const CudnnBatchNormForwardTrainingThunk&) = delete;
93 
94   Status ExecuteOnStream(const BufferAllocations& buffer_allocations,
95                          se::Stream* stream,
96                          HloExecutionProfiler* profiler) override;
97 
98  private:
99   BufferAllocation::Slice operand_;
100   BufferAllocation::Slice scale_;
101   BufferAllocation::Slice offset_;
102   float epsilon_;
103   int64 feature_index_;
104   BufferAllocation::Slice output_data_;
105   BufferAllocation::Slice output_mean_;
106   BufferAllocation::Slice output_inv_stddev_;
107   BufferAllocation::Slice output_tuple_;
108 };
109 
110 class CudnnBatchNormBackwardThunk : public Thunk {
111  public:
112   CudnnBatchNormBackwardThunk(const BufferAllocation::Slice& operand,
113                               const BufferAllocation::Slice& scale,
114                               const BufferAllocation::Slice& mean,
115                               const BufferAllocation::Slice& inv_stddev,
116                               const BufferAllocation::Slice& grad_output,
117                               float epsilon, int64 feature_index,
118                               const BufferAllocation::Slice& output_grad_data,
119                               const BufferAllocation::Slice& output_grad_scale,
120                               const BufferAllocation::Slice& output_grad_offset,
121                               const BufferAllocation::Slice& output_tuple,
122                               const HloInstruction* hlo);
123 
124   CudnnBatchNormBackwardThunk(const CudnnBatchNormBackwardThunk&) = delete;
125   CudnnBatchNormBackwardThunk& operator=(const CudnnBatchNormBackwardThunk&) =
126       delete;
127 
128   Status ExecuteOnStream(const BufferAllocations& buffer_allocations,
129                          se::Stream* stream,
130                          HloExecutionProfiler* profiler) override;
131 
132  private:
133   BufferAllocation::Slice operand_;
134   BufferAllocation::Slice scale_;
135   BufferAllocation::Slice mean_;
136   BufferAllocation::Slice inv_stddev_;
137   BufferAllocation::Slice grad_output_;
138   float epsilon_;
139   int64 feature_index_;
140   BufferAllocation::Slice output_grad_data_;
141   BufferAllocation::Slice output_grad_scale_;
142   BufferAllocation::Slice output_grad_offset_;
143   BufferAllocation::Slice output_tuple_;
144 };
145 
146 }  // namespace gpu
147 }  // namespace xla
148 
149 #endif  // TENSORFLOW_COMPILER_XLA_SERVICE_GPU_CUDNN_BATCHNORM_THUNK_H_
150