1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_CORE_KERNELS_REDUCTION_OPS_H_ 17 #define TENSORFLOW_CORE_KERNELS_REDUCTION_OPS_H_ 18 19 // Functor definitions for Reduction ops, must be compilable by nvcc. 20 21 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" 22 #include "tensorflow/core/framework/op_kernel.h" 23 #include "tensorflow/core/framework/tensor_types.h" 24 25 namespace tensorflow { 26 namespace functor { 27 28 template <typename Reducer> 29 struct ReducerTraits { 30 enum { IsScalarIdentity = true }; 31 }; 32 33 // Dummy class used for template specialization for mean reduction, which is 34 // accomplished by SumReducer and on-the-fly division by the reduction factor. 35 template <typename Scalar> 36 struct MeanReducer { initializeMeanReducer37 Scalar initialize() const { return Scalar(0); } 38 }; 39 40 // Dummy class used for template specialization for l2-norm reduction. 41 template <typename Scalar> 42 struct EuclideanNormReducer { initializeEuclideanNormReducer43 Scalar initialize() const { return Scalar(0); } 44 }; 45 46 template <typename Scalar> 47 struct ReducerTraits<EuclideanNormReducer<Scalar>> { 48 enum { IsScalarIdentity = false }; 49 }; 50 51 template <typename Device, typename OUT_T, typename IN_T, 52 typename ReductionAxes, typename Reducer> 53 struct ReduceEigenImpl { 54 void operator()(const Device& d, OUT_T out, IN_T in, 55 const ReductionAxes& reduction_axes, const Reducer& reducer) { 56 out.device(d) = in.reduce(reduction_axes, reducer); 57 } 58 }; 59 60 // Specialization for BF16 Reducer to fix accuracy. 61 // TODO: All BF16 reducers should have specializations to fix accuracy. 62 #define CASTING_SPECIALIZATION(Reducer, ScalarType, IntermediateType) \ 63 template <typename Device, typename OUT_T, typename IN_T, \ 64 typename ReductionAxes> \ 65 struct ReduceEigenImpl<Device, OUT_T, IN_T, ReductionAxes, \ 66 Reducer<ScalarType>> { \ 67 void operator()(const Device& d, OUT_T out, IN_T in, \ 68 const ReductionAxes& reduction_axes, \ 69 const Reducer<ScalarType>& reducer) { \ 70 static_assert(std::is_same<ScalarType, typename OUT_T::Scalar>::value, \ 71 ""); \ 72 Reducer<IntermediateType> intermediate_reducer; \ 73 auto in_as_intermediate = in.template cast<IntermediateType>(); \ 74 out.device(d) = \ 75 in_as_intermediate.reduce(reduction_axes, intermediate_reducer) \ 76 .template cast<ScalarType>(); \ 77 } \ 78 }; 79 80 CASTING_SPECIALIZATION(Eigen::internal::SumReducer, bfloat16, float); 81 #undef CASTING_SPECIALIZATION 82 83 template <typename Device, typename OUT_T, typename IN_T, 84 typename ReductionAxes, typename Scalar> 85 struct ReduceEigenImpl<Device, OUT_T, IN_T, ReductionAxes, 86 functor::MeanReducer<Scalar>> { 87 void operator()(const Device& d, OUT_T out, IN_T in, 88 const ReductionAxes& reduction_axes, 89 const functor::MeanReducer<Scalar>& reducer) { 90 static_assert(std::is_same<Scalar, typename OUT_T::Scalar>::value, ""); 91 Eigen::internal::SumReducer<Scalar> sum_reducer; 92 out.device(d) = in.reduce(reduction_axes, sum_reducer) / 93 static_cast<Scalar>(in.size() / out.size()); 94 } 95 }; 96 97 // Specialization for which we do the reduction in IntermediateType to 98 // avoid integer overflow. 99 #define CASTING_SPECIALIZATION(ScalarType, IntermediateType) \ 100 template <typename Device, typename OUT_T, typename IN_T, \ 101 typename ReductionAxes> \ 102 struct ReduceEigenImpl<Device, OUT_T, IN_T, ReductionAxes, \ 103 functor::MeanReducer<ScalarType>> { \ 104 void operator()(const Device& d, OUT_T out, IN_T in, \ 105 const ReductionAxes& reduction_axes, \ 106 const functor::MeanReducer<ScalarType>& reducer) { \ 107 static_assert(std::is_same<ScalarType, typename OUT_T::Scalar>::value, \ 108 ""); \ 109 Eigen::internal::SumReducer<IntermediateType> sum_reducer; \ 110 out.device(d) = (in.template cast<IntermediateType>().reduce( \ 111 reduction_axes, sum_reducer) / \ 112 static_cast<IntermediateType>(in.size() / out.size())) \ 113 .template cast<ScalarType>(); \ 114 } \ 115 } 116 117 CASTING_SPECIALIZATION(uint8, uint64); 118 CASTING_SPECIALIZATION(uint16, uint64); 119 CASTING_SPECIALIZATION(uint32, uint64); 120 CASTING_SPECIALIZATION(int8, int64); 121 CASTING_SPECIALIZATION(int16, int64); 122 CASTING_SPECIALIZATION(int32, int64); 123 #undef CASTING_SPECIALIZATION 124 125 // TODO(rmlarsen): Refactor this such that taking the sqrt can be optional 126 // controlled by an attribute. 127 template <typename Device, typename OUT_T, typename IN_T, 128 typename ReductionAxes, typename Scalar> 129 struct ReduceEigenImpl<Device, OUT_T, IN_T, ReductionAxes, 130 functor::EuclideanNormReducer<Scalar>> { 131 void operator()(const Device& d, OUT_T out, IN_T in, 132 const ReductionAxes& reduction_axes, 133 const functor::EuclideanNormReducer<Scalar>& reducer) { 134 static_assert(std::is_same<Scalar, typename OUT_T::Scalar>::value, ""); 135 Eigen::internal::SumReducer<Scalar> sum_reducer; 136 out.device(d) = 137 (in * in.conjugate()).reduce(reduction_axes, sum_reducer).sqrt(); 138 } 139 }; 140 141 template <typename Device, typename OUT_T, typename IN_T, 142 typename ReductionAxes> 143 struct ReduceEigenImpl<Device, OUT_T, IN_T, ReductionAxes, 144 functor::EuclideanNormReducer<bfloat16>> { 145 void operator()(const Device& d, OUT_T out, IN_T in, 146 const ReductionAxes& reduction_axes, 147 const functor::EuclideanNormReducer<bfloat16>& reducer) { 148 static_assert(std::is_same<bfloat16, typename OUT_T::Scalar>::value, ""); 149 Eigen::internal::SumReducer<float> sum_reducer; 150 auto in_as_float = in.template cast<float>(); 151 out.device(d) = (in_as_float * in_as_float.conjugate()) 152 .reduce(reduction_axes, sum_reducer) 153 .sqrt() 154 .template cast<bfloat16>(); 155 } 156 }; 157 158 // For most reducers, the identity is Reducer::initialize() 159 template <typename Reducer> 160 struct Identity { 161 static auto identity(const Reducer& reducer) 162 -> decltype(reducer.initialize()) { 163 return reducer.initialize(); 164 } 165 }; 166 167 // MeanReducer is a special case, since it doesn't technically have an identity. 168 // Thus, ideally we'd return nan. However, mean is instantiated for integer 169 // types as well, so we do the nan override only for floating point types. 170 #define FIX_MEAN_IDENTITY(T) \ 171 template <> \ 172 struct Identity<functor::MeanReducer<T>> { \ 173 static T identity(const functor::MeanReducer<T>&) { \ 174 return Eigen::NumTraits<T>::quiet_NaN(); \ 175 } \ 176 }; 177 FIX_MEAN_IDENTITY(Eigen::half) 178 FIX_MEAN_IDENTITY(float) 179 FIX_MEAN_IDENTITY(double) 180 #undef FIX_MEAN_IDENTITY 181 182 template <typename Device, typename OUT_T, typename Reducer> 183 void FillIdentityEigenImpl(const Device& d, OUT_T out, const Reducer& reducer) { 184 out.device(d) = out.constant(Identity<Reducer>::identity(reducer)); 185 } 186 187 template <typename Device, typename Reducer> 188 struct ReduceFunctor { 189 template <typename OUT_T, typename IN_T, typename ReductionAxes> 190 static void Reduce(OpKernelContext* ctx, OUT_T out, IN_T in, 191 const ReductionAxes& reduction_axes, 192 const Reducer& reducer); 193 194 template <typename OUT_T> 195 static void FillIdentity(const Device& d, OUT_T out, const Reducer& reducer); 196 }; 197 198 } // namespace functor 199 } // namespace tensorflow 200 201 #endif // TENSORFLOW_CORE_KERNELS_REDUCTION_OPS_H_ 202