1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/lib/strings/str_util.h"
17 #if GOOGLE_CUDA
18 
19 #define EIGEN_USE_GPU
20 
21 #include "tensorflow/core/framework/op_kernel.h"
22 #include "tensorflow/core/framework/register_types.h"
23 #include "tensorflow/core/framework/tensor.h"
24 #include "tensorflow/core/framework/tensor_shape.h"
25 #include "tensorflow/core/framework/tensor_types.h"
26 #include "tensorflow/core/platform/types.h"
27 
28 #include "tensorflow/core/util/cuda_kernel_helper.h"
29 
30 #include "tensorflow/core/kernels/reduction_gpu_kernels.cu.h"
31 #include "tensorflow/core/kernels/reduction_ops_common.h"
32 
33 namespace tensorflow {
34 
35 namespace {
36 
37 template <typename U, typename T>
38 __device__ __host__ EIGEN_STRONG_INLINE
39     typename std::enable_if<!std::is_same<T, U>::value, U>::type
40     strict_cast(T t);
41 
42 template <typename U, typename T>
43 __device__ __host__ EIGEN_STRONG_INLINE
44     typename std::enable_if<std::is_same<T, U>::value, U>::type
strict_cast(T t)45     strict_cast(T t) {
46   return t;
47 }
48 
49 template <>
strict_cast(Eigen::half t)50 __device__ __host__ EIGEN_STRONG_INLINE float strict_cast<float, Eigen::half>(
51     Eigen::half t) {
52   return functor::HalfToFloat()(t);
53 }
54 
55 template <>
56 __device__ __host__ EIGEN_STRONG_INLINE Eigen::half
strict_cast(float t)57 strict_cast<Eigen::half, float>(float t) {
58   return functor::FloatToHalf()(t);
59 }
60 
61 template <typename T>
62 struct softmax_traits {
63   using accumulator_type = T;
64 };
65 
66 template <>
67 struct softmax_traits<Eigen::half> {
68   using accumulator_type = float;
69 };
70 
71 template <typename T, typename U>
GenerateNormalizedProb(const T * logits,const U * sum_probs,const T * max_logits,T * output,const int num_rows,const int num_cols,const bool in_log_space)72 __global__ void GenerateNormalizedProb(const T* logits, const U* sum_probs,
73                                        const T* max_logits, T* output,
74                                        const int num_rows, const int num_cols,
75                                        const bool in_log_space) {
76   const int tid = blockIdx.x * blockDim.x + threadIdx.x;
77 
78   const int row = tid / num_cols;
79   const int col = tid % num_cols;
80 
81   // TODO(jamesqin): change to half2 load when inputs are Eigen::half.
82   U input = strict_cast<U>(logits[tid]);
83   U max_val = strict_cast<U>(ldg(max_logits + row));
84   U result;
85 
86   if (row < num_rows && col < num_cols) {
87     if (in_log_space) {
88       result = input - max_val - log(ldg(sum_probs + row));
89     } else {
90       result = exp(input - max_val) / ldg(sum_probs + row);
91     }
92     output[tid] = strict_cast<T>(result);
93   }
94 }
95 
96 template <typename T, typename U>
97 struct SubtractAndExpFunctor {
SubtractAndExpFunctortensorflow::__anon8a38f2b50111::SubtractAndExpFunctor98   __host__ __device__ SubtractAndExpFunctor(const T* logits,
99                                             const T* max_logits,
100                                             const int num_cols)
101       : logits_(logits), max_logits_(max_logits), num_cols_(num_cols) {}
102 
operator ()tensorflow::__anon8a38f2b50111::SubtractAndExpFunctor103   __host__ __device__ U operator()(const int gid) const {
104     // TODO(jamesqin): change to half2 load when inputs are Eigen::half.
105     const U diff =
106         strict_cast<U>(logits_[gid] - ldg(max_logits_ + gid / num_cols_));
107     return exp(diff);
108   }
109 
110   const T* logits_;
111   const T* max_logits_;
112   const int num_cols_;
113 };
114 
115 template <typename T, typename Op, typename InputIter>
DoRowReduction(OpKernelContext * context,T * output,InputIter input,int rows,int cols)116 void DoRowReduction(OpKernelContext* context, T* output, InputIter input,
117                     int rows, int cols) {
118   typedef const Eigen::array<TTypes<float>::Tensor::Index, 1>& ReductionAxes;
119   Constants<GPUDevice> constants;
120 
121   Op op;
122 
123   functor::ReduceImpl<T, Op, T*, InputIter, ReductionAxes>(
124       context, output, input, 2, rows, cols, 1, 1, constants.kOne, op);
125 }
126 }  // namespace
127 
128 template <typename T>
129 class SoftmaxOpGPU : public OpKernel {
130  public:
SoftmaxOpGPU(OpKernelConstruction * context)131   explicit SoftmaxOpGPU(OpKernelConstruction* context) : OpKernel(context) {
132     log_ = str_util::StartsWith(type_string(), "Log");
133   }
134 
Compute(OpKernelContext * context)135   void Compute(OpKernelContext* context) override {
136     const Tensor& logits_in_ = context->input(0);
137     OP_REQUIRES(context, TensorShapeUtils::IsVectorOrHigher(logits_in_.shape()),
138                 errors::InvalidArgument("logits must have >= 1 dimension, got ",
139                                         logits_in_.shape().DebugString()));
140     auto logits_in = logits_in_.flat_inner_dims<T>();
141     const int rows = logits_in.dimension(0);
142     const int cols = logits_in.dimension(1);
143     Tensor* softmax_out = nullptr;
144     OP_REQUIRES_OK(context, context->forward_input_or_allocate_output(
145                                 {0}, 0, logits_in_.shape(), &softmax_out));
146 
147     const cudaStream_t& cu_stream = GetCudaStream(context);
148     if (logits_in_.NumElements() > 0) {
149       Tensor max_logits;
150       Tensor sum_probs;
151       OP_REQUIRES_OK(context,
152                      context->allocate_temp(DataTypeToEnum<T>::value,
153                                             softmax_out->shape(), &max_logits));
154 
155       typedef typename softmax_traits<T>::accumulator_type acc_type;
156       OP_REQUIRES_OK(context,
157                      context->allocate_temp(DataTypeToEnum<acc_type>::value,
158                                             softmax_out->shape(), &sum_probs));
159 
160       DoRowReduction<T, cub::Max, const T*>(
161           context, const_cast<T*>(max_logits.flat<T>().data()),
162           reinterpret_cast<const T*>(logits_in_.flat<T>().data()), rows, cols);
163 
164       const int numThreads = 128;
165       const int numBlocks = Eigen::divup(rows * cols, numThreads);
166 
167       cub::CountingInputIterator<int> counting_iterator(0);
168       typedef cub::TransformInputIterator<acc_type,
169                                           SubtractAndExpFunctor<T, acc_type>,
170                                           cub::CountingInputIterator<int>>
171           InputIterType;
172 
173       InputIterType input_itr(
174           counting_iterator,
175           SubtractAndExpFunctor<T, acc_type>(
176               reinterpret_cast<const T*>(logits_in_.flat<T>().data()),
177               reinterpret_cast<const T*>(max_logits.flat<T>().data()), cols));
178 
179       DoRowReduction<acc_type, cub::Sum, InputIterType>(
180           context, const_cast<acc_type*>(sum_probs.flat<acc_type>().data()),
181           input_itr, rows, cols);
182 
183       TF_CHECK_OK(CudaLaunchKernel(
184           GenerateNormalizedProb<T, acc_type>, numBlocks, numThreads, 0,
185           cu_stream, reinterpret_cast<const T*>(logits_in_.flat<T>().data()),
186           reinterpret_cast<const acc_type*>(sum_probs.flat<acc_type>().data()),
187           reinterpret_cast<const T*>(max_logits.flat<T>().data()),
188           const_cast<T*>(softmax_out->flat<T>().data()), rows, cols, log_));
189     }
190   }
191 
192  private:
193   bool log_;
194 };
195 
196 REGISTER_KERNEL_BUILDER(
197     Name("Softmax").Device(DEVICE_GPU).TypeConstraint<Eigen::half>("T"),
198     SoftmaxOpGPU<Eigen::half>);
199 REGISTER_KERNEL_BUILDER(
200     Name("Softmax").Device(DEVICE_GPU).TypeConstraint<float>("T"),
201     SoftmaxOpGPU<float>);
202 REGISTER_KERNEL_BUILDER(
203     Name("Softmax").Device(DEVICE_GPU).TypeConstraint<double>("T"),
204     SoftmaxOpGPU<double>);
205 REGISTER_KERNEL_BUILDER(
206     Name("LogSoftmax").Device(DEVICE_GPU).TypeConstraint<Eigen::half>("T"),
207     SoftmaxOpGPU<Eigen::half>);
208 REGISTER_KERNEL_BUILDER(
209     Name("LogSoftmax").Device(DEVICE_GPU).TypeConstraint<float>("T"),
210     SoftmaxOpGPU<float>);
211 
212 }  // end namespace tensorflow
213 
214 #endif  // GOOGLE_CUDA
215