1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_CORE_KERNELS_SCATTER_FUNCTOR_H_ 17 #define TENSORFLOW_CORE_KERNELS_SCATTER_FUNCTOR_H_ 18 19 #include <type_traits> 20 21 #include "third_party/eigen3/Eigen/Core" 22 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" 23 #include "tensorflow/core/framework/bounds_check.h" 24 #include "tensorflow/core/framework/op_kernel.h" 25 #include "tensorflow/core/framework/tensor.h" 26 #include "tensorflow/core/framework/variant_op_registry.h" 27 #include "tensorflow/core/kernels/dense_update_functor.h" 28 #include "tensorflow/core/platform/types.h" 29 30 namespace tensorflow { 31 32 class OpKernelContext; 33 typedef Eigen::ThreadPoolDevice CPUDevice; 34 typedef Eigen::GpuDevice GPUDevice; 35 #ifdef TENSORFLOW_USE_SYCL 36 typedef Eigen::SyclDevice SYCLDevice; 37 #endif // TENSORFLOW_USE_SYCL 38 39 namespace scatter_op { 40 41 enum class UpdateOp { ASSIGN, ADD, SUB, MUL, DIV, MIN, MAX }; 42 43 namespace internal { 44 45 template <scatter_op::UpdateOp Op> 46 struct Assign {}; 47 template <> 48 struct Assign<scatter_op::UpdateOp::ASSIGN> { 49 template <typename Params, typename Update> 50 static void Run(Params p, Update u) { 51 p = u; 52 } 53 template <typename Params, typename Update> 54 static void RunScalar(Params p, Update u) { 55 p.setConstant(u); 56 } 57 }; 58 template <> 59 struct Assign<scatter_op::UpdateOp::ADD> { 60 template <typename Params, typename Update> 61 static void Run(Params p, Update u) { 62 p += u; 63 } 64 template <typename Params, typename Update> 65 static void RunScalar(Params p, Update u) { 66 p = p + u; 67 } 68 }; 69 template <> 70 struct Assign<scatter_op::UpdateOp::SUB> { 71 template <typename Params, typename Update> 72 static void Run(Params p, Update u) { 73 p -= u; 74 } 75 template <typename Params, typename Update> 76 static void RunScalar(Params p, Update u) { 77 p = p + static_cast<Update>(-u); 78 } 79 }; 80 template <> 81 struct Assign<scatter_op::UpdateOp::MUL> { 82 template <typename Params, typename Update> 83 static void Run(Params p, Update u) { 84 p *= u; 85 } 86 template <typename Params, typename Update> 87 static void RunScalar(Params p, Update u) { 88 p = p * u; 89 } 90 }; 91 template <> 92 struct Assign<scatter_op::UpdateOp::DIV> { 93 template <typename Params, typename Update> 94 static void Run(Params p, Update u) { 95 p /= u; 96 } 97 template <typename Params, typename Update> 98 static void RunScalar(Params p, Update u) { 99 p = p / u; 100 } 101 }; 102 template <> 103 struct Assign<scatter_op::UpdateOp::MIN> { 104 // This method requires that Params and Update are tensor types. 105 template <typename Params, typename Update> 106 static void Run(Params p, Update u) { 107 p = p.cwiseMin(u); 108 } 109 // Same thing, but for Update being a scalar type. 110 template <typename Params, typename Update> 111 static void RunScalar(Params p, Update u) { 112 p = p.cwiseMin(u); 113 } 114 }; 115 template <> 116 struct Assign<scatter_op::UpdateOp::MAX> { 117 template <typename Params, typename Update> 118 static void Run(Params p, Update u) { 119 p = p.cwiseMax(u); 120 } 121 template <typename Params, typename Update> 122 static void RunScalar(Params p, Update u) { 123 p = p.cwiseMax(u); 124 } 125 }; 126 127 #ifdef TENSORFLOW_USE_SYCL 128 template <scatter_op::UpdateOp Op> 129 struct AssignSYCL {}; 130 template <> 131 struct AssignSYCL<scatter_op::UpdateOp::ASSIGN> { 132 template <typename Device, typename Params, typename Update> 133 static void Run(Device d, Params p, Update u) { 134 p.device(d) = u; 135 } 136 }; 137 138 template <> 139 struct AssignSYCL<scatter_op::UpdateOp::ADD> { 140 template <typename Device, typename Params, typename Update> 141 static void Run(Device d, Params p, Update u) { 142 p.device(d) += u; 143 } 144 }; 145 146 template <> 147 struct AssignSYCL<scatter_op::UpdateOp::SUB> { 148 template <typename Device, typename Params, typename Update> 149 static void Run(Device d, Params p, Update u) { 150 p.device(d) -= u; 151 } 152 }; 153 154 template <> 155 struct AssignSYCL<scatter_op::UpdateOp::MUL> { 156 template <typename Device, typename Params, typename Update> 157 static void Run(Device d, Params p, Update u) { 158 p.device(d) = p * u; 159 } 160 }; 161 162 template <> 163 struct AssignSYCL<scatter_op::UpdateOp::DIV> { 164 template <typename Device, typename Params, typename Update> 165 static void Run(Device d, Params p, Update u) { 166 p.device(d) = p / u; 167 } 168 }; 169 170 template <> 171 struct AssignSYCL<scatter_op::UpdateOp::MIN> { 172 template <typename Device, typename Params, typename Update> 173 static void Run(Device d, Params p, Update u) { 174 p.device(d) = p.cwiseMin(u); 175 } 176 }; 177 178 template <> 179 struct AssignSYCL<scatter_op::UpdateOp::MAX> { 180 template <typename Device, typename Params, typename Update> 181 static void Run(Device d, Params p, Update u) { 182 p.device(d) = p.cwiseMax(u); 183 } 184 }; 185 #endif // TENSORFLOW_USE_SYCL 186 187 } // namespace internal 188 } // namespace scatter_op 189 190 namespace functor { 191 template <typename Device, typename T, typename Index, scatter_op::UpdateOp op> 192 struct ScatterFunctor { 193 Index operator()(OpKernelContext* c, const Device& d, 194 typename TTypes<T>::Matrix params, 195 typename TTypes<T>::ConstMatrix updates, 196 typename TTypes<Index>::ConstFlat indices); 197 }; 198 199 template <typename Device, typename T, typename Index, scatter_op::UpdateOp op> 200 struct ScatterFunctorBase { 201 Index operator()(OpKernelContext* c, const Device& d, 202 typename TTypes<T>::Matrix params, 203 typename TTypes<T>::ConstMatrix updates, 204 typename TTypes<Index>::ConstFlat indices) { 205 // indices and params sizes were validated in DoCompute(). 206 const Index N = static_cast<Index>(indices.size()); 207 const Index limit = static_cast<Index>(params.dimension(0)); 208 for (Index i = 0; i < N; i++) { 209 // Grab the index and check its validity. Do this carefully, 210 // to avoid checking the value and grabbing it again from 211 // memory a second time (a security risk since it may change in between). 212 const Index index = ::tensorflow::internal::SubtleMustCopy(indices(i)); 213 if (!FastBoundsCheck(index, limit)) return i; 214 // Copy last Ndim-1 dimensions of updates[i] to params[index] 215 scatter_op::internal::Assign<op>::Run(params.template chip<0>(index), 216 updates.template chip<0>(i)); 217 } 218 return -1; 219 } 220 }; 221 222 template <typename Device, typename Index> 223 struct ScatterFunctorVariantAssignBase { 224 Index operator()(OpKernelContext* c, const Device& d, 225 typename TTypes<Variant>::Matrix params, 226 typename TTypes<Variant>::ConstMatrix updates, 227 typename TTypes<Index>::ConstFlat indices) { 228 // indices and params sizes were validated in DoCompute(). 229 const Index N = static_cast<Index>(indices.size()); 230 const Index limit = static_cast<Index>(params.dimension(0)); 231 const Index cols = static_cast<Index>(params.dimension(1)); 232 DCHECK_EQ(N, updates.dimension(0)); 233 DCHECK_EQ(cols, updates.dimension(1)); 234 for (Index i = 0; i < N; i++) { 235 // Grab the index and check its validity. Do this carefully, 236 // to avoid checking the value and grabbing it again from 237 // memory a second time (a security risk since it may change in between). 238 const Index index = ::tensorflow::internal::SubtleMustCopy(indices(i)); 239 if (!FastBoundsCheck(index, limit)) return i; 240 // Copy last Ndim-1 dimensions of updates[i] to params[index] 241 for (int j = 0; j < cols; ++j) { 242 const Variant& to_scatter = updates(i, j); 243 params(index, j) = to_scatter; 244 } 245 } 246 return -1; 247 } 248 }; 249 250 template <typename Index> 251 struct ScatterFunctor<CPUDevice, Variant, Index, scatter_op::UpdateOp::ASSIGN> 252 : ScatterFunctorVariantAssignBase<CPUDevice, Index> {}; 253 254 template <typename Index> 255 struct ScatterFunctor<GPUDevice, Variant, Index, scatter_op::UpdateOp::ASSIGN> 256 : ScatterFunctorVariantAssignBase<GPUDevice, Index> {}; 257 258 #ifdef TENSORFLOW_USE_SYCL 259 template <typename T, typename Index, scatter_op::UpdateOp op> 260 struct ScatterFunctorBase<SYCLDevice, T, Index, op> { 261 Index operator()(OpKernelContext* c, const SYCLDevice& d, 262 typename TTypes<T>::Matrix params, 263 typename TTypes<T>::ConstMatrix updates, 264 typename TTypes<Index>::ConstFlat indices) { 265 // indices and params sizes were validated in DoCompute(). 266 const Index N = static_cast<Index>(indices.size()); 267 const Index limit = static_cast<Index>(params.dimension(0)); 268 for (Index i = 0; i < N; i++) { 269 // Grab the index and check its validity. Do this carefully, 270 // to avoid checking the value and grabbing it again from 271 // memory a second time (a security risk since it may change in between). 272 const Index index = ::tensorflow::internal::SubtleMustCopy(indices(i)); 273 if (!FastBoundsCheck(index, limit)) return i; 274 // Copy last Ndim-1 dimensions of updates[i] to params[index] 275 scatter_op::internal::AssignSYCL<op>::Run( 276 d, params.template chip<0>(index), updates.template chip<0>(i)); 277 } 278 return -1; 279 } 280 }; 281 #endif // TENSORFLOW_USE_SYCL 282 283 template <typename T, typename Index> 284 struct ScatterFunctorBase<CPUDevice, T, Index, scatter_op::UpdateOp::ASSIGN> { 285 Index operator()(OpKernelContext* c, const CPUDevice& d, 286 typename TTypes<T>::Matrix params, 287 typename TTypes<T>::ConstMatrix updates, 288 typename TTypes<Index>::ConstFlat indices) { 289 // indices and params sizes were validated in DoCompute(). 290 const Index N = static_cast<Index>(indices.size()); 291 const Index limit = static_cast<Index>(params.dimension(0)); 292 if (!std::is_same<T, string>::value) { 293 for (Index i = 0; i < N; i++) { 294 // Grab the index and check its validity. Do this carefully, 295 // to avoid checking the value and grabbing it again from 296 // memory a second time (a security risk since it may change in 297 // between). 298 const Index index = ::tensorflow::internal::SubtleMustCopy(indices(i)); 299 if (!FastBoundsCheck(index, limit)) return i; 300 memmove(params.data() + index * params.dimension(1), 301 updates.data() + i * updates.dimension(1), 302 updates.dimension(1) * sizeof(T)); 303 } 304 } else { 305 for (Index i = 0; i < N; i++) { 306 // Grab the index and check its validity. Do this carefully, 307 // to avoid checking the value and grabbing it again from 308 // memory a second time (a security risk since it may change in 309 // between). 310 const Index index = ::tensorflow::internal::SubtleMustCopy(indices(i)); 311 if (!FastBoundsCheck(index, limit)) return i; 312 // Copy last Ndim-1 dimensions of updates[i] to params[index] 313 scatter_op::internal::Assign<scatter_op::UpdateOp::ASSIGN>::Run( 314 params.template chip<0>(index), updates.template chip<0>(i)); 315 } 316 } 317 return -1; 318 } 319 }; 320 321 template <typename T, typename Index, scatter_op::UpdateOp op> 322 struct ScatterFunctor<CPUDevice, T, Index, op> 323 : ScatterFunctorBase<CPUDevice, T, Index, op> {}; 324 325 #ifdef TENSORFLOW_USE_SYCL 326 template <typename T, typename Index, scatter_op::UpdateOp op> 327 struct ScatterFunctorSYCL { 328 Index operator()(OpKernelContext* c, const SYCLDevice& d, 329 typename TTypes<T>::Matrix params, 330 typename TTypes<T>::ConstMatrix updates, 331 typename TTypes<Index>::Flat indices) { 332 // indices and params sizes were validated in DoCompute(). 333 const Index N = static_cast<Index>(indices.size()); 334 const Index limit = static_cast<Index>(params.dimension(0)); 335 for (Index i = 0; i < N; i++) { 336 const Index index = ::tensorflow::internal::SubtleMustCopy(indices(i)); 337 if (!FastBoundsCheck(index, limit)) return i; 338 // Copy last Ndim-1 dimensions of updates[i] to params[index] 339 scatter_op::internal::AssignSYCL<op>::Run( 340 d, params.template chip<0>(index), updates.template chip<0>(i)); 341 } 342 return -1; 343 } 344 }; 345 #endif // TENSORFLOW_USE_SYCL 346 347 template <typename Device, typename T, typename Index, scatter_op::UpdateOp op> 348 struct ScatterScalarFunctor { 349 Index operator()(OpKernelContext* c, const Device& d, 350 typename TTypes<T>::Matrix params, 351 const typename TTypes<T>::ConstScalar update, 352 typename TTypes<Index>::ConstFlat indices); 353 }; 354 355 template <typename Device, typename T, typename Index, scatter_op::UpdateOp op> 356 struct ScatterScalarFunctorBase { 357 Index operator()(OpKernelContext* c, const Device& d, 358 typename TTypes<T>::Matrix params, 359 const typename TTypes<T>::ConstScalar update, 360 typename TTypes<Index>::ConstFlat indices) { 361 // indices and params sizes were validated in DoCompute(). 362 const Index N = static_cast<Index>(indices.size()); 363 const Index limit = static_cast<Index>(params.dimension(0)); 364 for (Index i = 0; i < N; i++) { 365 // Grab the index and check its validity. Do this carefully, 366 // to avoid checking the value and grabbing it again from 367 // memory a second time (a security risk since it may change in between). 368 const Index index = ::tensorflow::internal::SubtleMustCopy(indices(i)); 369 if (!FastBoundsCheck(index, limit)) return i; 370 // Broadcast update to params[index] 371 scatter_op::internal::Assign<op>::RunScalar( 372 params.template chip<0>(index), update()); 373 } 374 return -1; 375 } 376 }; 377 378 template <typename Device, typename Index> 379 struct ScatterScalarFunctorVariantAssignBase { 380 Index operator()(OpKernelContext* c, const Device& d, 381 typename TTypes<Variant>::Matrix params, 382 const typename TTypes<Variant>::ConstScalar update, 383 typename TTypes<Index>::ConstFlat indices) { 384 // indices and params sizes were validated in DoCompute(). 385 const Index N = static_cast<Index>(indices.size()); 386 const Index limit = static_cast<Index>(params.dimension(0)); 387 const Index cols = static_cast<Index>(params.dimension(1)); 388 const Variant& to_scatter = update(); 389 for (Index i = 0; i < N; i++) { 390 // Grab the index and check its validity. Do this carefully, 391 // to avoid checking the value and grabbing it again from 392 // memory a second time (a security risk since it may change in between). 393 const Index index = ::tensorflow::internal::SubtleMustCopy(indices(i)); 394 if (!FastBoundsCheck(index, limit)) return i; 395 // Broadcast update to params[index] 396 for (Index j = 0; j < cols; ++j) { 397 params(index, j) = to_scatter; 398 } 399 } 400 return -1; 401 } 402 }; 403 404 template <typename Index> 405 struct ScatterScalarFunctor<CPUDevice, Variant, Index, 406 scatter_op::UpdateOp::ASSIGN> 407 : ScatterScalarFunctorVariantAssignBase<CPUDevice, Index> {}; 408 template <typename Index> 409 struct ScatterScalarFunctor<GPUDevice, Variant, Index, 410 scatter_op::UpdateOp::ASSIGN> 411 : ScatterScalarFunctorVariantAssignBase<GPUDevice, Index> {}; 412 413 #ifdef TENSORFLOW_USE_SYCL 414 template <typename T, typename Index, scatter_op::UpdateOp op> 415 struct ScatterScalarFunctorBase<SYCLDevice, T, Index, op> { 416 Index operator()(OpKernelContext* c, const SYCLDevice& d, 417 typename TTypes<T>::Matrix params, 418 const typename TTypes<T>::ConstScalar update, 419 typename TTypes<Index>::ConstFlat indices) { 420 // indices and params sizes were validated in DoCompute(). 421 const Index N = static_cast<Index>(indices.size()); 422 const Index limit = static_cast<Index>(params.dimension(0)); 423 for (Index i = 0; i < N; i++) { 424 // Grab the index and check its validity. Do this carefully, 425 // to avoid checking the value and grabbing it again from 426 // memory a second time (a security risk since it may change in between). 427 const Index index = ::tensorflow::internal::SubtleMustCopy(indices(i)); 428 if (!FastBoundsCheck(index, limit)) return i; 429 // Broadcast update to params[index] 430 scatter_op::internal::AssignSYCL<op>::RunScalar( 431 d, params.template chip<0>(index), update); 432 } 433 return -1; 434 } 435 }; 436 #endif // TENSORFLOW_USE_SYCL 437 438 template <typename T, typename Index> 439 struct ScatterScalarFunctorBase<CPUDevice, T, Index, 440 scatter_op::UpdateOp::ASSIGN> { 441 Index operator()(OpKernelContext* c, const CPUDevice& d, 442 typename TTypes<T>::Matrix params, 443 const typename TTypes<T>::ConstScalar update, 444 typename TTypes<Index>::ConstFlat indices) { 445 // indices and params sizes were validated in DoCompute(). 446 const Index N = static_cast<Index>(indices.size()); 447 const Index limit = static_cast<Index>(params.dimension(0)); 448 for (Index i = 0; i < N; i++) { 449 // Grab the index and check its validity. Do this carefully, 450 // to avoid checking the value and grabbing it again from 451 // memory a second time (a security risk since it may change in between). 452 const Index index = ::tensorflow::internal::SubtleMustCopy(indices(i)); 453 if (!FastBoundsCheck(index, limit)) return i; 454 // Broadcast update to params[index] 455 scatter_op::internal::Assign<scatter_op::UpdateOp::ASSIGN>::RunScalar( 456 params.template chip<0>(index), update()); 457 } 458 return -1; 459 } 460 }; 461 462 template <typename T, typename Index, scatter_op::UpdateOp op> 463 struct ScatterScalarFunctor<CPUDevice, T, Index, op> 464 : ScatterScalarFunctorBase<CPUDevice, T, Index, op> {}; 465 466 #ifdef TENSORFLOW_USE_SYCL 467 template <typename T, typename Index, scatter_op::UpdateOp op> 468 struct ScatterScalarFunctorSYCL { 469 Index operator()(OpKernelContext* c, const SYCLDevice& d, 470 typename TTypes<T>::Matrix params, 471 const typename TTypes<T>::ConstScalar update, 472 typename TTypes<Index>::Flat indices) { 473 // indices and params sizes were validated in DoCompute(). 474 const Index N = static_cast<Index>(indices.size()); 475 const Index limit = static_cast<Index>(params.dimension(0)); 476 for (Index i = 0; i < N; i++) { 477 const Index index = ::tensorflow::internal::SubtleMustCopy(indices(i)); 478 if (!FastBoundsCheck(index, limit)) return i; 479 // Broadcast update to params[index] 480 scatter_op::internal::AssignSYCL<op>::Run( 481 d, params.template chip<0>(index), update()); 482 } 483 return -1; 484 } 485 }; 486 #endif // TENSORFLOW_USE_SYCL 487 488 } // namespace functor 489 } // namespace tensorflow 490 491 #endif // TENSORFLOW_CORE_KERNELS_SCATTER_FUNCTOR_H_ 492