1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_CORE_KERNELS_AGGREGATE_OPS_H_ 17 #define TENSORFLOW_CORE_KERNELS_AGGREGATE_OPS_H_ 18 19 // Functor definitions for Aggregate ops, must be compilable by nvcc. 20 21 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" 22 #include "tensorflow/core/framework/tensor_types.h" 23 24 namespace tensorflow { 25 namespace functor { 26 27 template <typename Device, typename T> 28 struct Add2Functor { 29 void operator()(const Device& d, typename TTypes<T>::Flat out, 30 typename TTypes<T>::ConstFlat in1, 31 typename TTypes<T>::ConstFlat in2); 32 }; 33 34 template <typename Device, typename T> 35 struct Add2EigenImpl { ComputeAdd2EigenImpl36 static void Compute(const Device& d, typename TTypes<T>::Flat out, 37 typename TTypes<T>::ConstFlat in1, 38 typename TTypes<T>::ConstFlat in2) { 39 out.device(d) = in1 + in2; 40 } 41 }; 42 43 template <typename Device, typename T> 44 struct Add3Functor { 45 void operator()(const Device& d, typename TTypes<T>::Flat out, 46 typename TTypes<T>::ConstFlat in1, 47 typename TTypes<T>::ConstFlat in2, 48 typename TTypes<T>::ConstFlat in3); 49 }; 50 51 template <typename Device, typename T> 52 struct Add3EigenImpl { ComputeAdd3EigenImpl53 static void Compute(const Device& d, typename TTypes<T>::Flat out, 54 typename TTypes<T>::ConstFlat in1, 55 typename TTypes<T>::ConstFlat in2, 56 typename TTypes<T>::ConstFlat in3) { 57 out.device(d) = in1 + in2 + in3; 58 } 59 }; 60 61 template <typename Device, typename T> 62 struct Add4Functor { 63 void operator()(const Device& d, typename TTypes<T>::Flat out, 64 typename TTypes<T>::ConstFlat in1, 65 typename TTypes<T>::ConstFlat in2, 66 typename TTypes<T>::ConstFlat in3, 67 typename TTypes<T>::ConstFlat in4); 68 }; 69 70 template <typename Device, typename T> 71 struct Add4EigenImpl { ComputeAdd4EigenImpl72 static void Compute(const Device& d, typename TTypes<T>::Flat out, 73 typename TTypes<T>::ConstFlat in1, 74 typename TTypes<T>::ConstFlat in2, 75 typename TTypes<T>::ConstFlat in3, 76 typename TTypes<T>::ConstFlat in4) { 77 out.device(d) = in1 + in2 + in3 + in4; 78 } 79 }; 80 81 template <typename Device, typename T> 82 struct Add5Functor { 83 void operator()(const Device& d, typename TTypes<T>::Flat out, 84 typename TTypes<T>::ConstFlat in1, 85 typename TTypes<T>::ConstFlat in2, 86 typename TTypes<T>::ConstFlat in3, 87 typename TTypes<T>::ConstFlat in4, 88 typename TTypes<T>::ConstFlat in5); 89 }; 90 91 template <typename Device, typename T> 92 struct Add5EigenImpl { ComputeAdd5EigenImpl93 static void Compute(const Device& d, typename TTypes<T>::Flat out, 94 typename TTypes<T>::ConstFlat in1, 95 typename TTypes<T>::ConstFlat in2, 96 typename TTypes<T>::ConstFlat in3, 97 typename TTypes<T>::ConstFlat in4, 98 typename TTypes<T>::ConstFlat in5) { 99 out.device(d) = in1 + in2 + in3 + in4 + in5; 100 } 101 }; 102 103 template <typename Device, typename T> 104 struct Add6Functor { 105 void operator()(const Device& d, typename TTypes<T>::Flat out, 106 typename TTypes<T>::ConstFlat in1, 107 typename TTypes<T>::ConstFlat in2, 108 typename TTypes<T>::ConstFlat in3, 109 typename TTypes<T>::ConstFlat in4, 110 typename TTypes<T>::ConstFlat in5, 111 typename TTypes<T>::ConstFlat in6); 112 }; 113 114 template <typename Device, typename T> 115 struct Add6EigenImpl { ComputeAdd6EigenImpl116 static void Compute(const Device& d, typename TTypes<T>::Flat out, 117 typename TTypes<T>::ConstFlat in1, 118 typename TTypes<T>::ConstFlat in2, 119 typename TTypes<T>::ConstFlat in3, 120 typename TTypes<T>::ConstFlat in4, 121 typename TTypes<T>::ConstFlat in5, 122 typename TTypes<T>::ConstFlat in6) { 123 out.device(d) = in1 + in2 + in3 + in4 + in5 + in6; 124 } 125 }; 126 127 template <typename Device, typename T> 128 struct Add7Functor { 129 void operator()(const Device& d, typename TTypes<T>::Flat out, 130 typename TTypes<T>::ConstFlat in1, 131 typename TTypes<T>::ConstFlat in2, 132 typename TTypes<T>::ConstFlat in3, 133 typename TTypes<T>::ConstFlat in4, 134 typename TTypes<T>::ConstFlat in5, 135 typename TTypes<T>::ConstFlat in6, 136 typename TTypes<T>::ConstFlat in7); 137 }; 138 139 template <typename Device, typename T> 140 struct Add7EigenImpl { ComputeAdd7EigenImpl141 static void Compute(const Device& d, typename TTypes<T>::Flat out, 142 typename TTypes<T>::ConstFlat in1, 143 typename TTypes<T>::ConstFlat in2, 144 typename TTypes<T>::ConstFlat in3, 145 typename TTypes<T>::ConstFlat in4, 146 typename TTypes<T>::ConstFlat in5, 147 typename TTypes<T>::ConstFlat in6, 148 typename TTypes<T>::ConstFlat in7) { 149 out.device(d) = in1 + in2 + in3 + in4 + in5 + in6 + in7; 150 } 151 }; 152 153 template <typename Device, typename T> 154 struct Add8Functor { 155 void operator()( 156 const Device& d, typename TTypes<T>::Flat out, 157 typename TTypes<T>::ConstFlat in1, typename TTypes<T>::ConstFlat in2, 158 typename TTypes<T>::ConstFlat in3, typename TTypes<T>::ConstFlat in4, 159 typename TTypes<T>::ConstFlat in5, typename TTypes<T>::ConstFlat in6, 160 typename TTypes<T>::ConstFlat in7, typename TTypes<T>::ConstFlat in8); 161 }; 162 163 template <typename Device, typename T> 164 struct Add8EigenImpl { ComputeAdd8EigenImpl165 static void Compute( 166 const Device& d, typename TTypes<T>::Flat out, 167 typename TTypes<T>::ConstFlat in1, typename TTypes<T>::ConstFlat in2, 168 typename TTypes<T>::ConstFlat in3, typename TTypes<T>::ConstFlat in4, 169 typename TTypes<T>::ConstFlat in5, typename TTypes<T>::ConstFlat in6, 170 typename TTypes<T>::ConstFlat in7, typename TTypes<T>::ConstFlat in8) { 171 out.device(d) = in1 + in2 + in3 + in4 + in5 + in6 + in7 + in8; 172 } 173 }; 174 175 // Add8p is like Add8 except the underlying implementation should += 176 // rather than assign to the output. 177 template <typename Device, typename T> 178 struct Add8pFunctor { 179 void operator()( 180 const Device& d, typename TTypes<T>::Flat out, 181 typename TTypes<T>::ConstFlat in1, typename TTypes<T>::ConstFlat in2, 182 typename TTypes<T>::ConstFlat in3, typename TTypes<T>::ConstFlat in4, 183 typename TTypes<T>::ConstFlat in5, typename TTypes<T>::ConstFlat in6, 184 typename TTypes<T>::ConstFlat in7, typename TTypes<T>::ConstFlat in8); 185 }; 186 187 template <typename Device, typename T> 188 struct Add8pEigenImpl { ComputeAdd8pEigenImpl189 static void Compute( 190 const Device& d, typename TTypes<T>::Flat out, 191 typename TTypes<T>::ConstFlat in1, typename TTypes<T>::ConstFlat in2, 192 typename TTypes<T>::ConstFlat in3, typename TTypes<T>::ConstFlat in4, 193 typename TTypes<T>::ConstFlat in5, typename TTypes<T>::ConstFlat in6, 194 typename TTypes<T>::ConstFlat in7, typename TTypes<T>::ConstFlat in8) { 195 out.device(d) += in1 + in2 + in3 + in4 + in5 + in6 + in7 + in8; 196 } 197 }; 198 199 template <typename Device, typename T> 200 struct Add9Functor { 201 void operator()( 202 const Device& d, typename TTypes<T>::Flat out, 203 typename TTypes<T>::ConstFlat in1, typename TTypes<T>::ConstFlat in2, 204 typename TTypes<T>::ConstFlat in3, typename TTypes<T>::ConstFlat in4, 205 typename TTypes<T>::ConstFlat in5, typename TTypes<T>::ConstFlat in6, 206 typename TTypes<T>::ConstFlat in7, typename TTypes<T>::ConstFlat in8, 207 typename TTypes<T>::ConstFlat in9); 208 }; 209 210 template <typename Device, typename T> 211 struct Add9EigenImpl { ComputeAdd9EigenImpl212 static void Compute( 213 const Device& d, typename TTypes<T>::Flat out, 214 typename TTypes<T>::ConstFlat in1, typename TTypes<T>::ConstFlat in2, 215 typename TTypes<T>::ConstFlat in3, typename TTypes<T>::ConstFlat in4, 216 typename TTypes<T>::ConstFlat in5, typename TTypes<T>::ConstFlat in6, 217 typename TTypes<T>::ConstFlat in7, typename TTypes<T>::ConstFlat in8, 218 typename TTypes<T>::ConstFlat in9) { 219 out.device(d) = in1 + in2 + in3 + in4 + in5 + in6 + in7 + in8 + in9; 220 } 221 }; 222 223 } // namespace functor 224 } // namespace tensorflow 225 226 #endif // TENSORFLOW_CORE_KERNELS_AGGREGATE_OPS_H_ 227