1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_CORE_KERNELS_CWISE_OPS_GRADIENTS_H_ 17 #define TENSORFLOW_CORE_KERNELS_CWISE_OPS_GRADIENTS_H_ 18 19 #define EIGEN_USE_THREADS 20 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" 21 #include "tensorflow/core/kernels/cwise_ops.h" 22 23 namespace Eigen { 24 namespace internal { 25 26 // Gradient for the tanh function 27 template <typename T> 28 struct scalar_tanh_gradient_op { EIGEN_EMPTY_STRUCT_CTORscalar_tanh_gradient_op29 EIGEN_EMPTY_STRUCT_CTOR(scalar_tanh_gradient_op) 30 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const T 31 operator()(const T& output, const T& output_gradient) const { 32 return output_gradient * (T(1) - output * output); 33 } 34 template <typename Packet> 35 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Packet packetOpscalar_tanh_gradient_op36 packetOp(const Packet& output, const Packet& output_gradient) const { 37 return pmul(output_gradient, 38 psub(pset1<Packet>(T(1)), pmul(output, output))); 39 } 40 }; 41 template <typename T> 42 struct functor_traits<scalar_tanh_gradient_op<T>> { 43 enum { 44 Cost = NumTraits<T>::AddCost + 2 * NumTraits<T>::MulCost, 45 PacketAccess = packet_traits<T>::HasSub && packet_traits<T>::HasMul, 46 }; 47 }; 48 49 // Gradient for the sigmoid function 50 template <typename T> 51 struct scalar_sigmoid_gradient_op { 52 EIGEN_EMPTY_STRUCT_CTOR(scalar_sigmoid_gradient_op) 53 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const T 54 operator()(const T& output, const T& output_gradient) const { 55 return output_gradient * output * (T(1) - output); 56 } 57 template <typename Packet> 58 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Packet 59 packetOp(const Packet& output, const Packet& output_gradient) const { 60 return pmul(output_gradient, 61 pmul(output, psub(pset1<Packet>(T(1)), output))); 62 } 63 }; 64 template <typename T> 65 struct functor_traits<scalar_sigmoid_gradient_op<T>> { 66 enum { 67 Cost = NumTraits<T>::AddCost + 2 * NumTraits<T>::MulCost, 68 PacketAccess = packet_traits<T>::HasSub && packet_traits<T>::HasMul, 69 }; 70 }; 71 72 // Gradient for the inverse function 73 template <typename T> 74 struct scalar_inverse_gradient_op { 75 EIGEN_EMPTY_STRUCT_CTOR(scalar_inverse_gradient_op) 76 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const T 77 operator()(const T& output, const T& output_gradient) const { 78 if (output_gradient == T(0)) { 79 return T(0); 80 } else { 81 const T out_conj = numext::conj(output); 82 return -out_conj * out_conj * output_gradient; 83 } 84 } 85 template <typename Packet> 86 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Packet 87 packetOp(const Packet& output, const Packet& output_gradient) const { 88 const Packet out_conj = pconj(output); 89 return mul_no_nan_op<T>().packetOp(pnegate(pmul(out_conj, out_conj)), 90 output_gradient); 91 } 92 }; 93 template <typename T> 94 struct functor_traits<scalar_inverse_gradient_op<T>> { 95 enum { 96 Cost = NumTraits<T>::AddCost + 2 * NumTraits<T>::MulCost, 97 PacketAccess = packet_traits<T>::HasMul, 98 }; 99 }; 100 101 // Gradient for the sqrt function 102 template <typename T> 103 struct scalar_sqrt_gradient_op { 104 EIGEN_EMPTY_STRUCT_CTOR(scalar_sqrt_gradient_op) 105 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const T 106 operator()(const T& output, const T& output_gradient) const { 107 if (output_gradient == T(0)) { 108 return T(0); 109 } else { 110 const T out_conj = numext::conj(output); 111 return (static_cast<T>(0.5) * output_gradient) / out_conj; 112 } 113 } 114 template <typename Packet> 115 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Packet 116 packetOp(const Packet& output, const Packet& output_gradient) const { 117 const Packet const_half = pset1<Packet>(static_cast<T>(0.5)); 118 const Packet out_conj = pconj(output); 119 return mul_no_nan_op<T>().packetOp(pdiv(const_half, out_conj), 120 output_gradient); 121 } 122 }; 123 template <typename T> 124 struct functor_traits<scalar_sqrt_gradient_op<T>> { 125 enum { 126 PacketAccess = packet_traits<T>::HasMul & packet_traits<T>::HasDiv, 127 Cost = NumTraits<T>::MulCost + scalar_div_cost<T, PacketAccess>::value, 128 }; 129 }; 130 131 // Gradient for the rsqrt function 132 template <typename T> 133 struct scalar_rsqrt_gradient_op { 134 EIGEN_EMPTY_STRUCT_CTOR(scalar_rsqrt_gradient_op) 135 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const T 136 operator()(const T& output, const T& output_gradient) const { 137 if (output_gradient == T(0)) { 138 return T(0); 139 } else { 140 const T out_conj = numext::conj(output); 141 return static_cast<T>(-0.5) * (output_gradient * out_conj) * 142 (out_conj * out_conj); 143 } 144 } 145 template <typename Packet> 146 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Packet 147 packetOp(const Packet& output, const Packet& output_gradient) const { 148 const Packet const_half = pset1<Packet>(static_cast<T>(-0.5)); 149 const Packet out_conj = pconj(output); 150 auto safe_pmul = [](const Packet& a, const Packet& b) { 151 return mul_no_nan_op<T>().packetOp(a, b); 152 }; 153 return safe_pmul(pmul(const_half, pmul(out_conj, out_conj)), 154 safe_pmul(out_conj, output_gradient)); 155 } 156 }; 157 template <typename T> 158 struct functor_traits<scalar_rsqrt_gradient_op<T>> { 159 enum { 160 Cost = 4 * NumTraits<T>::MulCost, 161 PacketAccess = packet_traits<T>::HasMul, 162 }; 163 }; 164 165 } // end namespace internal 166 } // end namespace Eigen 167 168 namespace tensorflow { 169 170 namespace functor { 171 172 template <typename Device, typename Functor> 173 struct SimpleBinaryFunctor { 174 void operator()(const Device& d, typename Functor::tout_type out, 175 typename Functor::tin_type in0, 176 typename Functor::tin_type in1); 177 }; 178 179 // Partial specialization of BinaryFunctor for CPU devices 180 typedef Eigen::ThreadPoolDevice CPUDevice; 181 182 template <typename Functor> 183 struct SimpleBinaryFunctor<CPUDevice, Functor> { 184 void operator()(const CPUDevice& d, typename Functor::tout_type out, 185 typename Functor::tin_type in0, 186 typename Functor::tin_type in1) { 187 out.device(d) = in0.binaryExpr(in1, typename Functor::func()); 188 } 189 }; 190 191 192 template <typename T> 193 struct tanh_grad : base<T, Eigen::internal::scalar_tanh_gradient_op<T>> {}; 194 195 template <typename T> 196 struct sigmoid_grad : base<T, Eigen::internal::scalar_sigmoid_gradient_op<T>> { 197 }; 198 199 template <typename T> 200 struct inverse_grad : base<T, Eigen::internal::scalar_inverse_gradient_op<T>> { 201 }; 202 203 template <typename T> 204 struct sqrt_grad : base<T, Eigen::internal::scalar_sqrt_gradient_op<T>> {}; 205 206 template <typename T> 207 struct rsqrt_grad : base<T, Eigen::internal::scalar_rsqrt_gradient_op<T>> {}; 208 209 template <typename T> 210 struct igamma_grad_a : base<T, Eigen::internal::scalar_igamma_der_a_op<T>> {}; 211 212 } // end namespace functor 213 214 } // end namespace tensorflow 215 #endif // TENSORFLOW_CORE_KERNELS_CWISE_OPS_GRADIENTS_H_ 216