1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_CORE_KERNELS_AGGREGATE_OPS_H_
17 #define TENSORFLOW_CORE_KERNELS_AGGREGATE_OPS_H_
18 
19 // Functor definitions for Aggregate ops, must be compilable by nvcc.
20 
21 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
22 #include "tensorflow/core/framework/tensor_types.h"
23 
24 namespace tensorflow {
25 namespace functor {
26 
27 template <typename Device, typename T>
28 struct Add2Functor {
29   void operator()(const Device& d, typename TTypes<T>::Flat out,
30                   typename TTypes<T>::ConstFlat in1,
31                   typename TTypes<T>::ConstFlat in2);
32 };
33 
34 template <typename Device, typename T>
35 struct Add2EigenImpl {
ComputeAdd2EigenImpl36   static void Compute(const Device& d, typename TTypes<T>::Flat out,
37                       typename TTypes<T>::ConstFlat in1,
38                       typename TTypes<T>::ConstFlat in2) {
39     out.device(d) = in1 + in2;
40   }
41 };
42 
43 template <typename Device, typename T>
44 struct Add3Functor {
45   void operator()(const Device& d, typename TTypes<T>::Flat out,
46                   typename TTypes<T>::ConstFlat in1,
47                   typename TTypes<T>::ConstFlat in2,
48                   typename TTypes<T>::ConstFlat in3);
49 };
50 
51 template <typename Device, typename T>
52 struct Add3EigenImpl {
ComputeAdd3EigenImpl53   static void Compute(const Device& d, typename TTypes<T>::Flat out,
54                       typename TTypes<T>::ConstFlat in1,
55                       typename TTypes<T>::ConstFlat in2,
56                       typename TTypes<T>::ConstFlat in3) {
57     out.device(d) = in1 + in2 + in3;
58   }
59 };
60 
61 template <typename Device, typename T>
62 struct Add4Functor {
63   void operator()(const Device& d, typename TTypes<T>::Flat out,
64                   typename TTypes<T>::ConstFlat in1,
65                   typename TTypes<T>::ConstFlat in2,
66                   typename TTypes<T>::ConstFlat in3,
67                   typename TTypes<T>::ConstFlat in4);
68 };
69 
70 template <typename Device, typename T>
71 struct Add4EigenImpl {
ComputeAdd4EigenImpl72   static void Compute(const Device& d, typename TTypes<T>::Flat out,
73                       typename TTypes<T>::ConstFlat in1,
74                       typename TTypes<T>::ConstFlat in2,
75                       typename TTypes<T>::ConstFlat in3,
76                       typename TTypes<T>::ConstFlat in4) {
77     out.device(d) = in1 + in2 + in3 + in4;
78   }
79 };
80 
81 template <typename Device, typename T>
82 struct Add5Functor {
83   void operator()(const Device& d, typename TTypes<T>::Flat out,
84                   typename TTypes<T>::ConstFlat in1,
85                   typename TTypes<T>::ConstFlat in2,
86                   typename TTypes<T>::ConstFlat in3,
87                   typename TTypes<T>::ConstFlat in4,
88                   typename TTypes<T>::ConstFlat in5);
89 };
90 
91 template <typename Device, typename T>
92 struct Add5EigenImpl {
ComputeAdd5EigenImpl93   static void Compute(const Device& d, typename TTypes<T>::Flat out,
94                       typename TTypes<T>::ConstFlat in1,
95                       typename TTypes<T>::ConstFlat in2,
96                       typename TTypes<T>::ConstFlat in3,
97                       typename TTypes<T>::ConstFlat in4,
98                       typename TTypes<T>::ConstFlat in5) {
99     out.device(d) = in1 + in2 + in3 + in4 + in5;
100   }
101 };
102 
103 template <typename Device, typename T>
104 struct Add6Functor {
105   void operator()(const Device& d, typename TTypes<T>::Flat out,
106                   typename TTypes<T>::ConstFlat in1,
107                   typename TTypes<T>::ConstFlat in2,
108                   typename TTypes<T>::ConstFlat in3,
109                   typename TTypes<T>::ConstFlat in4,
110                   typename TTypes<T>::ConstFlat in5,
111                   typename TTypes<T>::ConstFlat in6);
112 };
113 
114 template <typename Device, typename T>
115 struct Add6EigenImpl {
ComputeAdd6EigenImpl116   static void Compute(const Device& d, typename TTypes<T>::Flat out,
117                       typename TTypes<T>::ConstFlat in1,
118                       typename TTypes<T>::ConstFlat in2,
119                       typename TTypes<T>::ConstFlat in3,
120                       typename TTypes<T>::ConstFlat in4,
121                       typename TTypes<T>::ConstFlat in5,
122                       typename TTypes<T>::ConstFlat in6) {
123     out.device(d) = in1 + in2 + in3 + in4 + in5 + in6;
124   }
125 };
126 
127 template <typename Device, typename T>
128 struct Add7Functor {
129   void operator()(const Device& d, typename TTypes<T>::Flat out,
130                   typename TTypes<T>::ConstFlat in1,
131                   typename TTypes<T>::ConstFlat in2,
132                   typename TTypes<T>::ConstFlat in3,
133                   typename TTypes<T>::ConstFlat in4,
134                   typename TTypes<T>::ConstFlat in5,
135                   typename TTypes<T>::ConstFlat in6,
136                   typename TTypes<T>::ConstFlat in7);
137 };
138 
139 template <typename Device, typename T>
140 struct Add7EigenImpl {
ComputeAdd7EigenImpl141   static void Compute(const Device& d, typename TTypes<T>::Flat out,
142                       typename TTypes<T>::ConstFlat in1,
143                       typename TTypes<T>::ConstFlat in2,
144                       typename TTypes<T>::ConstFlat in3,
145                       typename TTypes<T>::ConstFlat in4,
146                       typename TTypes<T>::ConstFlat in5,
147                       typename TTypes<T>::ConstFlat in6,
148                       typename TTypes<T>::ConstFlat in7) {
149     out.device(d) = in1 + in2 + in3 + in4 + in5 + in6 + in7;
150   }
151 };
152 
153 template <typename Device, typename T>
154 struct Add8Functor {
155   void operator()(
156       const Device& d, typename TTypes<T>::Flat out,
157       typename TTypes<T>::ConstFlat in1, typename TTypes<T>::ConstFlat in2,
158       typename TTypes<T>::ConstFlat in3, typename TTypes<T>::ConstFlat in4,
159       typename TTypes<T>::ConstFlat in5, typename TTypes<T>::ConstFlat in6,
160       typename TTypes<T>::ConstFlat in7, typename TTypes<T>::ConstFlat in8);
161 };
162 
163 template <typename Device, typename T>
164 struct Add8EigenImpl {
ComputeAdd8EigenImpl165   static void Compute(
166       const Device& d, typename TTypes<T>::Flat out,
167       typename TTypes<T>::ConstFlat in1, typename TTypes<T>::ConstFlat in2,
168       typename TTypes<T>::ConstFlat in3, typename TTypes<T>::ConstFlat in4,
169       typename TTypes<T>::ConstFlat in5, typename TTypes<T>::ConstFlat in6,
170       typename TTypes<T>::ConstFlat in7, typename TTypes<T>::ConstFlat in8) {
171     out.device(d) = in1 + in2 + in3 + in4 + in5 + in6 + in7 + in8;
172   }
173 };
174 
175 // Add8p is like Add8 except the underlying implementation should +=
176 // rather than assign to the output.
177 template <typename Device, typename T>
178 struct Add8pFunctor {
179   void operator()(
180       const Device& d, typename TTypes<T>::Flat out,
181       typename TTypes<T>::ConstFlat in1, typename TTypes<T>::ConstFlat in2,
182       typename TTypes<T>::ConstFlat in3, typename TTypes<T>::ConstFlat in4,
183       typename TTypes<T>::ConstFlat in5, typename TTypes<T>::ConstFlat in6,
184       typename TTypes<T>::ConstFlat in7, typename TTypes<T>::ConstFlat in8);
185 };
186 
187 template <typename Device, typename T>
188 struct Add8pEigenImpl {
ComputeAdd8pEigenImpl189   static void Compute(
190       const Device& d, typename TTypes<T>::Flat out,
191       typename TTypes<T>::ConstFlat in1, typename TTypes<T>::ConstFlat in2,
192       typename TTypes<T>::ConstFlat in3, typename TTypes<T>::ConstFlat in4,
193       typename TTypes<T>::ConstFlat in5, typename TTypes<T>::ConstFlat in6,
194       typename TTypes<T>::ConstFlat in7, typename TTypes<T>::ConstFlat in8) {
195     out.device(d) += in1 + in2 + in3 + in4 + in5 + in6 + in7 + in8;
196   }
197 };
198 
199 template <typename Device, typename T>
200 struct Add9Functor {
201   void operator()(
202       const Device& d, typename TTypes<T>::Flat out,
203       typename TTypes<T>::ConstFlat in1, typename TTypes<T>::ConstFlat in2,
204       typename TTypes<T>::ConstFlat in3, typename TTypes<T>::ConstFlat in4,
205       typename TTypes<T>::ConstFlat in5, typename TTypes<T>::ConstFlat in6,
206       typename TTypes<T>::ConstFlat in7, typename TTypes<T>::ConstFlat in8,
207       typename TTypes<T>::ConstFlat in9);
208 };
209 
210 template <typename Device, typename T>
211 struct Add9EigenImpl {
ComputeAdd9EigenImpl212   static void Compute(
213       const Device& d, typename TTypes<T>::Flat out,
214       typename TTypes<T>::ConstFlat in1, typename TTypes<T>::ConstFlat in2,
215       typename TTypes<T>::ConstFlat in3, typename TTypes<T>::ConstFlat in4,
216       typename TTypes<T>::ConstFlat in5, typename TTypes<T>::ConstFlat in6,
217       typename TTypes<T>::ConstFlat in7, typename TTypes<T>::ConstFlat in8,
218       typename TTypes<T>::ConstFlat in9) {
219     out.device(d) = in1 + in2 + in3 + in4 + in5 + in6 + in7 + in8 + in9;
220   }
221 };
222 
223 }  // namespace functor
224 }  // namespace tensorflow
225 
226 #endif  // TENSORFLOW_CORE_KERNELS_AGGREGATE_OPS_H_
227