1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_CUDNN_BATCHNORM_THUNK_H_ 17 #define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_CUDNN_BATCHNORM_THUNK_H_ 18 19 #include "tensorflow/compiler/xla/service/buffer_assignment.h" 20 #include "tensorflow/compiler/xla/service/gpu/buffer_allocations.h" 21 #include "tensorflow/compiler/xla/service/gpu/hlo_execution_profiler.h" 22 #include "tensorflow/compiler/xla/service/gpu/thunk.h" 23 #include "tensorflow/compiler/xla/service/hlo_instruction.h" 24 #include "tensorflow/compiler/xla/types.h" 25 #include "tensorflow/core/lib/core/status.h" 26 27 namespace xla { 28 namespace gpu { 29 30 // This file contains thunks which call into cudnn to run the various flavors of 31 // batch normalization: BatchNormInference, BatchNormTraining, and 32 // BatchNormGrad, known to cudnn as BatchNormForwardInference, 33 // BatchNormForwardTraining, and BatchNormBackward. 34 // 35 // As an alternative to using these thunks, XLA can decompose batchnorm HLOs 36 // into smaller components using the BatchNormRewriter pass. This can result in 37 // faster code because those individual components can fuse into their 38 // inputs/outputs, but it may also be slower if cudnn's batchnorm implementation 39 // outperforms the code XLA generates for these components. 40 // 41 // Currently these thunks require that their inputs are F32s. 42 // 43 // Note that these thunks do not take full advantage of the cudnn batchnorm 44 // functions. For example, cudnn lets you bias and/or scale the input/output, 45 // but these thunks don't currently support that. 46 47 class CudnnBatchNormForwardInferenceThunk : public Thunk { 48 public: 49 CudnnBatchNormForwardInferenceThunk(const BufferAllocation::Slice& operand, 50 const BufferAllocation::Slice& scale, 51 const BufferAllocation::Slice& offset, 52 const BufferAllocation::Slice& mean, 53 const BufferAllocation::Slice& variance, 54 float epsilon, int64 feature_index, 55 const BufferAllocation::Slice& output, 56 const HloInstruction* hlo); 57 58 CudnnBatchNormForwardInferenceThunk( 59 const CudnnBatchNormForwardInferenceThunk&) = delete; 60 CudnnBatchNormForwardInferenceThunk& operator=( 61 const CudnnBatchNormForwardInferenceThunk&) = delete; 62 63 Status ExecuteOnStream(const BufferAllocations& buffer_allocations, 64 se::Stream* stream, 65 HloExecutionProfiler* profiler) override; 66 67 private: 68 BufferAllocation::Slice operand_; 69 BufferAllocation::Slice scale_; 70 BufferAllocation::Slice offset_; 71 BufferAllocation::Slice mean_; 72 BufferAllocation::Slice variance_; 73 float epsilon_; 74 int64 feature_index_; 75 BufferAllocation::Slice output_; 76 }; 77 78 class CudnnBatchNormForwardTrainingThunk : public Thunk { 79 public: 80 CudnnBatchNormForwardTrainingThunk( 81 const BufferAllocation::Slice& operand, 82 const BufferAllocation::Slice& scale, 83 const BufferAllocation::Slice& offset, float epsilon, int64 feature_index, 84 const BufferAllocation::Slice& output_data, 85 const BufferAllocation::Slice& output_mean, 86 const BufferAllocation::Slice& output_inv_stddev, 87 const BufferAllocation::Slice& output_tuple, const HloInstruction* hlo); 88 89 CudnnBatchNormForwardTrainingThunk( 90 const CudnnBatchNormForwardTrainingThunk&) = delete; 91 CudnnBatchNormForwardTrainingThunk& operator=( 92 const CudnnBatchNormForwardTrainingThunk&) = delete; 93 94 Status ExecuteOnStream(const BufferAllocations& buffer_allocations, 95 se::Stream* stream, 96 HloExecutionProfiler* profiler) override; 97 98 private: 99 BufferAllocation::Slice operand_; 100 BufferAllocation::Slice scale_; 101 BufferAllocation::Slice offset_; 102 float epsilon_; 103 int64 feature_index_; 104 BufferAllocation::Slice output_data_; 105 BufferAllocation::Slice output_mean_; 106 BufferAllocation::Slice output_inv_stddev_; 107 BufferAllocation::Slice output_tuple_; 108 }; 109 110 class CudnnBatchNormBackwardThunk : public Thunk { 111 public: 112 CudnnBatchNormBackwardThunk(const BufferAllocation::Slice& operand, 113 const BufferAllocation::Slice& scale, 114 const BufferAllocation::Slice& mean, 115 const BufferAllocation::Slice& inv_stddev, 116 const BufferAllocation::Slice& grad_output, 117 float epsilon, int64 feature_index, 118 const BufferAllocation::Slice& output_grad_data, 119 const BufferAllocation::Slice& output_grad_scale, 120 const BufferAllocation::Slice& output_grad_offset, 121 const BufferAllocation::Slice& output_tuple, 122 const HloInstruction* hlo); 123 124 CudnnBatchNormBackwardThunk(const CudnnBatchNormBackwardThunk&) = delete; 125 CudnnBatchNormBackwardThunk& operator=(const CudnnBatchNormBackwardThunk&) = 126 delete; 127 128 Status ExecuteOnStream(const BufferAllocations& buffer_allocations, 129 se::Stream* stream, 130 HloExecutionProfiler* profiler) override; 131 132 private: 133 BufferAllocation::Slice operand_; 134 BufferAllocation::Slice scale_; 135 BufferAllocation::Slice mean_; 136 BufferAllocation::Slice inv_stddev_; 137 BufferAllocation::Slice grad_output_; 138 float epsilon_; 139 int64 feature_index_; 140 BufferAllocation::Slice output_grad_data_; 141 BufferAllocation::Slice output_grad_scale_; 142 BufferAllocation::Slice output_grad_offset_; 143 BufferAllocation::Slice output_tuple_; 144 }; 145 146 } // namespace gpu 147 } // namespace xla 148 149 #endif // TENSORFLOW_COMPILER_XLA_SERVICE_GPU_CUDNN_BATCHNORM_THUNK_H_ 150