1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_CORE_KERNELS_SCATTER_FUNCTOR_H_
17 #define TENSORFLOW_CORE_KERNELS_SCATTER_FUNCTOR_H_
18 
19 #include <type_traits>
20 
21 #include "third_party/eigen3/Eigen/Core"
22 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
23 #include "tensorflow/core/framework/bounds_check.h"
24 #include "tensorflow/core/framework/op_kernel.h"
25 #include "tensorflow/core/framework/tensor.h"
26 #include "tensorflow/core/framework/variant_op_registry.h"
27 #include "tensorflow/core/kernels/dense_update_functor.h"
28 #include "tensorflow/core/platform/types.h"
29 
30 namespace tensorflow {
31 
32 class OpKernelContext;
33 typedef Eigen::ThreadPoolDevice CPUDevice;
34 typedef Eigen::GpuDevice GPUDevice;
35 #ifdef TENSORFLOW_USE_SYCL
36 typedef Eigen::SyclDevice SYCLDevice;
37 #endif  // TENSORFLOW_USE_SYCL
38 
39 namespace scatter_op {
40 
41 enum class UpdateOp { ASSIGN, ADD, SUB, MUL, DIV, MIN, MAX };
42 
43 namespace internal {
44 
45 template <scatter_op::UpdateOp Op>
46 struct Assign {};
47 template <>
48 struct Assign<scatter_op::UpdateOp::ASSIGN> {
49   template <typename Params, typename Update>
50   static void Run(Params p, Update u) {
51     p = u;
52   }
53   template <typename Params, typename Update>
54   static void RunScalar(Params p, Update u) {
55     p.setConstant(u);
56   }
57 };
58 template <>
59 struct Assign<scatter_op::UpdateOp::ADD> {
60   template <typename Params, typename Update>
61   static void Run(Params p, Update u) {
62     p += u;
63   }
64   template <typename Params, typename Update>
65   static void RunScalar(Params p, Update u) {
66     p = p + u;
67   }
68 };
69 template <>
70 struct Assign<scatter_op::UpdateOp::SUB> {
71   template <typename Params, typename Update>
72   static void Run(Params p, Update u) {
73     p -= u;
74   }
75   template <typename Params, typename Update>
76   static void RunScalar(Params p, Update u) {
77     p = p + static_cast<Update>(-u);
78   }
79 };
80 template <>
81 struct Assign<scatter_op::UpdateOp::MUL> {
82   template <typename Params, typename Update>
83   static void Run(Params p, Update u) {
84     p *= u;
85   }
86   template <typename Params, typename Update>
87   static void RunScalar(Params p, Update u) {
88     p = p * u;
89   }
90 };
91 template <>
92 struct Assign<scatter_op::UpdateOp::DIV> {
93   template <typename Params, typename Update>
94   static void Run(Params p, Update u) {
95     p /= u;
96   }
97   template <typename Params, typename Update>
98   static void RunScalar(Params p, Update u) {
99     p = p / u;
100   }
101 };
102 template <>
103 struct Assign<scatter_op::UpdateOp::MIN> {
104   // This method requires that Params and Update are tensor types.
105   template <typename Params, typename Update>
106   static void Run(Params p, Update u) {
107     p = p.cwiseMin(u);
108   }
109   // Same thing, but for Update being a scalar type.
110   template <typename Params, typename Update>
111   static void RunScalar(Params p, Update u) {
112     p = p.cwiseMin(u);
113   }
114 };
115 template <>
116 struct Assign<scatter_op::UpdateOp::MAX> {
117   template <typename Params, typename Update>
118   static void Run(Params p, Update u) {
119     p = p.cwiseMax(u);
120   }
121   template <typename Params, typename Update>
122   static void RunScalar(Params p, Update u) {
123     p = p.cwiseMax(u);
124   }
125 };
126 
127 #ifdef TENSORFLOW_USE_SYCL
128 template <scatter_op::UpdateOp Op>
129 struct AssignSYCL {};
130 template <>
131 struct AssignSYCL<scatter_op::UpdateOp::ASSIGN> {
132   template <typename Device, typename Params, typename Update>
133   static void Run(Device d, Params p, Update u) {
134     p.device(d) = u;
135   }
136 };
137 
138 template <>
139 struct AssignSYCL<scatter_op::UpdateOp::ADD> {
140   template <typename Device, typename Params, typename Update>
141   static void Run(Device d, Params p, Update u) {
142     p.device(d) += u;
143   }
144 };
145 
146 template <>
147 struct AssignSYCL<scatter_op::UpdateOp::SUB> {
148   template <typename Device, typename Params, typename Update>
149   static void Run(Device d, Params p, Update u) {
150     p.device(d) -= u;
151   }
152 };
153 
154 template <>
155 struct AssignSYCL<scatter_op::UpdateOp::MUL> {
156   template <typename Device, typename Params, typename Update>
157   static void Run(Device d, Params p, Update u) {
158     p.device(d) = p * u;
159   }
160 };
161 
162 template <>
163 struct AssignSYCL<scatter_op::UpdateOp::DIV> {
164   template <typename Device, typename Params, typename Update>
165   static void Run(Device d, Params p, Update u) {
166     p.device(d) = p / u;
167   }
168 };
169 
170 template <>
171 struct AssignSYCL<scatter_op::UpdateOp::MIN> {
172   template <typename Device, typename Params, typename Update>
173   static void Run(Device d, Params p, Update u) {
174     p.device(d) = p.cwiseMin(u);
175   }
176 };
177 
178 template <>
179 struct AssignSYCL<scatter_op::UpdateOp::MAX> {
180   template <typename Device, typename Params, typename Update>
181   static void Run(Device d, Params p, Update u) {
182     p.device(d) = p.cwiseMax(u);
183   }
184 };
185 #endif  // TENSORFLOW_USE_SYCL
186 
187 }  // namespace internal
188 }  // namespace scatter_op
189 
190 namespace functor {
191 template <typename Device, typename T, typename Index, scatter_op::UpdateOp op>
192 struct ScatterFunctor {
193   Index operator()(OpKernelContext* c, const Device& d,
194                    typename TTypes<T>::Matrix params,
195                    typename TTypes<T>::ConstMatrix updates,
196                    typename TTypes<Index>::ConstFlat indices);
197 };
198 
199 template <typename Device, typename T, typename Index, scatter_op::UpdateOp op>
200 struct ScatterFunctorBase {
201   Index operator()(OpKernelContext* c, const Device& d,
202                    typename TTypes<T>::Matrix params,
203                    typename TTypes<T>::ConstMatrix updates,
204                    typename TTypes<Index>::ConstFlat indices) {
205     // indices and params sizes were validated in DoCompute().
206     const Index N = static_cast<Index>(indices.size());
207     const Index limit = static_cast<Index>(params.dimension(0));
208     for (Index i = 0; i < N; i++) {
209       // Grab the index and check its validity.  Do this carefully,
210       // to avoid checking the value and grabbing it again from
211       // memory a second time (a security risk since it may change in between).
212       const Index index = ::tensorflow::internal::SubtleMustCopy(indices(i));
213       if (!FastBoundsCheck(index, limit)) return i;
214       // Copy last Ndim-1 dimensions of updates[i] to params[index]
215       scatter_op::internal::Assign<op>::Run(params.template chip<0>(index),
216                                             updates.template chip<0>(i));
217     }
218     return -1;
219   }
220 };
221 
222 template <typename Device, typename Index>
223 struct ScatterFunctorVariantAssignBase {
224   Index operator()(OpKernelContext* c, const Device& d,
225                    typename TTypes<Variant>::Matrix params,
226                    typename TTypes<Variant>::ConstMatrix updates,
227                    typename TTypes<Index>::ConstFlat indices) {
228     // indices and params sizes were validated in DoCompute().
229     const Index N = static_cast<Index>(indices.size());
230     const Index limit = static_cast<Index>(params.dimension(0));
231     const Index cols = static_cast<Index>(params.dimension(1));
232     DCHECK_EQ(N, updates.dimension(0));
233     DCHECK_EQ(cols, updates.dimension(1));
234     for (Index i = 0; i < N; i++) {
235       // Grab the index and check its validity.  Do this carefully,
236       // to avoid checking the value and grabbing it again from
237       // memory a second time (a security risk since it may change in between).
238       const Index index = ::tensorflow::internal::SubtleMustCopy(indices(i));
239       if (!FastBoundsCheck(index, limit)) return i;
240       // Copy last Ndim-1 dimensions of updates[i] to params[index]
241       for (int j = 0; j < cols; ++j) {
242         const Variant& to_scatter = updates(i, j);
243         params(index, j) = to_scatter;
244       }
245     }
246     return -1;
247   }
248 };
249 
250 template <typename Index>
251 struct ScatterFunctor<CPUDevice, Variant, Index, scatter_op::UpdateOp::ASSIGN>
252     : ScatterFunctorVariantAssignBase<CPUDevice, Index> {};
253 
254 template <typename Index>
255 struct ScatterFunctor<GPUDevice, Variant, Index, scatter_op::UpdateOp::ASSIGN>
256     : ScatterFunctorVariantAssignBase<GPUDevice, Index> {};
257 
258 #ifdef TENSORFLOW_USE_SYCL
259 template <typename T, typename Index, scatter_op::UpdateOp op>
260 struct ScatterFunctorBase<SYCLDevice, T, Index, op> {
261   Index operator()(OpKernelContext* c, const SYCLDevice& d,
262                    typename TTypes<T>::Matrix params,
263                    typename TTypes<T>::ConstMatrix updates,
264                    typename TTypes<Index>::ConstFlat indices) {
265     // indices and params sizes were validated in DoCompute().
266     const Index N = static_cast<Index>(indices.size());
267     const Index limit = static_cast<Index>(params.dimension(0));
268     for (Index i = 0; i < N; i++) {
269       // Grab the index and check its validity.  Do this carefully,
270       // to avoid checking the value and grabbing it again from
271       // memory a second time (a security risk since it may change in between).
272       const Index index = ::tensorflow::internal::SubtleMustCopy(indices(i));
273       if (!FastBoundsCheck(index, limit)) return i;
274       // Copy last Ndim-1 dimensions of updates[i] to params[index]
275       scatter_op::internal::AssignSYCL<op>::Run(
276           d, params.template chip<0>(index), updates.template chip<0>(i));
277     }
278     return -1;
279   }
280 };
281 #endif  // TENSORFLOW_USE_SYCL
282 
283 template <typename T, typename Index>
284 struct ScatterFunctorBase<CPUDevice, T, Index, scatter_op::UpdateOp::ASSIGN> {
285   Index operator()(OpKernelContext* c, const CPUDevice& d,
286                    typename TTypes<T>::Matrix params,
287                    typename TTypes<T>::ConstMatrix updates,
288                    typename TTypes<Index>::ConstFlat indices) {
289     // indices and params sizes were validated in DoCompute().
290     const Index N = static_cast<Index>(indices.size());
291     const Index limit = static_cast<Index>(params.dimension(0));
292     if (!std::is_same<T, string>::value) {
293       for (Index i = 0; i < N; i++) {
294         // Grab the index and check its validity.  Do this carefully,
295         // to avoid checking the value and grabbing it again from
296         // memory a second time (a security risk since it may change in
297         // between).
298         const Index index = ::tensorflow::internal::SubtleMustCopy(indices(i));
299         if (!FastBoundsCheck(index, limit)) return i;
300         memmove(params.data() + index * params.dimension(1),
301                 updates.data() + i * updates.dimension(1),
302                 updates.dimension(1) * sizeof(T));
303       }
304     } else {
305       for (Index i = 0; i < N; i++) {
306         // Grab the index and check its validity.  Do this carefully,
307         // to avoid checking the value and grabbing it again from
308         // memory a second time (a security risk since it may change in
309         // between).
310         const Index index = ::tensorflow::internal::SubtleMustCopy(indices(i));
311         if (!FastBoundsCheck(index, limit)) return i;
312         // Copy last Ndim-1 dimensions of updates[i] to params[index]
313         scatter_op::internal::Assign<scatter_op::UpdateOp::ASSIGN>::Run(
314             params.template chip<0>(index), updates.template chip<0>(i));
315       }
316     }
317     return -1;
318   }
319 };
320 
321 template <typename T, typename Index, scatter_op::UpdateOp op>
322 struct ScatterFunctor<CPUDevice, T, Index, op>
323     : ScatterFunctorBase<CPUDevice, T, Index, op> {};
324 
325 #ifdef TENSORFLOW_USE_SYCL
326 template <typename T, typename Index, scatter_op::UpdateOp op>
327 struct ScatterFunctorSYCL {
328   Index operator()(OpKernelContext* c, const SYCLDevice& d,
329                    typename TTypes<T>::Matrix params,
330                    typename TTypes<T>::ConstMatrix updates,
331                    typename TTypes<Index>::Flat indices) {
332     // indices and params sizes were validated in DoCompute().
333     const Index N = static_cast<Index>(indices.size());
334     const Index limit = static_cast<Index>(params.dimension(0));
335     for (Index i = 0; i < N; i++) {
336       const Index index = ::tensorflow::internal::SubtleMustCopy(indices(i));
337       if (!FastBoundsCheck(index, limit)) return i;
338       // Copy last Ndim-1 dimensions of updates[i] to params[index]
339       scatter_op::internal::AssignSYCL<op>::Run(
340           d, params.template chip<0>(index), updates.template chip<0>(i));
341     }
342     return -1;
343   }
344 };
345 #endif  // TENSORFLOW_USE_SYCL
346 
347 template <typename Device, typename T, typename Index, scatter_op::UpdateOp op>
348 struct ScatterScalarFunctor {
349   Index operator()(OpKernelContext* c, const Device& d,
350                    typename TTypes<T>::Matrix params,
351                    const typename TTypes<T>::ConstScalar update,
352                    typename TTypes<Index>::ConstFlat indices);
353 };
354 
355 template <typename Device, typename T, typename Index, scatter_op::UpdateOp op>
356 struct ScatterScalarFunctorBase {
357   Index operator()(OpKernelContext* c, const Device& d,
358                    typename TTypes<T>::Matrix params,
359                    const typename TTypes<T>::ConstScalar update,
360                    typename TTypes<Index>::ConstFlat indices) {
361     // indices and params sizes were validated in DoCompute().
362     const Index N = static_cast<Index>(indices.size());
363     const Index limit = static_cast<Index>(params.dimension(0));
364     for (Index i = 0; i < N; i++) {
365       // Grab the index and check its validity.  Do this carefully,
366       // to avoid checking the value and grabbing it again from
367       // memory a second time (a security risk since it may change in between).
368       const Index index = ::tensorflow::internal::SubtleMustCopy(indices(i));
369       if (!FastBoundsCheck(index, limit)) return i;
370       // Broadcast update to params[index]
371       scatter_op::internal::Assign<op>::RunScalar(
372           params.template chip<0>(index), update());
373     }
374     return -1;
375   }
376 };
377 
378 template <typename Device, typename Index>
379 struct ScatterScalarFunctorVariantAssignBase {
380   Index operator()(OpKernelContext* c, const Device& d,
381                    typename TTypes<Variant>::Matrix params,
382                    const typename TTypes<Variant>::ConstScalar update,
383                    typename TTypes<Index>::ConstFlat indices) {
384     // indices and params sizes were validated in DoCompute().
385     const Index N = static_cast<Index>(indices.size());
386     const Index limit = static_cast<Index>(params.dimension(0));
387     const Index cols = static_cast<Index>(params.dimension(1));
388     const Variant& to_scatter = update();
389     for (Index i = 0; i < N; i++) {
390       // Grab the index and check its validity.  Do this carefully,
391       // to avoid checking the value and grabbing it again from
392       // memory a second time (a security risk since it may change in between).
393       const Index index = ::tensorflow::internal::SubtleMustCopy(indices(i));
394       if (!FastBoundsCheck(index, limit)) return i;
395       // Broadcast update to params[index]
396       for (Index j = 0; j < cols; ++j) {
397         params(index, j) = to_scatter;
398       }
399     }
400     return -1;
401   }
402 };
403 
404 template <typename Index>
405 struct ScatterScalarFunctor<CPUDevice, Variant, Index,
406                             scatter_op::UpdateOp::ASSIGN>
407     : ScatterScalarFunctorVariantAssignBase<CPUDevice, Index> {};
408 template <typename Index>
409 struct ScatterScalarFunctor<GPUDevice, Variant, Index,
410                             scatter_op::UpdateOp::ASSIGN>
411     : ScatterScalarFunctorVariantAssignBase<GPUDevice, Index> {};
412 
413 #ifdef TENSORFLOW_USE_SYCL
414 template <typename T, typename Index, scatter_op::UpdateOp op>
415 struct ScatterScalarFunctorBase<SYCLDevice, T, Index, op> {
416   Index operator()(OpKernelContext* c, const SYCLDevice& d,
417                    typename TTypes<T>::Matrix params,
418                    const typename TTypes<T>::ConstScalar update,
419                    typename TTypes<Index>::ConstFlat indices) {
420     // indices and params sizes were validated in DoCompute().
421     const Index N = static_cast<Index>(indices.size());
422     const Index limit = static_cast<Index>(params.dimension(0));
423     for (Index i = 0; i < N; i++) {
424       // Grab the index and check its validity.  Do this carefully,
425       // to avoid checking the value and grabbing it again from
426       // memory a second time (a security risk since it may change in between).
427       const Index index = ::tensorflow::internal::SubtleMustCopy(indices(i));
428       if (!FastBoundsCheck(index, limit)) return i;
429       // Broadcast update to params[index]
430       scatter_op::internal::AssignSYCL<op>::RunScalar(
431           d, params.template chip<0>(index), update);
432     }
433     return -1;
434   }
435 };
436 #endif  // TENSORFLOW_USE_SYCL
437 
438 template <typename T, typename Index>
439 struct ScatterScalarFunctorBase<CPUDevice, T, Index,
440                                 scatter_op::UpdateOp::ASSIGN> {
441   Index operator()(OpKernelContext* c, const CPUDevice& d,
442                    typename TTypes<T>::Matrix params,
443                    const typename TTypes<T>::ConstScalar update,
444                    typename TTypes<Index>::ConstFlat indices) {
445     // indices and params sizes were validated in DoCompute().
446     const Index N = static_cast<Index>(indices.size());
447     const Index limit = static_cast<Index>(params.dimension(0));
448     for (Index i = 0; i < N; i++) {
449       // Grab the index and check its validity.  Do this carefully,
450       // to avoid checking the value and grabbing it again from
451       // memory a second time (a security risk since it may change in between).
452       const Index index = ::tensorflow::internal::SubtleMustCopy(indices(i));
453       if (!FastBoundsCheck(index, limit)) return i;
454       // Broadcast update to params[index]
455       scatter_op::internal::Assign<scatter_op::UpdateOp::ASSIGN>::RunScalar(
456           params.template chip<0>(index), update());
457     }
458     return -1;
459   }
460 };
461 
462 template <typename T, typename Index, scatter_op::UpdateOp op>
463 struct ScatterScalarFunctor<CPUDevice, T, Index, op>
464     : ScatterScalarFunctorBase<CPUDevice, T, Index, op> {};
465 
466 #ifdef TENSORFLOW_USE_SYCL
467 template <typename T, typename Index, scatter_op::UpdateOp op>
468 struct ScatterScalarFunctorSYCL {
469   Index operator()(OpKernelContext* c, const SYCLDevice& d,
470                    typename TTypes<T>::Matrix params,
471                    const typename TTypes<T>::ConstScalar update,
472                    typename TTypes<Index>::Flat indices) {
473     // indices and params sizes were validated in DoCompute().
474     const Index N = static_cast<Index>(indices.size());
475     const Index limit = static_cast<Index>(params.dimension(0));
476     for (Index i = 0; i < N; i++) {
477       const Index index = ::tensorflow::internal::SubtleMustCopy(indices(i));
478       if (!FastBoundsCheck(index, limit)) return i;
479       // Broadcast update to params[index]
480       scatter_op::internal::AssignSYCL<op>::Run(
481           d, params.template chip<0>(index), update());
482     }
483     return -1;
484   }
485 };
486 #endif  // TENSORFLOW_USE_SYCL
487 
488 }  // namespace functor
489 }  // namespace tensorflow
490 
491 #endif  // TENSORFLOW_CORE_KERNELS_SCATTER_FUNCTOR_H_
492