1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_CORE_KERNELS_REDUCTION_OPS_H_
17 #define TENSORFLOW_CORE_KERNELS_REDUCTION_OPS_H_
18 
19 // Functor definitions for Reduction ops, must be compilable by nvcc.
20 
21 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
22 #include "tensorflow/core/framework/op_kernel.h"
23 #include "tensorflow/core/framework/tensor_types.h"
24 
25 namespace tensorflow {
26 namespace functor {
27 
28 template <typename Reducer>
29 struct ReducerTraits {
30   enum { IsScalarIdentity = true };
31 };
32 
33 // Dummy class used for template specialization for mean reduction, which is
34 // accomplished by SumReducer and on-the-fly division by the reduction factor.
35 template <typename Scalar>
36 struct MeanReducer {
initializeMeanReducer37   Scalar initialize() const { return Scalar(0); }
38 };
39 
40 // Dummy class used for template specialization for l2-norm reduction.
41 template <typename Scalar>
42 struct EuclideanNormReducer {
initializeEuclideanNormReducer43   Scalar initialize() const { return Scalar(0); }
44 };
45 
46 template <typename Scalar>
47 struct ReducerTraits<EuclideanNormReducer<Scalar>> {
48   enum { IsScalarIdentity = false };
49 };
50 
51 template <typename Device, typename OUT_T, typename IN_T,
52           typename ReductionAxes, typename Reducer>
53 struct ReduceEigenImpl {
54   void operator()(const Device& d, OUT_T out, IN_T in,
55                   const ReductionAxes& reduction_axes, const Reducer& reducer) {
56     out.device(d) = in.reduce(reduction_axes, reducer);
57   }
58 };
59 
60 // Specialization for BF16 Reducer to fix accuracy.
61 // TODO: All BF16 reducers should have specializations to fix accuracy.
62 #define CASTING_SPECIALIZATION(Reducer, ScalarType, IntermediateType)        \
63   template <typename Device, typename OUT_T, typename IN_T,                  \
64             typename ReductionAxes>                                          \
65   struct ReduceEigenImpl<Device, OUT_T, IN_T, ReductionAxes,                 \
66                          Reducer<ScalarType>> {                              \
67     void operator()(const Device& d, OUT_T out, IN_T in,                     \
68                     const ReductionAxes& reduction_axes,                     \
69                     const Reducer<ScalarType>& reducer) {                    \
70       static_assert(std::is_same<ScalarType, typename OUT_T::Scalar>::value, \
71                     "");                                                     \
72       Reducer<IntermediateType> intermediate_reducer;                        \
73       auto in_as_intermediate = in.template cast<IntermediateType>();        \
74       out.device(d) =                                                        \
75           in_as_intermediate.reduce(reduction_axes, intermediate_reducer)    \
76               .template cast<ScalarType>();                                  \
77     }                                                                        \
78   };
79 
80 CASTING_SPECIALIZATION(Eigen::internal::SumReducer, bfloat16, float);
81 #undef CASTING_SPECIALIZATION
82 
83 template <typename Device, typename OUT_T, typename IN_T,
84           typename ReductionAxes, typename Scalar>
85 struct ReduceEigenImpl<Device, OUT_T, IN_T, ReductionAxes,
86                        functor::MeanReducer<Scalar>> {
87   void operator()(const Device& d, OUT_T out, IN_T in,
88                   const ReductionAxes& reduction_axes,
89                   const functor::MeanReducer<Scalar>& reducer) {
90     static_assert(std::is_same<Scalar, typename OUT_T::Scalar>::value, "");
91     Eigen::internal::SumReducer<Scalar> sum_reducer;
92     out.device(d) = in.reduce(reduction_axes, sum_reducer) /
93                     static_cast<Scalar>(in.size() / out.size());
94   }
95 };
96 
97 // Specialization for which we do the reduction in IntermediateType to
98 // avoid integer overflow.
99 #define CASTING_SPECIALIZATION(ScalarType, IntermediateType)                  \
100   template <typename Device, typename OUT_T, typename IN_T,                   \
101             typename ReductionAxes>                                           \
102   struct ReduceEigenImpl<Device, OUT_T, IN_T, ReductionAxes,                  \
103                          functor::MeanReducer<ScalarType>> {                  \
104     void operator()(const Device& d, OUT_T out, IN_T in,                      \
105                     const ReductionAxes& reduction_axes,                      \
106                     const functor::MeanReducer<ScalarType>& reducer) {        \
107       static_assert(std::is_same<ScalarType, typename OUT_T::Scalar>::value,  \
108                     "");                                                      \
109       Eigen::internal::SumReducer<IntermediateType> sum_reducer;              \
110       out.device(d) = (in.template cast<IntermediateType>().reduce(           \
111                            reduction_axes, sum_reducer) /                     \
112                        static_cast<IntermediateType>(in.size() / out.size())) \
113                           .template cast<ScalarType>();                       \
114     }                                                                         \
115   }
116 
117 CASTING_SPECIALIZATION(uint8, uint64);
118 CASTING_SPECIALIZATION(uint16, uint64);
119 CASTING_SPECIALIZATION(uint32, uint64);
120 CASTING_SPECIALIZATION(int8, int64);
121 CASTING_SPECIALIZATION(int16, int64);
122 CASTING_SPECIALIZATION(int32, int64);
123 #undef CASTING_SPECIALIZATION
124 
125 // TODO(rmlarsen): Refactor this such that taking the sqrt can be optional
126 // controlled by an attribute.
127 template <typename Device, typename OUT_T, typename IN_T,
128           typename ReductionAxes, typename Scalar>
129 struct ReduceEigenImpl<Device, OUT_T, IN_T, ReductionAxes,
130                        functor::EuclideanNormReducer<Scalar>> {
131   void operator()(const Device& d, OUT_T out, IN_T in,
132                   const ReductionAxes& reduction_axes,
133                   const functor::EuclideanNormReducer<Scalar>& reducer) {
134     static_assert(std::is_same<Scalar, typename OUT_T::Scalar>::value, "");
135     Eigen::internal::SumReducer<Scalar> sum_reducer;
136     out.device(d) =
137         (in * in.conjugate()).reduce(reduction_axes, sum_reducer).sqrt();
138   }
139 };
140 
141 template <typename Device, typename OUT_T, typename IN_T,
142           typename ReductionAxes>
143 struct ReduceEigenImpl<Device, OUT_T, IN_T, ReductionAxes,
144                        functor::EuclideanNormReducer<bfloat16>> {
145   void operator()(const Device& d, OUT_T out, IN_T in,
146                   const ReductionAxes& reduction_axes,
147                   const functor::EuclideanNormReducer<bfloat16>& reducer) {
148     static_assert(std::is_same<bfloat16, typename OUT_T::Scalar>::value, "");
149     Eigen::internal::SumReducer<float> sum_reducer;
150     auto in_as_float = in.template cast<float>();
151     out.device(d) = (in_as_float * in_as_float.conjugate())
152                         .reduce(reduction_axes, sum_reducer)
153                         .sqrt()
154                         .template cast<bfloat16>();
155   }
156 };
157 
158 // For most reducers, the identity is Reducer::initialize()
159 template <typename Reducer>
160 struct Identity {
161   static auto identity(const Reducer& reducer)
162       -> decltype(reducer.initialize()) {
163     return reducer.initialize();
164   }
165 };
166 
167 // MeanReducer is a special case, since it doesn't technically have an identity.
168 // Thus, ideally we'd return nan.  However, mean is instantiated for integer
169 // types as well, so we do the nan override only for floating point types.
170 #define FIX_MEAN_IDENTITY(T)                            \
171   template <>                                           \
172   struct Identity<functor::MeanReducer<T>> {            \
173     static T identity(const functor::MeanReducer<T>&) { \
174       return Eigen::NumTraits<T>::quiet_NaN();          \
175     }                                                   \
176   };
177 FIX_MEAN_IDENTITY(Eigen::half)
178 FIX_MEAN_IDENTITY(float)
179 FIX_MEAN_IDENTITY(double)
180 #undef FIX_MEAN_IDENTITY
181 
182 template <typename Device, typename OUT_T, typename Reducer>
183 void FillIdentityEigenImpl(const Device& d, OUT_T out, const Reducer& reducer) {
184   out.device(d) = out.constant(Identity<Reducer>::identity(reducer));
185 }
186 
187 template <typename Device, typename Reducer>
188 struct ReduceFunctor {
189   template <typename OUT_T, typename IN_T, typename ReductionAxes>
190   static void Reduce(OpKernelContext* ctx, OUT_T out, IN_T in,
191                      const ReductionAxes& reduction_axes,
192                      const Reducer& reducer);
193 
194   template <typename OUT_T>
195   static void FillIdentity(const Device& d, OUT_T out, const Reducer& reducer);
196 };
197 
198 }  // namespace functor
199 }  // namespace tensorflow
200 
201 #endif  // TENSORFLOW_CORE_KERNELS_REDUCTION_OPS_H_
202