1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // Our general strategy for preventing conflicts between concurrent
17 // reads and writes of resource variables is to:
18 // * For read operations, we:
19 //   - acquire the variable's mutex (in "shared" mode);
20 //   - make a (shallow) copy of the Tensor object, which increments
21 //     the reference count on the variable's TensorBuffer;
22 //   - release the variable's mutex;
23 //   - use the copy of the Tensor object to do the read.
24 // * For write operations, we:
25 //   - acquire the variable's mutex (in "exclusive" mode);
26 //   - check the reference count of variable's TensorBuffer and
27 //     if it is >1, make a deep copy of the variable's Tensor;
28 //   - mutate the variable's Tensor;
29 //   - and release the variable's mutex.
30 // This allows several read operations to all use the same
31 // TensorBuffer without needing to copy. When it comes time to write
32 // it will only make a copy if there is an outstanding read using the
33 // buffer. Write operations are serialized by the variable's mutex.
34 //
35 // For sparse operations (scatter, gather, sparse optimizer updates),
36 // we need to avoid copies, since there may not be enough memory for
37 // to copies of the whole tensor. To support this, we make two
38 // modifications to the above strategy:
39 // * For sparse reads (gather), we hold the variable's mutex (still in
40 //   "shared" mode) for the duration of the whole read. This means
41 //   that as long as you only do sparse read operations no write will
42 //   see the reference count >1.
43 // * For sparse write operations where the user explicitly specifies
44 //   that they want to perform the write without locks held
45 //   (use_locking=false), we never copy even if the variable's
46 //   reference count is >1.
47 
48 #define EIGEN_USE_THREADS
49 
50 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
51 #define EIGEN_USE_GPU
52 #endif
53 
54 #include "tensorflow/core/kernels/resource_variable_ops.h"
55 
56 #include <memory>
57 #include <vector>
58 
59 #include "absl/strings/str_join.h"
60 #include "tensorflow/core/common_runtime/device.h"
61 #include "tensorflow/core/framework/bounds_check.h"
62 #include "tensorflow/core/framework/op_kernel.h"
63 #include "tensorflow/core/framework/register_types.h"
64 #include "tensorflow/core/framework/resource_mgr.h"
65 #include "tensorflow/core/framework/tensor_shape.h"
66 #include "tensorflow/core/framework/tensor_types.h"
67 #include "tensorflow/core/framework/variant_op_registry.h"
68 #include "tensorflow/core/kernels/dense_update_functor.h"
69 #include "tensorflow/core/kernels/gather_functor.h"
70 #include "tensorflow/core/kernels/gather_nd_op.h"
71 #include "tensorflow/core/kernels/scatter_functor.h"
72 #include "tensorflow/core/kernels/training_op_helpers.h"
73 #include "tensorflow/core/kernels/variable_ops.h"
74 #include "tensorflow/core/lib/core/errors.h"
75 #include "tensorflow/core/lib/core/refcount.h"
76 #include "tensorflow/core/platform/casts.h"
77 #include "tensorflow/core/platform/mem.h"
78 #include "tensorflow/core/platform/mutex.h"
79 #include "tensorflow/core/platform/types.h"
80 #include "tensorflow/core/util/util.h"
81 
82 namespace tensorflow {
83 
84 REGISTER_KERNEL_BUILDER(Name("_VarHandlesOp").Device(DEVICE_CPU),
85                         ResourceHandlesOp<Var>);
86 
ReadVariableOp(OpKernelConstruction * c)87 ReadVariableOp::ReadVariableOp(OpKernelConstruction* c) : OpKernel(c) {
88   OP_REQUIRES_OK(c, c->GetAttr("dtype", &dtype_));
89 }
90 
91 namespace {
92 
CopyVariable(int output_idx,OpKernelContext * ctx,const Tensor * t)93 Status CopyVariable(int output_idx, OpKernelContext* ctx, const Tensor* t) {
94   Tensor* output;
95   Notification n;
96   Status status;
97   AllocatorAttributes attr;
98   if (t->dtype() == DT_VARIANT) {
99     attr.set_on_host(true);
100   }
101   TF_RETURN_IF_ERROR(
102       ctx->allocate_output(output_idx, t->shape(), &output, attr));
103   if (t->dtype() == DT_VARIANT) {
104     output->flat<Variant>() = t->flat<Variant>();
105   } else if (ctx->op_device_context() != nullptr) {
106     // TODO(apassos): remove the down_cast by just returning Device* from
107     // OpKernelContext
108     Device* device = down_cast<Device*>(ctx->device());
109     ctx->op_device_context()->CopyTensorInSameDevice(
110         t, device, output, [&n, &status](const Status& s) {
111           status = s;
112           n.Notify();
113         });
114     n.WaitForNotification();
115     return status;
116   } else {
117     switch (t->dtype()) {
118 #define HANDLER(type)                       \
119   case DataTypeToEnum<type>::value:         \
120     output->flat<type>() = t->flat<type>(); \
121     break;
122       TF_CALL_ALL_TYPES(HANDLER);
123 #undef HANDLER
124       default:
125         return errors::Internal("Unsupported dtype", t->dtype());
126     }
127   }
128   return Status::OK();
129 }
130 
131 }  // namespace
132 
Compute(OpKernelContext * ctx)133 void ReadVariableOp::Compute(OpKernelContext* ctx) {
134   core::RefCountPtr<Var> variable;
135   const ResourceHandle& handle = HandleFromInput(ctx, 0);
136   const auto status = LookupResource(ctx, handle, &variable);
137   OP_REQUIRES(ctx, status.ok(),
138               errors::FailedPrecondition(
139                   "Could not find variable ", handle.name(), ". ",
140                   "This could mean that the variable has been deleted. ",
141                   "In TF1, it can also mean the variable is uninitialized. ",
142                   "Debug info: container=", handle.container(),
143                   ", status=", status.ToString()));
144 
145   tf_shared_lock ml(*variable->mu());
146   // We're acquiring a reference to the underlying buffer while
147   // holding a shared lock to guarantee ordering of reads and
148   // writes when in copy-on-write mode.
149   const Tensor* t = variable->tensor();
150   if (!variable->copy_on_read_mode.load()) {
151     OP_REQUIRES(
152         ctx, dtype_ == t->dtype(),
153         errors::InvalidArgument(
154             "Trying to read variable with wrong dtype. Expected ",
155             DataTypeString(dtype_), " got ", DataTypeString(t->dtype())));
156     ctx->set_output(0, *t);
157   } else {
158     OP_REQUIRES_OK(ctx, CopyVariable(0, ctx, t));
159   }
160 }
161 
ReadVariablesOp(OpKernelConstruction * c)162 ReadVariablesOp::ReadVariablesOp(OpKernelConstruction* c) : OpKernel(c) {
163   int n;
164   OP_REQUIRES_OK(c, c->GetAttr("N", &n));
165   OP_REQUIRES_OK(c, c->GetAttr("dtypes", &dtypes_));
166   OP_REQUIRES(c, n == dtypes_.size(),
167               errors::InvalidArgument(
168                   "Mismatched number of arguments to ReadVariablesOp (", n,
169                   " vs. ", dtypes_.size(), ")"));
170 }
171 
Compute(OpKernelContext * ctx)172 void ReadVariablesOp::Compute(OpKernelContext* ctx) {
173   std::vector<core::RefCountPtr<Var>> variables(dtypes_.size());
174   std::vector<const ResourceHandle*> handles(dtypes_.size());
175   for (size_t i = 0; i < dtypes_.size(); ++i) {
176     handles[i] = &HandleFromInput(ctx, i);
177   }
178 
179   OP_REQUIRES_OK(ctx, LookupResources(ctx, handles, &variables));
180 
181   std::vector<string> uninitialized_vars;
182   for (int64 i = 0; i < variables.size(); i++) {
183     if (variables[i] == nullptr) {
184       uninitialized_vars.push_back(handles[i]->name());
185     }
186   }
187 
188   OP_REQUIRES(ctx, uninitialized_vars.empty(),
189               errors::FailedPrecondition(
190                   "In ReadVariablesOp the following variables were "
191                   "found uninitialized: ",
192                   absl::StrJoin(uninitialized_vars, ", ")));
193 
194   for (size_t i = 0; i < dtypes_.size(); ++i) {
195     // We're acquiring a reference to the underlying buffer while
196     // holding a shared lock to guarantee ordering of reads and
197     // writes.
198     tf_shared_lock ml(*variables[i]->mu());
199     OP_REQUIRES(ctx, dtypes_[i] == variables[i]->tensor()->dtype(),
200                 errors::InvalidArgument(
201                     "Trying to read variable ", handles[i]->name(),
202                     " from Container: ", handles[i]->container(),
203                     " with wrong dtype. Expected ", DataTypeString(dtypes_[i]),
204                     " got ", DataTypeString(variables[i]->tensor()->dtype())));
205     if (variables[i]->copy_on_read_mode.load()) {
206       OP_REQUIRES_OK(ctx, CopyVariable(i, ctx, variables[i]->tensor()));
207     } else {
208       const Tensor& t = *variables[i]->tensor();
209       ctx->set_output(i, t);
210     }
211   }
212 }
213 
214 REGISTER_KERNEL_BUILDER(Name("ReadVariableOp").Device(DEVICE_CPU),
215                         ReadVariableOp);
216 REGISTER_KERNEL_BUILDER(Name("_ReadVariablesOp").Device(DEVICE_CPU),
217                         ReadVariablesOp);
218 
VarHandleOp(OpKernelConstruction * context)219 VarHandleOp::VarHandleOp(OpKernelConstruction* context) : OpKernel(context) {
220   OP_REQUIRES_OK(context, context->GetAttr("container", &container_));
221   OP_REQUIRES_OK(context, context->GetAttr("shared_name", &name_));
222 
223   OP_REQUIRES_OK(context, context->GetAttr("dtype", &dtype_and_shape_.dtype));
224   OP_REQUIRES_OK(context, context->GetAttr("shape", &dtype_and_shape_.shape));
225 
226   is_anonymous_ = name_ == ResourceHandle::ANONYMOUS_NAME;
227 
228   if (!is_anonymous_) {
229     AllocatorAttributes attr;
230     attr.set_on_host(true);
231     OP_REQUIRES_OK(context, context->allocate_temp(DT_RESOURCE, TensorShape({}),
232                                                    &resource_, attr));
233     resource_.scalar<ResourceHandle>()() = MakeResourceHandle<Var>(
234         context, container_, name_,
235         std::vector<DtypeAndPartialTensorShape>{dtype_and_shape_});
236   }
237 }
238 
Compute(OpKernelContext * ctx)239 void VarHandleOp::Compute(OpKernelContext* ctx) {
240   if (is_anonymous_) {
241     AllocatorAttributes attr;
242     attr.set_on_host(true);
243     Tensor handle;
244     OP_REQUIRES_OK(
245         ctx, ctx->allocate_temp(DT_RESOURCE, TensorShape({}), &handle, attr));
246     handle.scalar<ResourceHandle>()() = MakeResourceHandle<Var>(
247         ctx, container_, name_,
248         std::vector<DtypeAndPartialTensorShape>{dtype_and_shape_},
249         ctx->stack_trace());
250     ctx->set_output(0, handle);
251   } else {
252     ctx->set_output(0, resource_);
253   }
254 }
255 
256 REGISTER_KERNEL_BUILDER(Name("VarHandleOp").Device(DEVICE_CPU), VarHandleOp);
257 
258 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
259 REGISTER_KERNEL_BUILDER(
260     Name("ReadVariableOp").Device(DEVICE_GPU).HostMemory("resource"),
261     ReadVariableOp);
262 REGISTER_KERNEL_BUILDER(
263     Name("_ReadVariablesOp").Device(DEVICE_GPU).HostMemory("resources"),
264     ReadVariablesOp);
265 
266 #define REGISTER_GPU_KERNELS(type)                             \
267   namespace functor {                                          \
268   template <>                                                  \
269   void DenseUpdate<GPUDevice, type, ASSIGN>::operator()(       \
270       const GPUDevice& d, typename TTypes<type>::Flat lhs,     \
271       typename TTypes<type>::ConstFlat rhs);                   \
272   extern template struct DenseUpdate<GPUDevice, type, ASSIGN>; \
273   }                                                            \
274   REGISTER_KERNEL_BUILDER(Name("VarHandleOp")                  \
275                               .Device(DEVICE_GPU)              \
276                               .HostMemory("resource")          \
277                               .TypeConstraint<type>("dtype"),  \
278                           VarHandleOp)
279 TF_CALL_GPU_ALL_TYPES(REGISTER_GPU_KERNELS);
280 TF_CALL_int64(REGISTER_GPU_KERNELS);
281 TF_CALL_variant(REGISTER_GPU_KERNELS);
282 TF_CALL_uint32(REGISTER_GPU_KERNELS);
283 #undef REGISTER_GPU_KERNELS
284 
285 REGISTER_KERNEL_BUILDER(Name("_VarHandlesOp")
286                             .Device(DEVICE_GPU)
287                             .HostMemory("resources")
288                             .TypeConstraint("dtypes",
289                                             {DT_INT64, DT_COMPLEX64,
290                                              DT_COMPLEX128, DT_HALF, DT_FLOAT,
291                                              DT_DOUBLE, DT_BOOL, DT_VARIANT}),
292                         ResourceHandlesOp<Var>);
293 
294 #endif  // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
295 
296 REGISTER_KERNEL_BUILDER(
297     Name("VariableShape").Device(DEVICE_CPU).TypeConstraint<int32>("out_type"),
298     VariableShapeOp<int32>);
299 REGISTER_KERNEL_BUILDER(
300     Name("VariableShape").Device(DEVICE_CPU).TypeConstraint<int64>("out_type"),
301     VariableShapeOp<int64>);
302 
303 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
304 
305 REGISTER_KERNEL_BUILDER(Name("VariableShape")
306                             .Device(DEVICE_GPU)
307                             .TypeConstraint<int32>("out_type")
308                             .HostMemory("output")
309                             .HostMemory("input"),
310                         VariableShapeOp<int32>);
311 REGISTER_KERNEL_BUILDER(Name("VariableShape")
312                             .Device(DEVICE_GPU)
313                             .TypeConstraint<int64>("out_type")
314                             .HostMemory("output")
315                             .HostMemory("input"),
316                         VariableShapeOp<int64>);
317 
318 #endif  // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
319 
DestroyResourceOp(OpKernelConstruction * ctx)320 DestroyResourceOp::DestroyResourceOp(OpKernelConstruction* ctx)
321     : OpKernel(ctx) {
322   OP_REQUIRES_OK(ctx,
323                  ctx->GetAttr("ignore_lookup_error", &ignore_lookup_error_));
324 }
325 
Compute(OpKernelContext * ctx)326 void DestroyResourceOp::Compute(OpKernelContext* ctx) {
327   const ResourceHandle& p = HandleFromInput(ctx, 0);
328   Status status = DeleteResource(ctx, p);
329   if (ignore_lookup_error_ && errors::IsNotFound(status)) {
330     return;
331   }
332   OP_REQUIRES_OK(ctx, status);
333 }
334 
335 REGISTER_KERNEL_BUILDER(Name("DestroyResourceOp").Device(DEVICE_CPU),
336                         DestroyResourceOp);
337 REGISTER_KERNEL_BUILDER(
338     Name("DestroyResourceOp").Device(DEVICE_GPU).HostMemory("resource"),
339     DestroyResourceOp);
340 
341 template <typename Device, typename T>
342 class AssignVariableOp : public OpKernel {
343  public:
AssignVariableOp(OpKernelConstruction * c)344   explicit AssignVariableOp(OpKernelConstruction* c) : OpKernel(c) {
345     OP_REQUIRES_OK(c, c->GetAttr("dtype", &dtype_));
346     if (!c->GetAttr("_grappler_relax_allocator_constraints",
347                     &relax_constraints_)
348              .ok()) {
349       relax_constraints_ = false;
350     }
351   }
352 
Compute(OpKernelContext * context)353   void Compute(OpKernelContext* context) override {
354     OP_REQUIRES(context, dtype_ == context->input(1).dtype(),
355                 errors::InvalidArgument(
356                     "Variable and value dtypes don't match; respectively, ",
357                     DataTypeString(dtype_), " and ",
358                     DataTypeString(context->input(1).dtype())));
359     core::RefCountPtr<Var> variable;
360     const Tensor& value = context->input(1);
361     // Note: every resource-variable-manipulating op assumes copy-on-write
362     // semantics, and creates a copy of the variable's Tensor if its refcount is
363     // bigger than 1 when we try to modify it. This means we never need to copy
364     // the original tensor for AssignVariableOp; even if there are other live
365     // users of it we know none can modify it so this is always safe (even in
366     // esoteric cases where the same tensor is used to initialize multiple
367     // variables or the tensor is a constant this is safe, as future writes will
368     // trigger copies).
369     OP_REQUIRES_OK(context, LookupOrCreateResource<Var>(
370                                 context, HandleFromInput(context, 0), &variable,
371                                 [this, &value](Var** ptr) {
372                                   *ptr = new Var(dtype_);
373                                   *(*ptr)->tensor() = value;
374                                   (*ptr)->is_initialized = true;
375                                   return Status::OK();
376                                 }));
377     mutex_lock ml(*variable->mu());
378     // (variable->tensor()->dtype() == DT_INVALID && !variable->is_initialized)
379     // check below is to allow an XLA specific situation wherein update can
380     // happen first by the AssignVariableOp,
381     // in which case the variable is still uninitialized.
382     // When using TF-XLA, this scenario is possible when the execution uses the
383     // 'fallback' path (which essentially invokes Tensorflow ops via
384     // partitioned_call).
385     OP_REQUIRES(context,
386                 (variable->tensor()->dtype() == DT_INVALID &&
387                  !variable->is_initialized) ||
388                     variable->tensor()->dtype() == dtype_,
389                 errors::InvalidArgument(
390                     "Trying to assign variable with wrong dtype. Expected ",
391                     DataTypeString(variable->tensor()->dtype()), " got ",
392                     DataTypeString(dtype_)));
393     if (variable->copy_on_read_mode.load()) {
394       PersistentTensor unused;
395       Tensor* tmp;
396       AllocatorAttributes attr;
397       attr.set_gpu_compatible(true);
398       attr.set_nic_compatible(true);
399       OP_REQUIRES_OK(context,
400                      context->allocate_persistent(value.dtype(), value.shape(),
401                                                   &unused, &tmp, attr));
402       functor::DenseUpdate<Device, T, ASSIGN> copy_functor;
403       copy_functor(context->eigen_device<Device>(), tmp->flat<T>(),
404                    value.flat<T>());
405       *variable->tensor() = *tmp;
406     } else {
407       *variable->tensor() = value;
408     }
409     variable->is_initialized = true;
410   }
411 
412  private:
413   DataType dtype_;
414   bool relax_constraints_;
415 };
416 
417 template <typename Device>
418 class AssignVariableOp<Device, Variant> : public OpKernel {
419  public:
AssignVariableOp(OpKernelConstruction * c)420   explicit AssignVariableOp(OpKernelConstruction* c) : OpKernel(c) {
421     OP_REQUIRES_OK(c, c->GetAttr("dtype", &dtype_));
422     OP_REQUIRES(c, dtype_ == DT_VARIANT,
423                 errors::Internal("Variant kernel called with dtype: ",
424                                  DataTypeString(dtype_)));
425   }
426 
Compute(OpKernelContext * context)427   void Compute(OpKernelContext* context) override {
428     const Tensor& value = context->input(1);
429     core::RefCountPtr<Var> variable;
430     OP_REQUIRES_OK(context, LookupOrCreateResource<Var>(
431                                 context, HandleFromInput(context, 0), &variable,
432                                 [](Var** ptr) {
433                                   // Created on host.
434                                   *ptr = new Var(DT_VARIANT);
435                                   return Status::OK();
436                                 }));
437 
438     // For purposes of forwarding DT_VARIANT, we want the least
439     // restrictive attr; we already know the input is on host.
440     AllocatorAttributes attr;
441 
442     // Copying is unnecessary if we are the last user of the value
443     // tensor, we can just adopt the input tensor's buffer instead.
444     // Note that Variant objects themselves always reside on host.
445     //
446     // We nevertheless want to signal to the runtime that the tensor
447     // should reside in memory of the associated device, as Variant
448     // tensors may be marked as sitting on either CPU or GPU.  This
449     // helps to elide one or more copies.
450     std::unique_ptr<Tensor> input_alias = context->forward_input(
451         1, OpKernelContext::Params::kNoReservation /*output_index*/, DT_VARIANT,
452         value.shape(),
453         DEVICE_MEMORY /* HOST_MEMORY is only reserved for special cases */,
454         attr);
455 
456     mutex_lock ml(*variable->mu());
457     OP_REQUIRES(context, variable->tensor()->dtype() == DT_VARIANT,
458                 errors::InvalidArgument(
459                     "Trying to assign variable with wrong dtype. Expected ",
460                     DataTypeString(variable->tensor()->dtype()), " got ",
461                     DataTypeString(DT_VARIANT)));
462     variable->is_initialized = true;
463     *variable->tensor() = Tensor(DT_VARIANT, value.shape());
464 
465     if (input_alias) {
466       *variable->tensor() = *input_alias;
467       return;
468     }
469 
470     // Need to copy, but maybe we can re-use variable's buffer?
471     if (!variable->tensor()->RefCountIsOne() ||
472         !variable->tensor()->shape().IsSameSize(value.shape())) {
473       PersistentTensor unused;
474       Tensor* tmp;
475       // Allocation of DT_VARIANT is always on host.
476       attr.set_on_host(true);
477       OP_REQUIRES_OK(context,
478                      context->allocate_persistent(DT_VARIANT, value.shape(),
479                                                   &unused, &tmp, attr));
480       *variable->tensor() = *tmp;
481     }
482 
483     const auto elements_in = value.flat<Variant>();
484     auto elements_out = variable->tensor()->flat<Variant>();
485     for (int64 i = 0; i < elements_in.size(); ++i) {
486       elements_out(i) = elements_in(i);
487     }
488   }
489 
490  private:
491   DataType dtype_;
492 };
493 
494 #define REGISTER_KERNELS(type)                                \
495   REGISTER_KERNEL_BUILDER(Name("AssignVariableOp")            \
496                               .Device(DEVICE_CPU)             \
497                               .TypeConstraint<type>("dtype"), \
498                           AssignVariableOp<Eigen::ThreadPoolDevice, type>);
499 
500 TF_CALL_ALL_TYPES(REGISTER_KERNELS);
501 TF_CALL_QUANTIZED_TYPES(REGISTER_KERNELS);
502 #undef REGISTER_KERNELS
503 
504 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
505 #define REGISTER_GPU_KERNELS(type)                           \
506   REGISTER_KERNEL_BUILDER(Name("AssignVariableOp")           \
507                               .Device(DEVICE_GPU)            \
508                               .TypeConstraint<type>("dtype") \
509                               .HostMemory("resource"),       \
510                           AssignVariableOp<GPUDevice, type>);
511 
512 TF_CALL_GPU_ALL_TYPES(REGISTER_GPU_KERNELS);
513 TF_CALL_int64(REGISTER_GPU_KERNELS);
514 TF_CALL_variant(REGISTER_GPU_KERNELS);
515 TF_CALL_uint32(REGISTER_GPU_KERNELS);
516 #undef REGISTER_GPU_KERNELS
517 #endif  // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
518 
519 template <typename Device, typename T, DenseUpdateType Op>
520 class AssignUpdateVariableOp : public OpKernel {
521  public:
AssignUpdateVariableOp(OpKernelConstruction * c)522   explicit AssignUpdateVariableOp(OpKernelConstruction* c) : OpKernel(c) {}
523 
Compute(OpKernelContext * context)524   void Compute(OpKernelContext* context) override {
525     core::RefCountPtr<Var> variable;
526     OP_REQUIRES_OK(context, LookupResource(context, HandleFromInput(context, 0),
527                                            &variable));
528 
529     const Tensor& value = context->input(1);
530     // TODO(apassos): We could possibly avoid the copy done by
531     // PrepareToUpdateVariable() for commutative operations like Op ==
532     // ADD if value's refcount was 1.
533     mutex_lock ml(*variable->mu());
534     Tensor* var_tensor = variable->tensor();
535     OP_REQUIRES(context, var_tensor->shape().IsSameSize(value.shape()),
536                 errors::InvalidArgument("Cannot update variable with shape ",
537                                         var_tensor->shape().DebugString(),
538                                         " using a Tensor with shape ",
539                                         value.shape().DebugString(),
540                                         ", shapes must be equal."));
541     OP_REQUIRES_OK(
542         context, PrepareToUpdateVariable<Device, T>(
543                      context, var_tensor, variable->copy_on_read_mode.load()));
544     functor::DenseUpdate<Device, T, Op> update_functor;
545     update_functor(context->eigen_device<Device>(), var_tensor->flat<T>(),
546                    value.flat<T>());
547   }
548 };
549 
550 #define REGISTER_KERNELS(type)                                     \
551   REGISTER_KERNEL_BUILDER(                                         \
552       Name("AssignAddVariableOp")                                  \
553           .Device(DEVICE_CPU)                                      \
554           .TypeConstraint<type>("dtype"),                          \
555       AssignUpdateVariableOp<Eigen::ThreadPoolDevice, type, ADD>); \
556   REGISTER_KERNEL_BUILDER(                                         \
557       Name("AssignSubVariableOp")                                  \
558           .Device(DEVICE_CPU)                                      \
559           .TypeConstraint<type>("dtype"),                          \
560       AssignUpdateVariableOp<Eigen::ThreadPoolDevice, type, SUB>);
561 
562 TF_CALL_NUMBER_TYPES(REGISTER_KERNELS);
563 #undef REGISTER_KERNELS
564 
565 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
566 #define REGISTER_GPU_KERNELS(type)                                       \
567   REGISTER_KERNEL_BUILDER(Name("AssignAddVariableOp")                    \
568                               .Device(DEVICE_GPU)                        \
569                               .HostMemory("resource")                    \
570                               .TypeConstraint<type>("dtype"),            \
571                           AssignUpdateVariableOp<GPUDevice, type, ADD>); \
572   REGISTER_KERNEL_BUILDER(Name("AssignSubVariableOp")                    \
573                               .Device(DEVICE_GPU)                        \
574                               .HostMemory("resource")                    \
575                               .TypeConstraint<type>("dtype"),            \
576                           AssignUpdateVariableOp<GPUDevice, type, SUB>);
577 
578 TF_CALL_GPU_NUMBER_TYPES(REGISTER_GPU_KERNELS);
579 TF_CALL_int64(REGISTER_GPU_KERNELS);
580 #undef REGISTER_GPU_KERNELS
581 #endif  // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
582 
583 class VarIsInitializedOp : public OpKernel {
584  public:
VarIsInitializedOp(OpKernelConstruction * c)585   explicit VarIsInitializedOp(OpKernelConstruction* c) : OpKernel(c) {}
586 
Compute(OpKernelContext * context)587   void Compute(OpKernelContext* context) override {
588     Tensor* output = nullptr;
589     OP_REQUIRES_OK(context,
590                    context->allocate_output(0, TensorShape({}), &output));
591     auto output_tensor = output->tensor<bool, 0>();
592     core::RefCountPtr<Var> variable;
593     Status s = LookupResource(context, HandleFromInput(context, 0), &variable);
594     if (!s.ok()) {
595       output_tensor() = false;
596       return;
597     }
598     mutex_lock ml(*variable->mu());
599     output_tensor() = variable->is_initialized;
600   }
601 };
602 
603 REGISTER_KERNEL_BUILDER(Name("VarIsInitializedOp").Device(DEVICE_CPU),
604                         VarIsInitializedOp);
605 
606 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
607 REGISTER_KERNEL_BUILDER(Name("VarIsInitializedOp")
608                             .Device(DEVICE_GPU)
609                             .HostMemory("resource")
610                             .HostMemory("is_initialized"),
611                         IsResourceInitialized<Var>);
612 #endif  // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
613 
614 template <typename Device, typename T, typename Index>
615 class ResourceGatherOp : public OpKernel {
616  public:
ResourceGatherOp(OpKernelConstruction * c)617   explicit ResourceGatherOp(OpKernelConstruction* c) : OpKernel(c) {
618     OP_REQUIRES_OK(c, c->GetAttr("batch_dims", &batch_dims_));
619   }
620 
Compute(OpKernelContext * c)621   void Compute(OpKernelContext* c) override {
622     core::RefCountPtr<Var> v;
623     OP_REQUIRES_OK(c, LookupResource(c, HandleFromInput(c, 0), &v));
624     OP_REQUIRES_OK(c, EnsureSparseVariableAccess<Device, T>(c, v.get()));
625     // NOTE: We hold the lock for the whole gather operation instead
626     // of increasing the reference count of v->tensor() to avoid a
627     // situation where a write to the same variable will see a
628     // reference count greater than one and make a copy of the
629     // (potentially very large) tensor buffer.
630     tf_shared_lock ml(*v->mu());
631     const Tensor& params = *v->tensor();
632     const Tensor& indices = c->input(1);
633     OP_REQUIRES(
634         c, TensorShapeUtils::IsVectorOrHigher(params.shape()),
635         errors::InvalidArgument("params must be at least 1 dimensional"));
636 
637     // Check that we have enough index space
638     const int64 N = indices.NumElements();
639     OP_REQUIRES(
640         c, params.dim_size(0) <= std::numeric_limits<Index>::max(),
641         errors::InvalidArgument("params.shape[0] too large for ",
642                                 DataTypeString(DataTypeToEnum<Index>::v()),
643                                 " indexing: ", params.dim_size(0), " > ",
644                                 std::numeric_limits<Index>::max()));
645 
646     // The result shape is params.shape[:batch_dims] +
647     // indices.shape[batch_dims:] + params.shape[batch_dims+1:].
648     TensorShape result_shape;
649     for (int i = 0; i < batch_dims_; ++i) {
650       result_shape.AddDim(params.dim_size(i));
651     }
652     for (int i = batch_dims_; i < indices.dims(); ++i) {
653       result_shape.AddDim(indices.dim_size(i));
654     }
655     for (int i = batch_dims_ + 1; i < params.dims(); ++i) {
656       result_shape.AddDim(params.dim_size(i));
657     }
658 
659     Tensor* out = nullptr;
660     Tensor tmp;
661     if (params.dtype() == DT_VARIANT) {
662       tmp = Tensor(DT_VARIANT, result_shape);
663       c->set_output(0, tmp);
664       out = &tmp;
665     } else {
666       OP_REQUIRES_OK(c, c->allocate_output(0, result_shape, &out));
667     }
668 
669     if (N > 0) {
670       Tensor tmp_indices;
671 
672       // Points to the original or updated (if batch_dims is set) indices.
673       const Tensor* op_indices = &indices;
674       if (batch_dims_ > 0) {
675         OP_REQUIRES_OK(c, c->allocate_temp(indices.dtype(), indices.shape(),
676                                            &tmp_indices));
677         functor::DenseUpdate<Device, Index, ASSIGN> copy_functor;
678         copy_functor(c->eigen_device<Device>(), tmp_indices.flat<Index>(),
679                      indices.flat<Index>());
680 
681         AddBatchOffsets(&tmp_indices, params);
682         op_indices = &tmp_indices;
683       }
684 
685       int64 gather_dim_size = 1;
686       for (int idx = 0; idx <= batch_dims_; ++idx) {
687         gather_dim_size *= params.dim_size(idx);
688       }
689       int64 inner_size = 1;
690       for (int i = batch_dims_ + 1; i < params.dims(); ++i) {
691         inner_size *= params.dim_size(i);
692       }
693       auto params_flat = params.shaped<T, 3>({1, gather_dim_size, inner_size});
694       const auto indices_flat = op_indices->flat<Index>();
695       auto out_flat = out->shaped<T, 3>({1, N, out->NumElements() / N});
696 
697       functor::GatherFunctor<Device, T, Index> functor;
698       int64 bad_i = functor(c, params_flat, indices_flat, out_flat);
699 
700       OP_REQUIRES(
701           c, bad_i < 0,
702           errors::InvalidArgument(
703               "indices", SliceDebugString(indices.shape(), bad_i), " = ",
704               indices_flat(bad_i), " is not in [0, ", params.dim_size(0), ")"));
705     }
706   }
707 
708  private:
709   // Add the batch offset derived from params to each batch of indices.
710   // Example: batch_dims = 1, indices = [[0, 1, 2], [0, 1, 2]]
711   // If indexing into a params dimension of size 4, then the indices will become
712   // [0, 1, 2, 4, 5, 6]
AddBatchOffsets(Tensor * indices,const Tensor & params)713   void AddBatchOffsets(Tensor* indices, const Tensor& params) {
714     int64 batch_size = 1;  // The size of all batch dimensions.
715     for (int idx = 0; idx < batch_dims_; ++idx) {
716       batch_size *= params.dim_size(idx);
717     }
718 
719     auto indices_flat = indices->flat<Index>();
720     int64 const index_inner_size = indices->NumElements() / batch_size;
721     int64 const batch_offset = params.dim_size(batch_dims_);
722     for (int64 batch_idx = 0, dest_idx = 0; batch_idx < batch_size;
723          ++batch_idx) {
724       for (int64 idx = 0; idx < index_inner_size; ++idx) {
725         indices_flat(dest_idx++) += batch_offset * batch_idx;
726       }
727     }
728   }
729 
730   int32 batch_dims_ = 0;
731 };
732 
733 #define REGISTER_GATHER_FULL(dev, type, index_type)                    \
734   REGISTER_KERNEL_BUILDER(Name("ResourceGather")                       \
735                               .Device(DEVICE_##dev)                    \
736                               .HostMemory("resource")                  \
737                               .TypeConstraint<type>("dtype")           \
738                               .TypeConstraint<index_type>("Tindices"), \
739                           ResourceGatherOp<dev##Device, type, index_type>)
740 
741 #define REGISTER_GATHER_ALL_INDICES(dev, type) \
742   REGISTER_GATHER_FULL(dev, type, int32);      \
743   REGISTER_GATHER_FULL(dev, type, int64)
744 
745 #define REGISTER_GATHER_CPU(type) REGISTER_GATHER_ALL_INDICES(CPU, type)
746 
747 // Registration of the CPU implementations.
748 TF_CALL_ALL_TYPES(REGISTER_GATHER_CPU);
749 TF_CALL_QUANTIZED_TYPES(REGISTER_GATHER_CPU);
750 
751 // Registers GPU kernels.
752 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
753 #define REGISTER_GATHER_GPU(type) REGISTER_GATHER_ALL_INDICES(GPU, type)
754 
755 TF_CALL_int64(REGISTER_GATHER_GPU);
756 TF_CALL_GPU_ALL_TYPES(REGISTER_GATHER_GPU);
757 
758 // Variant objects themselves sit on CPU, even if they contain data
759 // pointing to a device.
760 REGISTER_KERNEL_BUILDER(Name("ResourceGather")
761                             .Device(DEVICE_GPU)
762                             .HostMemory("resource")
763                             .HostMemory("indices")
764                             .TypeConstraint<Variant>("dtype")
765                             .TypeConstraint<int32>("Tindices"),
766                         ResourceGatherOp<GPUDevice, Variant, int32>)
767 REGISTER_KERNEL_BUILDER(Name("ResourceGather")
768                             .Device(DEVICE_GPU)
769                             .HostMemory("resource")
770                             .HostMemory("indices")
771                             .TypeConstraint<Variant>("dtype")
772                             .TypeConstraint<int64>("Tindices"),
773                         ResourceGatherOp<GPUDevice, Variant, int64>)
774 
775 #endif  // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
776 
777 #undef REGISTER_GATHER_CPU
778 #undef REGISTER_GATHER_GPU
779 #undef REGISTER_GATHER_ALL_INDICES
780 #undef REGISTER_GATHER_FULL
781 
782 template <typename Device, typename T, typename Index>
783 class ResourceGatherNdOp : public OpKernel {
784  public:
ResourceGatherNdOp(OpKernelConstruction * c)785   explicit ResourceGatherNdOp(OpKernelConstruction* c) : OpKernel(c) {}
786 
Compute(OpKernelContext * c)787   void Compute(OpKernelContext* c) override {
788     core::RefCountPtr<Var> v;
789     OP_REQUIRES_OK(c, LookupResource(c, HandleFromInput(c, 0), &v));
790     OP_REQUIRES_OK(c, EnsureSparseVariableAccess<Device, T>(c, v.get()));
791     // NOTE: We hold the lock for the whole gather operation instead
792     // of increasing the reference count of v->tensor() to avoid a
793     // situation where a write to the same variable will see a
794     // reference count greater than one and make a copy of the
795     // (potentially very large) tensor buffer.
796     tf_shared_lock ml(*v->mu());
797     const Tensor& params = *v->tensor();
798     const Tensor& indices = c->input(1);
799 
800     Tensor out;
801     OP_REQUIRES_OK(
802         c, functor::DoGatherNd<Device, T, Index>(c, params, indices, &out));
803     c->set_output(0, out);
804   }
805 };
806 
807 #define REGISTER_GATHER_ND_FULL(dev, type, index_type)                 \
808   REGISTER_KERNEL_BUILDER(Name("ResourceGatherNd")                     \
809                               .Device(DEVICE_##dev)                    \
810                               .HostMemory("resource")                  \
811                               .TypeConstraint<type>("dtype")           \
812                               .TypeConstraint<index_type>("Tindices"), \
813                           ResourceGatherNdOp<dev##Device, type, index_type>)
814 
815 #define REGISTER_GATHER_ND_ALL_INDICES(dev, type) \
816   REGISTER_GATHER_ND_FULL(dev, type, int32);      \
817   REGISTER_GATHER_ND_FULL(dev, type, int64)
818 
819 #define REGISTER_GATHER_ND_CPU(type) REGISTER_GATHER_ND_ALL_INDICES(CPU, type)
820 
821 // Registration of the CPU implementations.
822 TF_CALL_ALL_TYPES(REGISTER_GATHER_ND_CPU);
823 
824 // Registers GPU kernels.
825 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
826 #define REGISTER_GATHER_ND_GPU(type) REGISTER_GATHER_ND_ALL_INDICES(GPU, type)
827 
828 TF_CALL_GPU_NUMBER_TYPES(REGISTER_GATHER_ND_GPU);
829 
830 #endif  // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
831 
832 #undef REGISTER_GATHER_ND_CPU
833 #undef REGISTER_GATHER_ND_GPU
834 #undef REGISTER_GATHER_ND_ALL_INDICES
835 #undef REGISTER_GATHER_ND_FULL
836 
837 template <typename Device, typename T, typename Index, scatter_op::UpdateOp op>
838 class ResourceScatterUpdateOp : public OpKernel {
839  public:
ResourceScatterUpdateOp(OpKernelConstruction * c)840   explicit ResourceScatterUpdateOp(OpKernelConstruction* c) : OpKernel(c) {
841     // We use the same kernel for many operations.
842     // Each operation has a different set of attributes defined in its nodes.
843     Status s = c->GetAttr("use_locking", &use_exclusive_lock_);
844     if (!s.ok()) {
845       use_exclusive_lock_ = false;
846     }
847   }
848 
Compute(OpKernelContext * c)849   void Compute(OpKernelContext* c) override {
850     core::RefCountPtr<Var> v;
851     OP_REQUIRES_OK(c, LookupResource(c, HandleFromInput(c, 0), &v));
852     OP_REQUIRES_OK(c, EnsureSparseVariableAccess<Device, T>(c, v.get()));
853     const bool is_non_pod_dtype = c->input_dtype(0) == DT_RESOURCE ||
854                                   c->input_dtype(0) == DT_STRING ||
855                                   c->input_dtype(0) == DT_VARIANT;
856     if (is_non_pod_dtype || use_exclusive_lock_) {
857       mutex_lock ml(*v->mu());
858       DoCompute(c);
859     } else {
860       // For POD dtypes, we can safely run the update without the mutex.
861       tf_shared_lock ml(*v->mu());
862       DoCompute(c);
863     }
864   }
865 
866  private:
867   bool use_exclusive_lock_;
868 
DoCompute(OpKernelContext * c)869   void DoCompute(OpKernelContext* c) {
870     core::RefCountPtr<Var> v;
871     OP_REQUIRES_OK(c, LookupResource(c, HandleFromInput(c, 0), &v));
872     Tensor* params = v->tensor();
873     const Tensor& indices = c->input(1);
874     const Tensor& updates = c->input(2);
875 
876     // Check that rank(updates.shape) = rank(indices.shape + params.shape[1:])
877     OP_REQUIRES(c,
878                 updates.dims() == 0 ||
879                     updates.dims() == indices.dims() + params->dims() - 1,
880                 errors::InvalidArgument(
881                     "Must have updates.shape = indices.shape + "
882                     "params.shape[1:] or updates.shape = [], got ",
883                     "updates.shape ", updates.shape().DebugString(),
884                     ", indices.shape ", indices.shape().DebugString(),
885                     ", params.shape ", params->shape().DebugString()));
886 
887     // Check that we have enough index space
888     const int64 N_big = indices.NumElements();
889     OP_REQUIRES(
890         c, N_big <= std::numeric_limits<Index>::max(),
891         errors::InvalidArgument("indices has too many elements for ",
892                                 DataTypeString(DataTypeToEnum<Index>::v()),
893                                 " indexing: ", N_big, " > ",
894                                 std::numeric_limits<Index>::max()));
895     const Index N = static_cast<Index>(N_big);
896     OP_REQUIRES(
897         c, params->dim_size(0) <= std::numeric_limits<Index>::max(),
898         errors::InvalidArgument("params.shape[0] too large for ",
899                                 DataTypeString(DataTypeToEnum<Index>::v()),
900                                 " indexing: ", params->dim_size(0), " > ",
901                                 std::numeric_limits<Index>::max()));
902 
903     if (N > 0) {
904       auto indices_flat = indices.flat<Index>();
905       auto params_flat = params->flat_outer_dims<T>();
906       if (TensorShapeUtils::IsScalar(updates.shape())) {
907         const auto update = updates.scalar<T>();
908 
909         functor::ScatterScalarFunctor<Device, T, Index, op> functor;
910         const Index bad_i = functor(c, c->template eigen_device<Device>(),
911                                     params_flat, update, indices_flat);
912         OP_REQUIRES(c, bad_i < 0,
913                     errors::InvalidArgument(
914                         "indices", SliceDebugString(indices.shape(), bad_i),
915                         " = ", indices_flat(bad_i), " is not in [0, ",
916                         params->dim_size(0), ")"));
917       } else {
918         int64 num_updates = updates.NumElements();
919         OP_REQUIRES(c, num_updates % N == 0,
920                     errors::InvalidArgument(
921                         "shape of indices (", indices.shape().DebugString(),
922                         ") is not compatible with the shape of updates (",
923                         updates.shape().DebugString(), ")"));
924         auto updates_flat = updates.shaped<T, 2>({N, num_updates / N});
925 
926         functor::ScatterFunctor<Device, T, Index, op> functor;
927         const Index bad_i = functor(c, c->template eigen_device<Device>(),
928                                     params_flat, updates_flat, indices_flat);
929         OP_REQUIRES(c, bad_i < 0,
930                     errors::InvalidArgument(
931                         "indices", SliceDebugString(indices.shape(), bad_i),
932                         " = ", indices_flat(bad_i), " is not in [0, ",
933                         params->dim_size(0), ")"));
934       }
935     }
936   }
937 };
938 
939 #define REGISTER_SCATTER_KERNEL_INDEX(type, index_type, dev, name, op) \
940   REGISTER_KERNEL_BUILDER(                                             \
941       Name(name)                                                       \
942           .Device(DEVICE_##dev)                                        \
943           .HostMemory("resource")                                      \
944           .TypeConstraint<type>("dtype")                               \
945           .TypeConstraint<index_type>("Tindices"),                     \
946       ResourceScatterUpdateOp<dev##Device, type, index_type, op>)
947 
948 #define REGISTER_SCATTER_KERNEL(type, dev, name, op)         \
949   REGISTER_SCATTER_KERNEL_INDEX(type, int32, dev, name, op); \
950   REGISTER_SCATTER_KERNEL_INDEX(type, int64, dev, name, op);
951 
952 #define REGISTER_SCATTER_ARITHMETIC(type, dev)                \
953   REGISTER_SCATTER_KERNEL(type, dev, "ResourceScatterAdd",    \
954                           scatter_op::UpdateOp::ADD);         \
955   REGISTER_SCATTER_KERNEL(type, dev, "ResourceScatterSub",    \
956                           scatter_op::UpdateOp::SUB);         \
957   REGISTER_SCATTER_KERNEL(type, dev, "ResourceScatterMul",    \
958                           scatter_op::UpdateOp::MUL);         \
959   REGISTER_SCATTER_KERNEL(type, dev, "ResourceScatterDiv",    \
960                           scatter_op::UpdateOp::DIV);         \
961   REGISTER_SCATTER_KERNEL(type, dev, "ResourceScatterUpdate", \
962                           scatter_op::UpdateOp::ASSIGN);
963 #define REGISTER_SCATTER_MINMAX(type, dev)                 \
964   REGISTER_SCATTER_KERNEL(type, dev, "ResourceScatterMin", \
965                           scatter_op::UpdateOp::MIN);      \
966   REGISTER_SCATTER_KERNEL(type, dev, "ResourceScatterMax", \
967                           scatter_op::UpdateOp::MAX);
968 
969 // Registers CPU kernels.
970 #define REGISTER_SCATTER_ARITHMETIC_CPU(type) \
971   REGISTER_SCATTER_ARITHMETIC(type, CPU);
972 #define REGISTER_SCATTER_MINMAX_CPU(type) REGISTER_SCATTER_MINMAX(type, CPU);
973 
974 TF_CALL_NUMBER_TYPES(REGISTER_SCATTER_ARITHMETIC_CPU);
975 TF_CALL_REAL_NUMBER_TYPES(REGISTER_SCATTER_MINMAX_CPU);
976 
977 REGISTER_SCATTER_KERNEL(tstring, CPU, "ResourceScatterUpdate",
978                         scatter_op::UpdateOp::ASSIGN);
979 REGISTER_SCATTER_KERNEL(bool, CPU, "ResourceScatterUpdate",
980                         scatter_op::UpdateOp::ASSIGN);
981 REGISTER_SCATTER_KERNEL(Variant, CPU, "ResourceScatterUpdate",
982                         scatter_op::UpdateOp::ASSIGN);
983 
984 // Registers GPU kernels.
985 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
986 #define REGISTER_SCATTER_ARITHMETIC_GPU(type) \
987   REGISTER_SCATTER_ARITHMETIC(type, GPU);
988 #define REGISTER_SCATTER_MINMAX_GPU(type) REGISTER_SCATTER_MINMAX(type, GPU);
989 
990 #define REGISTER_SCATTER_UPDATE_GPU(type) REGISTER_SCATTER_UPDATE(type, GPU);
991 
992 TF_CALL_GPU_NUMBER_TYPES(REGISTER_SCATTER_ARITHMETIC_GPU);
993 TF_CALL_GPU_NUMBER_TYPES(REGISTER_SCATTER_MINMAX_GPU);
994 
995 REGISTER_KERNEL_BUILDER(Name("ResourceScatterUpdate")
996                             .Device(DEVICE_GPU)
997                             .HostMemory("resource")
998                             .HostMemory("indices")
999                             .TypeConstraint<Variant>("dtype")
1000                             .TypeConstraint<int32>("Tindices"),
1001                         ResourceScatterUpdateOp<GPUDevice, Variant, int32,
1002                                                 scatter_op::UpdateOp::ASSIGN>)
1003 REGISTER_KERNEL_BUILDER(Name("ResourceScatterUpdate")
1004                             .Device(DEVICE_GPU)
1005                             .HostMemory("resource")
1006                             .TypeConstraint<bool>("dtype")
1007                             .TypeConstraint<int32>("Tindices"),
1008                         ResourceScatterUpdateOp<GPUDevice, bool, int32,
1009                                                 scatter_op::UpdateOp::ASSIGN>)
1010 REGISTER_KERNEL_BUILDER(Name("ResourceScatterUpdate")
1011                             .Device(DEVICE_GPU)
1012                             .HostMemory("resource")
1013                             .HostMemory("indices")
1014                             .TypeConstraint<Variant>("dtype")
1015                             .TypeConstraint<int64>("Tindices"),
1016                         ResourceScatterUpdateOp<GPUDevice, Variant, int64,
1017                                                 scatter_op::UpdateOp::ASSIGN>)
1018 
1019 #endif  // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
1020 
1021 #undef REGISTER_SCATTER_ARITHMETIC
1022 #undef REGISTER_SCATTER_ARITHMETIC_CPU
1023 #undef REGISTER_SCATTER_MINMAX
1024 #undef REGISTER_SCATTER_MINMAX_CPU
1025 #undef REGISTER_SCATTER_KERNEL
1026 #undef REGISTER_SCATTER_KERNEL_INDEX
1027 
1028 }  // namespace tensorflow
1029