1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 // Our general strategy for preventing conflicts between concurrent
17 // reads and writes of resource variables is to:
18 // * For read operations, we:
19 // - acquire the variable's mutex (in "shared" mode);
20 // - make a (shallow) copy of the Tensor object, which increments
21 // the reference count on the variable's TensorBuffer;
22 // - release the variable's mutex;
23 // - use the copy of the Tensor object to do the read.
24 // * For write operations, we:
25 // - acquire the variable's mutex (in "exclusive" mode);
26 // - check the reference count of variable's TensorBuffer and
27 // if it is >1, make a deep copy of the variable's Tensor;
28 // - mutate the variable's Tensor;
29 // - and release the variable's mutex.
30 // This allows several read operations to all use the same
31 // TensorBuffer without needing to copy. When it comes time to write
32 // it will only make a copy if there is an outstanding read using the
33 // buffer. Write operations are serialized by the variable's mutex.
34 //
35 // For sparse operations (scatter, gather, sparse optimizer updates),
36 // we need to avoid copies, since there may not be enough memory for
37 // to copies of the whole tensor. To support this, we make two
38 // modifications to the above strategy:
39 // * For sparse reads (gather), we hold the variable's mutex (still in
40 // "shared" mode) for the duration of the whole read. This means
41 // that as long as you only do sparse read operations no write will
42 // see the reference count >1.
43 // * For sparse write operations where the user explicitly specifies
44 // that they want to perform the write without locks held
45 // (use_locking=false), we never copy even if the variable's
46 // reference count is >1.
47
48 #define EIGEN_USE_THREADS
49
50 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
51 #define EIGEN_USE_GPU
52 #endif
53
54 #include "tensorflow/core/kernels/resource_variable_ops.h"
55
56 #include <memory>
57 #include <vector>
58
59 #include "absl/strings/str_join.h"
60 #include "tensorflow/core/common_runtime/device.h"
61 #include "tensorflow/core/framework/bounds_check.h"
62 #include "tensorflow/core/framework/op_kernel.h"
63 #include "tensorflow/core/framework/register_types.h"
64 #include "tensorflow/core/framework/resource_mgr.h"
65 #include "tensorflow/core/framework/tensor_shape.h"
66 #include "tensorflow/core/framework/tensor_types.h"
67 #include "tensorflow/core/framework/variant_op_registry.h"
68 #include "tensorflow/core/kernels/dense_update_functor.h"
69 #include "tensorflow/core/kernels/gather_functor.h"
70 #include "tensorflow/core/kernels/gather_nd_op.h"
71 #include "tensorflow/core/kernels/scatter_functor.h"
72 #include "tensorflow/core/kernels/training_op_helpers.h"
73 #include "tensorflow/core/kernels/variable_ops.h"
74 #include "tensorflow/core/lib/core/errors.h"
75 #include "tensorflow/core/lib/core/refcount.h"
76 #include "tensorflow/core/platform/casts.h"
77 #include "tensorflow/core/platform/mem.h"
78 #include "tensorflow/core/platform/mutex.h"
79 #include "tensorflow/core/platform/types.h"
80 #include "tensorflow/core/util/util.h"
81
82 namespace tensorflow {
83
84 REGISTER_KERNEL_BUILDER(Name("_VarHandlesOp").Device(DEVICE_CPU),
85 ResourceHandlesOp<Var>);
86
ReadVariableOp(OpKernelConstruction * c)87 ReadVariableOp::ReadVariableOp(OpKernelConstruction* c) : OpKernel(c) {
88 OP_REQUIRES_OK(c, c->GetAttr("dtype", &dtype_));
89 }
90
91 namespace {
92
CopyVariable(int output_idx,OpKernelContext * ctx,const Tensor * t)93 Status CopyVariable(int output_idx, OpKernelContext* ctx, const Tensor* t) {
94 Tensor* output;
95 Notification n;
96 Status status;
97 AllocatorAttributes attr;
98 if (t->dtype() == DT_VARIANT) {
99 attr.set_on_host(true);
100 }
101 TF_RETURN_IF_ERROR(
102 ctx->allocate_output(output_idx, t->shape(), &output, attr));
103 if (t->dtype() == DT_VARIANT) {
104 output->flat<Variant>() = t->flat<Variant>();
105 } else if (ctx->op_device_context() != nullptr) {
106 // TODO(apassos): remove the down_cast by just returning Device* from
107 // OpKernelContext
108 Device* device = down_cast<Device*>(ctx->device());
109 ctx->op_device_context()->CopyTensorInSameDevice(
110 t, device, output, [&n, &status](const Status& s) {
111 status = s;
112 n.Notify();
113 });
114 n.WaitForNotification();
115 return status;
116 } else {
117 switch (t->dtype()) {
118 #define HANDLER(type) \
119 case DataTypeToEnum<type>::value: \
120 output->flat<type>() = t->flat<type>(); \
121 break;
122 TF_CALL_ALL_TYPES(HANDLER);
123 #undef HANDLER
124 default:
125 return errors::Internal("Unsupported dtype", t->dtype());
126 }
127 }
128 return Status::OK();
129 }
130
131 } // namespace
132
Compute(OpKernelContext * ctx)133 void ReadVariableOp::Compute(OpKernelContext* ctx) {
134 core::RefCountPtr<Var> variable;
135 const ResourceHandle& handle = HandleFromInput(ctx, 0);
136 const auto status = LookupResource(ctx, handle, &variable);
137 OP_REQUIRES(ctx, status.ok(),
138 errors::FailedPrecondition(
139 "Could not find variable ", handle.name(), ". ",
140 "This could mean that the variable has been deleted. ",
141 "In TF1, it can also mean the variable is uninitialized. ",
142 "Debug info: container=", handle.container(),
143 ", status=", status.ToString()));
144
145 tf_shared_lock ml(*variable->mu());
146 // We're acquiring a reference to the underlying buffer while
147 // holding a shared lock to guarantee ordering of reads and
148 // writes when in copy-on-write mode.
149 const Tensor* t = variable->tensor();
150 if (!variable->copy_on_read_mode.load()) {
151 OP_REQUIRES(
152 ctx, dtype_ == t->dtype(),
153 errors::InvalidArgument(
154 "Trying to read variable with wrong dtype. Expected ",
155 DataTypeString(dtype_), " got ", DataTypeString(t->dtype())));
156 ctx->set_output(0, *t);
157 } else {
158 OP_REQUIRES_OK(ctx, CopyVariable(0, ctx, t));
159 }
160 }
161
ReadVariablesOp(OpKernelConstruction * c)162 ReadVariablesOp::ReadVariablesOp(OpKernelConstruction* c) : OpKernel(c) {
163 int n;
164 OP_REQUIRES_OK(c, c->GetAttr("N", &n));
165 OP_REQUIRES_OK(c, c->GetAttr("dtypes", &dtypes_));
166 OP_REQUIRES(c, n == dtypes_.size(),
167 errors::InvalidArgument(
168 "Mismatched number of arguments to ReadVariablesOp (", n,
169 " vs. ", dtypes_.size(), ")"));
170 }
171
Compute(OpKernelContext * ctx)172 void ReadVariablesOp::Compute(OpKernelContext* ctx) {
173 std::vector<core::RefCountPtr<Var>> variables(dtypes_.size());
174 std::vector<const ResourceHandle*> handles(dtypes_.size());
175 for (size_t i = 0; i < dtypes_.size(); ++i) {
176 handles[i] = &HandleFromInput(ctx, i);
177 }
178
179 OP_REQUIRES_OK(ctx, LookupResources(ctx, handles, &variables));
180
181 std::vector<string> uninitialized_vars;
182 for (int64 i = 0; i < variables.size(); i++) {
183 if (variables[i] == nullptr) {
184 uninitialized_vars.push_back(handles[i]->name());
185 }
186 }
187
188 OP_REQUIRES(ctx, uninitialized_vars.empty(),
189 errors::FailedPrecondition(
190 "In ReadVariablesOp the following variables were "
191 "found uninitialized: ",
192 absl::StrJoin(uninitialized_vars, ", ")));
193
194 for (size_t i = 0; i < dtypes_.size(); ++i) {
195 // We're acquiring a reference to the underlying buffer while
196 // holding a shared lock to guarantee ordering of reads and
197 // writes.
198 tf_shared_lock ml(*variables[i]->mu());
199 OP_REQUIRES(ctx, dtypes_[i] == variables[i]->tensor()->dtype(),
200 errors::InvalidArgument(
201 "Trying to read variable ", handles[i]->name(),
202 " from Container: ", handles[i]->container(),
203 " with wrong dtype. Expected ", DataTypeString(dtypes_[i]),
204 " got ", DataTypeString(variables[i]->tensor()->dtype())));
205 if (variables[i]->copy_on_read_mode.load()) {
206 OP_REQUIRES_OK(ctx, CopyVariable(i, ctx, variables[i]->tensor()));
207 } else {
208 const Tensor& t = *variables[i]->tensor();
209 ctx->set_output(i, t);
210 }
211 }
212 }
213
214 REGISTER_KERNEL_BUILDER(Name("ReadVariableOp").Device(DEVICE_CPU),
215 ReadVariableOp);
216 REGISTER_KERNEL_BUILDER(Name("_ReadVariablesOp").Device(DEVICE_CPU),
217 ReadVariablesOp);
218
VarHandleOp(OpKernelConstruction * context)219 VarHandleOp::VarHandleOp(OpKernelConstruction* context) : OpKernel(context) {
220 OP_REQUIRES_OK(context, context->GetAttr("container", &container_));
221 OP_REQUIRES_OK(context, context->GetAttr("shared_name", &name_));
222
223 OP_REQUIRES_OK(context, context->GetAttr("dtype", &dtype_and_shape_.dtype));
224 OP_REQUIRES_OK(context, context->GetAttr("shape", &dtype_and_shape_.shape));
225
226 is_anonymous_ = name_ == ResourceHandle::ANONYMOUS_NAME;
227
228 if (!is_anonymous_) {
229 AllocatorAttributes attr;
230 attr.set_on_host(true);
231 OP_REQUIRES_OK(context, context->allocate_temp(DT_RESOURCE, TensorShape({}),
232 &resource_, attr));
233 resource_.scalar<ResourceHandle>()() = MakeResourceHandle<Var>(
234 context, container_, name_,
235 std::vector<DtypeAndPartialTensorShape>{dtype_and_shape_});
236 }
237 }
238
Compute(OpKernelContext * ctx)239 void VarHandleOp::Compute(OpKernelContext* ctx) {
240 if (is_anonymous_) {
241 AllocatorAttributes attr;
242 attr.set_on_host(true);
243 Tensor handle;
244 OP_REQUIRES_OK(
245 ctx, ctx->allocate_temp(DT_RESOURCE, TensorShape({}), &handle, attr));
246 handle.scalar<ResourceHandle>()() = MakeResourceHandle<Var>(
247 ctx, container_, name_,
248 std::vector<DtypeAndPartialTensorShape>{dtype_and_shape_},
249 ctx->stack_trace());
250 ctx->set_output(0, handle);
251 } else {
252 ctx->set_output(0, resource_);
253 }
254 }
255
256 REGISTER_KERNEL_BUILDER(Name("VarHandleOp").Device(DEVICE_CPU), VarHandleOp);
257
258 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
259 REGISTER_KERNEL_BUILDER(
260 Name("ReadVariableOp").Device(DEVICE_GPU).HostMemory("resource"),
261 ReadVariableOp);
262 REGISTER_KERNEL_BUILDER(
263 Name("_ReadVariablesOp").Device(DEVICE_GPU).HostMemory("resources"),
264 ReadVariablesOp);
265
266 #define REGISTER_GPU_KERNELS(type) \
267 namespace functor { \
268 template <> \
269 void DenseUpdate<GPUDevice, type, ASSIGN>::operator()( \
270 const GPUDevice& d, typename TTypes<type>::Flat lhs, \
271 typename TTypes<type>::ConstFlat rhs); \
272 extern template struct DenseUpdate<GPUDevice, type, ASSIGN>; \
273 } \
274 REGISTER_KERNEL_BUILDER(Name("VarHandleOp") \
275 .Device(DEVICE_GPU) \
276 .HostMemory("resource") \
277 .TypeConstraint<type>("dtype"), \
278 VarHandleOp)
279 TF_CALL_GPU_ALL_TYPES(REGISTER_GPU_KERNELS);
280 TF_CALL_int64(REGISTER_GPU_KERNELS);
281 TF_CALL_variant(REGISTER_GPU_KERNELS);
282 TF_CALL_uint32(REGISTER_GPU_KERNELS);
283 #undef REGISTER_GPU_KERNELS
284
285 REGISTER_KERNEL_BUILDER(Name("_VarHandlesOp")
286 .Device(DEVICE_GPU)
287 .HostMemory("resources")
288 .TypeConstraint("dtypes",
289 {DT_INT64, DT_COMPLEX64,
290 DT_COMPLEX128, DT_HALF, DT_FLOAT,
291 DT_DOUBLE, DT_BOOL, DT_VARIANT}),
292 ResourceHandlesOp<Var>);
293
294 #endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
295
296 REGISTER_KERNEL_BUILDER(
297 Name("VariableShape").Device(DEVICE_CPU).TypeConstraint<int32>("out_type"),
298 VariableShapeOp<int32>);
299 REGISTER_KERNEL_BUILDER(
300 Name("VariableShape").Device(DEVICE_CPU).TypeConstraint<int64>("out_type"),
301 VariableShapeOp<int64>);
302
303 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
304
305 REGISTER_KERNEL_BUILDER(Name("VariableShape")
306 .Device(DEVICE_GPU)
307 .TypeConstraint<int32>("out_type")
308 .HostMemory("output")
309 .HostMemory("input"),
310 VariableShapeOp<int32>);
311 REGISTER_KERNEL_BUILDER(Name("VariableShape")
312 .Device(DEVICE_GPU)
313 .TypeConstraint<int64>("out_type")
314 .HostMemory("output")
315 .HostMemory("input"),
316 VariableShapeOp<int64>);
317
318 #endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
319
DestroyResourceOp(OpKernelConstruction * ctx)320 DestroyResourceOp::DestroyResourceOp(OpKernelConstruction* ctx)
321 : OpKernel(ctx) {
322 OP_REQUIRES_OK(ctx,
323 ctx->GetAttr("ignore_lookup_error", &ignore_lookup_error_));
324 }
325
Compute(OpKernelContext * ctx)326 void DestroyResourceOp::Compute(OpKernelContext* ctx) {
327 const ResourceHandle& p = HandleFromInput(ctx, 0);
328 Status status = DeleteResource(ctx, p);
329 if (ignore_lookup_error_ && errors::IsNotFound(status)) {
330 return;
331 }
332 OP_REQUIRES_OK(ctx, status);
333 }
334
335 REGISTER_KERNEL_BUILDER(Name("DestroyResourceOp").Device(DEVICE_CPU),
336 DestroyResourceOp);
337 REGISTER_KERNEL_BUILDER(
338 Name("DestroyResourceOp").Device(DEVICE_GPU).HostMemory("resource"),
339 DestroyResourceOp);
340
341 template <typename Device, typename T>
342 class AssignVariableOp : public OpKernel {
343 public:
AssignVariableOp(OpKernelConstruction * c)344 explicit AssignVariableOp(OpKernelConstruction* c) : OpKernel(c) {
345 OP_REQUIRES_OK(c, c->GetAttr("dtype", &dtype_));
346 if (!c->GetAttr("_grappler_relax_allocator_constraints",
347 &relax_constraints_)
348 .ok()) {
349 relax_constraints_ = false;
350 }
351 }
352
Compute(OpKernelContext * context)353 void Compute(OpKernelContext* context) override {
354 OP_REQUIRES(context, dtype_ == context->input(1).dtype(),
355 errors::InvalidArgument(
356 "Variable and value dtypes don't match; respectively, ",
357 DataTypeString(dtype_), " and ",
358 DataTypeString(context->input(1).dtype())));
359 core::RefCountPtr<Var> variable;
360 const Tensor& value = context->input(1);
361 // Note: every resource-variable-manipulating op assumes copy-on-write
362 // semantics, and creates a copy of the variable's Tensor if its refcount is
363 // bigger than 1 when we try to modify it. This means we never need to copy
364 // the original tensor for AssignVariableOp; even if there are other live
365 // users of it we know none can modify it so this is always safe (even in
366 // esoteric cases where the same tensor is used to initialize multiple
367 // variables or the tensor is a constant this is safe, as future writes will
368 // trigger copies).
369 OP_REQUIRES_OK(context, LookupOrCreateResource<Var>(
370 context, HandleFromInput(context, 0), &variable,
371 [this, &value](Var** ptr) {
372 *ptr = new Var(dtype_);
373 *(*ptr)->tensor() = value;
374 (*ptr)->is_initialized = true;
375 return Status::OK();
376 }));
377 mutex_lock ml(*variable->mu());
378 // (variable->tensor()->dtype() == DT_INVALID && !variable->is_initialized)
379 // check below is to allow an XLA specific situation wherein update can
380 // happen first by the AssignVariableOp,
381 // in which case the variable is still uninitialized.
382 // When using TF-XLA, this scenario is possible when the execution uses the
383 // 'fallback' path (which essentially invokes Tensorflow ops via
384 // partitioned_call).
385 OP_REQUIRES(context,
386 (variable->tensor()->dtype() == DT_INVALID &&
387 !variable->is_initialized) ||
388 variable->tensor()->dtype() == dtype_,
389 errors::InvalidArgument(
390 "Trying to assign variable with wrong dtype. Expected ",
391 DataTypeString(variable->tensor()->dtype()), " got ",
392 DataTypeString(dtype_)));
393 if (variable->copy_on_read_mode.load()) {
394 PersistentTensor unused;
395 Tensor* tmp;
396 AllocatorAttributes attr;
397 attr.set_gpu_compatible(true);
398 attr.set_nic_compatible(true);
399 OP_REQUIRES_OK(context,
400 context->allocate_persistent(value.dtype(), value.shape(),
401 &unused, &tmp, attr));
402 functor::DenseUpdate<Device, T, ASSIGN> copy_functor;
403 copy_functor(context->eigen_device<Device>(), tmp->flat<T>(),
404 value.flat<T>());
405 *variable->tensor() = *tmp;
406 } else {
407 *variable->tensor() = value;
408 }
409 variable->is_initialized = true;
410 }
411
412 private:
413 DataType dtype_;
414 bool relax_constraints_;
415 };
416
417 template <typename Device>
418 class AssignVariableOp<Device, Variant> : public OpKernel {
419 public:
AssignVariableOp(OpKernelConstruction * c)420 explicit AssignVariableOp(OpKernelConstruction* c) : OpKernel(c) {
421 OP_REQUIRES_OK(c, c->GetAttr("dtype", &dtype_));
422 OP_REQUIRES(c, dtype_ == DT_VARIANT,
423 errors::Internal("Variant kernel called with dtype: ",
424 DataTypeString(dtype_)));
425 }
426
Compute(OpKernelContext * context)427 void Compute(OpKernelContext* context) override {
428 const Tensor& value = context->input(1);
429 core::RefCountPtr<Var> variable;
430 OP_REQUIRES_OK(context, LookupOrCreateResource<Var>(
431 context, HandleFromInput(context, 0), &variable,
432 [](Var** ptr) {
433 // Created on host.
434 *ptr = new Var(DT_VARIANT);
435 return Status::OK();
436 }));
437
438 // For purposes of forwarding DT_VARIANT, we want the least
439 // restrictive attr; we already know the input is on host.
440 AllocatorAttributes attr;
441
442 // Copying is unnecessary if we are the last user of the value
443 // tensor, we can just adopt the input tensor's buffer instead.
444 // Note that Variant objects themselves always reside on host.
445 //
446 // We nevertheless want to signal to the runtime that the tensor
447 // should reside in memory of the associated device, as Variant
448 // tensors may be marked as sitting on either CPU or GPU. This
449 // helps to elide one or more copies.
450 std::unique_ptr<Tensor> input_alias = context->forward_input(
451 1, OpKernelContext::Params::kNoReservation /*output_index*/, DT_VARIANT,
452 value.shape(),
453 DEVICE_MEMORY /* HOST_MEMORY is only reserved for special cases */,
454 attr);
455
456 mutex_lock ml(*variable->mu());
457 OP_REQUIRES(context, variable->tensor()->dtype() == DT_VARIANT,
458 errors::InvalidArgument(
459 "Trying to assign variable with wrong dtype. Expected ",
460 DataTypeString(variable->tensor()->dtype()), " got ",
461 DataTypeString(DT_VARIANT)));
462 variable->is_initialized = true;
463 *variable->tensor() = Tensor(DT_VARIANT, value.shape());
464
465 if (input_alias) {
466 *variable->tensor() = *input_alias;
467 return;
468 }
469
470 // Need to copy, but maybe we can re-use variable's buffer?
471 if (!variable->tensor()->RefCountIsOne() ||
472 !variable->tensor()->shape().IsSameSize(value.shape())) {
473 PersistentTensor unused;
474 Tensor* tmp;
475 // Allocation of DT_VARIANT is always on host.
476 attr.set_on_host(true);
477 OP_REQUIRES_OK(context,
478 context->allocate_persistent(DT_VARIANT, value.shape(),
479 &unused, &tmp, attr));
480 *variable->tensor() = *tmp;
481 }
482
483 const auto elements_in = value.flat<Variant>();
484 auto elements_out = variable->tensor()->flat<Variant>();
485 for (int64 i = 0; i < elements_in.size(); ++i) {
486 elements_out(i) = elements_in(i);
487 }
488 }
489
490 private:
491 DataType dtype_;
492 };
493
494 #define REGISTER_KERNELS(type) \
495 REGISTER_KERNEL_BUILDER(Name("AssignVariableOp") \
496 .Device(DEVICE_CPU) \
497 .TypeConstraint<type>("dtype"), \
498 AssignVariableOp<Eigen::ThreadPoolDevice, type>);
499
500 TF_CALL_ALL_TYPES(REGISTER_KERNELS);
501 TF_CALL_QUANTIZED_TYPES(REGISTER_KERNELS);
502 #undef REGISTER_KERNELS
503
504 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
505 #define REGISTER_GPU_KERNELS(type) \
506 REGISTER_KERNEL_BUILDER(Name("AssignVariableOp") \
507 .Device(DEVICE_GPU) \
508 .TypeConstraint<type>("dtype") \
509 .HostMemory("resource"), \
510 AssignVariableOp<GPUDevice, type>);
511
512 TF_CALL_GPU_ALL_TYPES(REGISTER_GPU_KERNELS);
513 TF_CALL_int64(REGISTER_GPU_KERNELS);
514 TF_CALL_variant(REGISTER_GPU_KERNELS);
515 TF_CALL_uint32(REGISTER_GPU_KERNELS);
516 #undef REGISTER_GPU_KERNELS
517 #endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
518
519 template <typename Device, typename T, DenseUpdateType Op>
520 class AssignUpdateVariableOp : public OpKernel {
521 public:
AssignUpdateVariableOp(OpKernelConstruction * c)522 explicit AssignUpdateVariableOp(OpKernelConstruction* c) : OpKernel(c) {}
523
Compute(OpKernelContext * context)524 void Compute(OpKernelContext* context) override {
525 core::RefCountPtr<Var> variable;
526 OP_REQUIRES_OK(context, LookupResource(context, HandleFromInput(context, 0),
527 &variable));
528
529 const Tensor& value = context->input(1);
530 // TODO(apassos): We could possibly avoid the copy done by
531 // PrepareToUpdateVariable() for commutative operations like Op ==
532 // ADD if value's refcount was 1.
533 mutex_lock ml(*variable->mu());
534 Tensor* var_tensor = variable->tensor();
535 OP_REQUIRES(context, var_tensor->shape().IsSameSize(value.shape()),
536 errors::InvalidArgument("Cannot update variable with shape ",
537 var_tensor->shape().DebugString(),
538 " using a Tensor with shape ",
539 value.shape().DebugString(),
540 ", shapes must be equal."));
541 OP_REQUIRES_OK(
542 context, PrepareToUpdateVariable<Device, T>(
543 context, var_tensor, variable->copy_on_read_mode.load()));
544 functor::DenseUpdate<Device, T, Op> update_functor;
545 update_functor(context->eigen_device<Device>(), var_tensor->flat<T>(),
546 value.flat<T>());
547 }
548 };
549
550 #define REGISTER_KERNELS(type) \
551 REGISTER_KERNEL_BUILDER( \
552 Name("AssignAddVariableOp") \
553 .Device(DEVICE_CPU) \
554 .TypeConstraint<type>("dtype"), \
555 AssignUpdateVariableOp<Eigen::ThreadPoolDevice, type, ADD>); \
556 REGISTER_KERNEL_BUILDER( \
557 Name("AssignSubVariableOp") \
558 .Device(DEVICE_CPU) \
559 .TypeConstraint<type>("dtype"), \
560 AssignUpdateVariableOp<Eigen::ThreadPoolDevice, type, SUB>);
561
562 TF_CALL_NUMBER_TYPES(REGISTER_KERNELS);
563 #undef REGISTER_KERNELS
564
565 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
566 #define REGISTER_GPU_KERNELS(type) \
567 REGISTER_KERNEL_BUILDER(Name("AssignAddVariableOp") \
568 .Device(DEVICE_GPU) \
569 .HostMemory("resource") \
570 .TypeConstraint<type>("dtype"), \
571 AssignUpdateVariableOp<GPUDevice, type, ADD>); \
572 REGISTER_KERNEL_BUILDER(Name("AssignSubVariableOp") \
573 .Device(DEVICE_GPU) \
574 .HostMemory("resource") \
575 .TypeConstraint<type>("dtype"), \
576 AssignUpdateVariableOp<GPUDevice, type, SUB>);
577
578 TF_CALL_GPU_NUMBER_TYPES(REGISTER_GPU_KERNELS);
579 TF_CALL_int64(REGISTER_GPU_KERNELS);
580 #undef REGISTER_GPU_KERNELS
581 #endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
582
583 class VarIsInitializedOp : public OpKernel {
584 public:
VarIsInitializedOp(OpKernelConstruction * c)585 explicit VarIsInitializedOp(OpKernelConstruction* c) : OpKernel(c) {}
586
Compute(OpKernelContext * context)587 void Compute(OpKernelContext* context) override {
588 Tensor* output = nullptr;
589 OP_REQUIRES_OK(context,
590 context->allocate_output(0, TensorShape({}), &output));
591 auto output_tensor = output->tensor<bool, 0>();
592 core::RefCountPtr<Var> variable;
593 Status s = LookupResource(context, HandleFromInput(context, 0), &variable);
594 if (!s.ok()) {
595 output_tensor() = false;
596 return;
597 }
598 mutex_lock ml(*variable->mu());
599 output_tensor() = variable->is_initialized;
600 }
601 };
602
603 REGISTER_KERNEL_BUILDER(Name("VarIsInitializedOp").Device(DEVICE_CPU),
604 VarIsInitializedOp);
605
606 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
607 REGISTER_KERNEL_BUILDER(Name("VarIsInitializedOp")
608 .Device(DEVICE_GPU)
609 .HostMemory("resource")
610 .HostMemory("is_initialized"),
611 IsResourceInitialized<Var>);
612 #endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
613
614 template <typename Device, typename T, typename Index>
615 class ResourceGatherOp : public OpKernel {
616 public:
ResourceGatherOp(OpKernelConstruction * c)617 explicit ResourceGatherOp(OpKernelConstruction* c) : OpKernel(c) {
618 OP_REQUIRES_OK(c, c->GetAttr("batch_dims", &batch_dims_));
619 }
620
Compute(OpKernelContext * c)621 void Compute(OpKernelContext* c) override {
622 core::RefCountPtr<Var> v;
623 OP_REQUIRES_OK(c, LookupResource(c, HandleFromInput(c, 0), &v));
624 OP_REQUIRES_OK(c, EnsureSparseVariableAccess<Device, T>(c, v.get()));
625 // NOTE: We hold the lock for the whole gather operation instead
626 // of increasing the reference count of v->tensor() to avoid a
627 // situation where a write to the same variable will see a
628 // reference count greater than one and make a copy of the
629 // (potentially very large) tensor buffer.
630 tf_shared_lock ml(*v->mu());
631 const Tensor& params = *v->tensor();
632 const Tensor& indices = c->input(1);
633 OP_REQUIRES(
634 c, TensorShapeUtils::IsVectorOrHigher(params.shape()),
635 errors::InvalidArgument("params must be at least 1 dimensional"));
636
637 // Check that we have enough index space
638 const int64 N = indices.NumElements();
639 OP_REQUIRES(
640 c, params.dim_size(0) <= std::numeric_limits<Index>::max(),
641 errors::InvalidArgument("params.shape[0] too large for ",
642 DataTypeString(DataTypeToEnum<Index>::v()),
643 " indexing: ", params.dim_size(0), " > ",
644 std::numeric_limits<Index>::max()));
645
646 // The result shape is params.shape[:batch_dims] +
647 // indices.shape[batch_dims:] + params.shape[batch_dims+1:].
648 TensorShape result_shape;
649 for (int i = 0; i < batch_dims_; ++i) {
650 result_shape.AddDim(params.dim_size(i));
651 }
652 for (int i = batch_dims_; i < indices.dims(); ++i) {
653 result_shape.AddDim(indices.dim_size(i));
654 }
655 for (int i = batch_dims_ + 1; i < params.dims(); ++i) {
656 result_shape.AddDim(params.dim_size(i));
657 }
658
659 Tensor* out = nullptr;
660 Tensor tmp;
661 if (params.dtype() == DT_VARIANT) {
662 tmp = Tensor(DT_VARIANT, result_shape);
663 c->set_output(0, tmp);
664 out = &tmp;
665 } else {
666 OP_REQUIRES_OK(c, c->allocate_output(0, result_shape, &out));
667 }
668
669 if (N > 0) {
670 Tensor tmp_indices;
671
672 // Points to the original or updated (if batch_dims is set) indices.
673 const Tensor* op_indices = &indices;
674 if (batch_dims_ > 0) {
675 OP_REQUIRES_OK(c, c->allocate_temp(indices.dtype(), indices.shape(),
676 &tmp_indices));
677 functor::DenseUpdate<Device, Index, ASSIGN> copy_functor;
678 copy_functor(c->eigen_device<Device>(), tmp_indices.flat<Index>(),
679 indices.flat<Index>());
680
681 AddBatchOffsets(&tmp_indices, params);
682 op_indices = &tmp_indices;
683 }
684
685 int64 gather_dim_size = 1;
686 for (int idx = 0; idx <= batch_dims_; ++idx) {
687 gather_dim_size *= params.dim_size(idx);
688 }
689 int64 inner_size = 1;
690 for (int i = batch_dims_ + 1; i < params.dims(); ++i) {
691 inner_size *= params.dim_size(i);
692 }
693 auto params_flat = params.shaped<T, 3>({1, gather_dim_size, inner_size});
694 const auto indices_flat = op_indices->flat<Index>();
695 auto out_flat = out->shaped<T, 3>({1, N, out->NumElements() / N});
696
697 functor::GatherFunctor<Device, T, Index> functor;
698 int64 bad_i = functor(c, params_flat, indices_flat, out_flat);
699
700 OP_REQUIRES(
701 c, bad_i < 0,
702 errors::InvalidArgument(
703 "indices", SliceDebugString(indices.shape(), bad_i), " = ",
704 indices_flat(bad_i), " is not in [0, ", params.dim_size(0), ")"));
705 }
706 }
707
708 private:
709 // Add the batch offset derived from params to each batch of indices.
710 // Example: batch_dims = 1, indices = [[0, 1, 2], [0, 1, 2]]
711 // If indexing into a params dimension of size 4, then the indices will become
712 // [0, 1, 2, 4, 5, 6]
AddBatchOffsets(Tensor * indices,const Tensor & params)713 void AddBatchOffsets(Tensor* indices, const Tensor& params) {
714 int64 batch_size = 1; // The size of all batch dimensions.
715 for (int idx = 0; idx < batch_dims_; ++idx) {
716 batch_size *= params.dim_size(idx);
717 }
718
719 auto indices_flat = indices->flat<Index>();
720 int64 const index_inner_size = indices->NumElements() / batch_size;
721 int64 const batch_offset = params.dim_size(batch_dims_);
722 for (int64 batch_idx = 0, dest_idx = 0; batch_idx < batch_size;
723 ++batch_idx) {
724 for (int64 idx = 0; idx < index_inner_size; ++idx) {
725 indices_flat(dest_idx++) += batch_offset * batch_idx;
726 }
727 }
728 }
729
730 int32 batch_dims_ = 0;
731 };
732
733 #define REGISTER_GATHER_FULL(dev, type, index_type) \
734 REGISTER_KERNEL_BUILDER(Name("ResourceGather") \
735 .Device(DEVICE_##dev) \
736 .HostMemory("resource") \
737 .TypeConstraint<type>("dtype") \
738 .TypeConstraint<index_type>("Tindices"), \
739 ResourceGatherOp<dev##Device, type, index_type>)
740
741 #define REGISTER_GATHER_ALL_INDICES(dev, type) \
742 REGISTER_GATHER_FULL(dev, type, int32); \
743 REGISTER_GATHER_FULL(dev, type, int64)
744
745 #define REGISTER_GATHER_CPU(type) REGISTER_GATHER_ALL_INDICES(CPU, type)
746
747 // Registration of the CPU implementations.
748 TF_CALL_ALL_TYPES(REGISTER_GATHER_CPU);
749 TF_CALL_QUANTIZED_TYPES(REGISTER_GATHER_CPU);
750
751 // Registers GPU kernels.
752 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
753 #define REGISTER_GATHER_GPU(type) REGISTER_GATHER_ALL_INDICES(GPU, type)
754
755 TF_CALL_int64(REGISTER_GATHER_GPU);
756 TF_CALL_GPU_ALL_TYPES(REGISTER_GATHER_GPU);
757
758 // Variant objects themselves sit on CPU, even if they contain data
759 // pointing to a device.
760 REGISTER_KERNEL_BUILDER(Name("ResourceGather")
761 .Device(DEVICE_GPU)
762 .HostMemory("resource")
763 .HostMemory("indices")
764 .TypeConstraint<Variant>("dtype")
765 .TypeConstraint<int32>("Tindices"),
766 ResourceGatherOp<GPUDevice, Variant, int32>)
767 REGISTER_KERNEL_BUILDER(Name("ResourceGather")
768 .Device(DEVICE_GPU)
769 .HostMemory("resource")
770 .HostMemory("indices")
771 .TypeConstraint<Variant>("dtype")
772 .TypeConstraint<int64>("Tindices"),
773 ResourceGatherOp<GPUDevice, Variant, int64>)
774
775 #endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
776
777 #undef REGISTER_GATHER_CPU
778 #undef REGISTER_GATHER_GPU
779 #undef REGISTER_GATHER_ALL_INDICES
780 #undef REGISTER_GATHER_FULL
781
782 template <typename Device, typename T, typename Index>
783 class ResourceGatherNdOp : public OpKernel {
784 public:
ResourceGatherNdOp(OpKernelConstruction * c)785 explicit ResourceGatherNdOp(OpKernelConstruction* c) : OpKernel(c) {}
786
Compute(OpKernelContext * c)787 void Compute(OpKernelContext* c) override {
788 core::RefCountPtr<Var> v;
789 OP_REQUIRES_OK(c, LookupResource(c, HandleFromInput(c, 0), &v));
790 OP_REQUIRES_OK(c, EnsureSparseVariableAccess<Device, T>(c, v.get()));
791 // NOTE: We hold the lock for the whole gather operation instead
792 // of increasing the reference count of v->tensor() to avoid a
793 // situation where a write to the same variable will see a
794 // reference count greater than one and make a copy of the
795 // (potentially very large) tensor buffer.
796 tf_shared_lock ml(*v->mu());
797 const Tensor& params = *v->tensor();
798 const Tensor& indices = c->input(1);
799
800 Tensor out;
801 OP_REQUIRES_OK(
802 c, functor::DoGatherNd<Device, T, Index>(c, params, indices, &out));
803 c->set_output(0, out);
804 }
805 };
806
807 #define REGISTER_GATHER_ND_FULL(dev, type, index_type) \
808 REGISTER_KERNEL_BUILDER(Name("ResourceGatherNd") \
809 .Device(DEVICE_##dev) \
810 .HostMemory("resource") \
811 .TypeConstraint<type>("dtype") \
812 .TypeConstraint<index_type>("Tindices"), \
813 ResourceGatherNdOp<dev##Device, type, index_type>)
814
815 #define REGISTER_GATHER_ND_ALL_INDICES(dev, type) \
816 REGISTER_GATHER_ND_FULL(dev, type, int32); \
817 REGISTER_GATHER_ND_FULL(dev, type, int64)
818
819 #define REGISTER_GATHER_ND_CPU(type) REGISTER_GATHER_ND_ALL_INDICES(CPU, type)
820
821 // Registration of the CPU implementations.
822 TF_CALL_ALL_TYPES(REGISTER_GATHER_ND_CPU);
823
824 // Registers GPU kernels.
825 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
826 #define REGISTER_GATHER_ND_GPU(type) REGISTER_GATHER_ND_ALL_INDICES(GPU, type)
827
828 TF_CALL_GPU_NUMBER_TYPES(REGISTER_GATHER_ND_GPU);
829
830 #endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
831
832 #undef REGISTER_GATHER_ND_CPU
833 #undef REGISTER_GATHER_ND_GPU
834 #undef REGISTER_GATHER_ND_ALL_INDICES
835 #undef REGISTER_GATHER_ND_FULL
836
837 template <typename Device, typename T, typename Index, scatter_op::UpdateOp op>
838 class ResourceScatterUpdateOp : public OpKernel {
839 public:
ResourceScatterUpdateOp(OpKernelConstruction * c)840 explicit ResourceScatterUpdateOp(OpKernelConstruction* c) : OpKernel(c) {
841 // We use the same kernel for many operations.
842 // Each operation has a different set of attributes defined in its nodes.
843 Status s = c->GetAttr("use_locking", &use_exclusive_lock_);
844 if (!s.ok()) {
845 use_exclusive_lock_ = false;
846 }
847 }
848
Compute(OpKernelContext * c)849 void Compute(OpKernelContext* c) override {
850 core::RefCountPtr<Var> v;
851 OP_REQUIRES_OK(c, LookupResource(c, HandleFromInput(c, 0), &v));
852 OP_REQUIRES_OK(c, EnsureSparseVariableAccess<Device, T>(c, v.get()));
853 const bool is_non_pod_dtype = c->input_dtype(0) == DT_RESOURCE ||
854 c->input_dtype(0) == DT_STRING ||
855 c->input_dtype(0) == DT_VARIANT;
856 if (is_non_pod_dtype || use_exclusive_lock_) {
857 mutex_lock ml(*v->mu());
858 DoCompute(c);
859 } else {
860 // For POD dtypes, we can safely run the update without the mutex.
861 tf_shared_lock ml(*v->mu());
862 DoCompute(c);
863 }
864 }
865
866 private:
867 bool use_exclusive_lock_;
868
DoCompute(OpKernelContext * c)869 void DoCompute(OpKernelContext* c) {
870 core::RefCountPtr<Var> v;
871 OP_REQUIRES_OK(c, LookupResource(c, HandleFromInput(c, 0), &v));
872 Tensor* params = v->tensor();
873 const Tensor& indices = c->input(1);
874 const Tensor& updates = c->input(2);
875
876 // Check that rank(updates.shape) = rank(indices.shape + params.shape[1:])
877 OP_REQUIRES(c,
878 updates.dims() == 0 ||
879 updates.dims() == indices.dims() + params->dims() - 1,
880 errors::InvalidArgument(
881 "Must have updates.shape = indices.shape + "
882 "params.shape[1:] or updates.shape = [], got ",
883 "updates.shape ", updates.shape().DebugString(),
884 ", indices.shape ", indices.shape().DebugString(),
885 ", params.shape ", params->shape().DebugString()));
886
887 // Check that we have enough index space
888 const int64 N_big = indices.NumElements();
889 OP_REQUIRES(
890 c, N_big <= std::numeric_limits<Index>::max(),
891 errors::InvalidArgument("indices has too many elements for ",
892 DataTypeString(DataTypeToEnum<Index>::v()),
893 " indexing: ", N_big, " > ",
894 std::numeric_limits<Index>::max()));
895 const Index N = static_cast<Index>(N_big);
896 OP_REQUIRES(
897 c, params->dim_size(0) <= std::numeric_limits<Index>::max(),
898 errors::InvalidArgument("params.shape[0] too large for ",
899 DataTypeString(DataTypeToEnum<Index>::v()),
900 " indexing: ", params->dim_size(0), " > ",
901 std::numeric_limits<Index>::max()));
902
903 if (N > 0) {
904 auto indices_flat = indices.flat<Index>();
905 auto params_flat = params->flat_outer_dims<T>();
906 if (TensorShapeUtils::IsScalar(updates.shape())) {
907 const auto update = updates.scalar<T>();
908
909 functor::ScatterScalarFunctor<Device, T, Index, op> functor;
910 const Index bad_i = functor(c, c->template eigen_device<Device>(),
911 params_flat, update, indices_flat);
912 OP_REQUIRES(c, bad_i < 0,
913 errors::InvalidArgument(
914 "indices", SliceDebugString(indices.shape(), bad_i),
915 " = ", indices_flat(bad_i), " is not in [0, ",
916 params->dim_size(0), ")"));
917 } else {
918 int64 num_updates = updates.NumElements();
919 OP_REQUIRES(c, num_updates % N == 0,
920 errors::InvalidArgument(
921 "shape of indices (", indices.shape().DebugString(),
922 ") is not compatible with the shape of updates (",
923 updates.shape().DebugString(), ")"));
924 auto updates_flat = updates.shaped<T, 2>({N, num_updates / N});
925
926 functor::ScatterFunctor<Device, T, Index, op> functor;
927 const Index bad_i = functor(c, c->template eigen_device<Device>(),
928 params_flat, updates_flat, indices_flat);
929 OP_REQUIRES(c, bad_i < 0,
930 errors::InvalidArgument(
931 "indices", SliceDebugString(indices.shape(), bad_i),
932 " = ", indices_flat(bad_i), " is not in [0, ",
933 params->dim_size(0), ")"));
934 }
935 }
936 }
937 };
938
939 #define REGISTER_SCATTER_KERNEL_INDEX(type, index_type, dev, name, op) \
940 REGISTER_KERNEL_BUILDER( \
941 Name(name) \
942 .Device(DEVICE_##dev) \
943 .HostMemory("resource") \
944 .TypeConstraint<type>("dtype") \
945 .TypeConstraint<index_type>("Tindices"), \
946 ResourceScatterUpdateOp<dev##Device, type, index_type, op>)
947
948 #define REGISTER_SCATTER_KERNEL(type, dev, name, op) \
949 REGISTER_SCATTER_KERNEL_INDEX(type, int32, dev, name, op); \
950 REGISTER_SCATTER_KERNEL_INDEX(type, int64, dev, name, op);
951
952 #define REGISTER_SCATTER_ARITHMETIC(type, dev) \
953 REGISTER_SCATTER_KERNEL(type, dev, "ResourceScatterAdd", \
954 scatter_op::UpdateOp::ADD); \
955 REGISTER_SCATTER_KERNEL(type, dev, "ResourceScatterSub", \
956 scatter_op::UpdateOp::SUB); \
957 REGISTER_SCATTER_KERNEL(type, dev, "ResourceScatterMul", \
958 scatter_op::UpdateOp::MUL); \
959 REGISTER_SCATTER_KERNEL(type, dev, "ResourceScatterDiv", \
960 scatter_op::UpdateOp::DIV); \
961 REGISTER_SCATTER_KERNEL(type, dev, "ResourceScatterUpdate", \
962 scatter_op::UpdateOp::ASSIGN);
963 #define REGISTER_SCATTER_MINMAX(type, dev) \
964 REGISTER_SCATTER_KERNEL(type, dev, "ResourceScatterMin", \
965 scatter_op::UpdateOp::MIN); \
966 REGISTER_SCATTER_KERNEL(type, dev, "ResourceScatterMax", \
967 scatter_op::UpdateOp::MAX);
968
969 // Registers CPU kernels.
970 #define REGISTER_SCATTER_ARITHMETIC_CPU(type) \
971 REGISTER_SCATTER_ARITHMETIC(type, CPU);
972 #define REGISTER_SCATTER_MINMAX_CPU(type) REGISTER_SCATTER_MINMAX(type, CPU);
973
974 TF_CALL_NUMBER_TYPES(REGISTER_SCATTER_ARITHMETIC_CPU);
975 TF_CALL_REAL_NUMBER_TYPES(REGISTER_SCATTER_MINMAX_CPU);
976
977 REGISTER_SCATTER_KERNEL(tstring, CPU, "ResourceScatterUpdate",
978 scatter_op::UpdateOp::ASSIGN);
979 REGISTER_SCATTER_KERNEL(bool, CPU, "ResourceScatterUpdate",
980 scatter_op::UpdateOp::ASSIGN);
981 REGISTER_SCATTER_KERNEL(Variant, CPU, "ResourceScatterUpdate",
982 scatter_op::UpdateOp::ASSIGN);
983
984 // Registers GPU kernels.
985 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
986 #define REGISTER_SCATTER_ARITHMETIC_GPU(type) \
987 REGISTER_SCATTER_ARITHMETIC(type, GPU);
988 #define REGISTER_SCATTER_MINMAX_GPU(type) REGISTER_SCATTER_MINMAX(type, GPU);
989
990 #define REGISTER_SCATTER_UPDATE_GPU(type) REGISTER_SCATTER_UPDATE(type, GPU);
991
992 TF_CALL_GPU_NUMBER_TYPES(REGISTER_SCATTER_ARITHMETIC_GPU);
993 TF_CALL_GPU_NUMBER_TYPES(REGISTER_SCATTER_MINMAX_GPU);
994
995 REGISTER_KERNEL_BUILDER(Name("ResourceScatterUpdate")
996 .Device(DEVICE_GPU)
997 .HostMemory("resource")
998 .HostMemory("indices")
999 .TypeConstraint<Variant>("dtype")
1000 .TypeConstraint<int32>("Tindices"),
1001 ResourceScatterUpdateOp<GPUDevice, Variant, int32,
1002 scatter_op::UpdateOp::ASSIGN>)
1003 REGISTER_KERNEL_BUILDER(Name("ResourceScatterUpdate")
1004 .Device(DEVICE_GPU)
1005 .HostMemory("resource")
1006 .TypeConstraint<bool>("dtype")
1007 .TypeConstraint<int32>("Tindices"),
1008 ResourceScatterUpdateOp<GPUDevice, bool, int32,
1009 scatter_op::UpdateOp::ASSIGN>)
1010 REGISTER_KERNEL_BUILDER(Name("ResourceScatterUpdate")
1011 .Device(DEVICE_GPU)
1012 .HostMemory("resource")
1013 .HostMemory("indices")
1014 .TypeConstraint<Variant>("dtype")
1015 .TypeConstraint<int64>("Tindices"),
1016 ResourceScatterUpdateOp<GPUDevice, Variant, int64,
1017 scatter_op::UpdateOp::ASSIGN>)
1018
1019 #endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
1020
1021 #undef REGISTER_SCATTER_ARITHMETIC
1022 #undef REGISTER_SCATTER_ARITHMETIC_CPU
1023 #undef REGISTER_SCATTER_MINMAX
1024 #undef REGISTER_SCATTER_MINMAX_CPU
1025 #undef REGISTER_SCATTER_KERNEL
1026 #undef REGISTER_SCATTER_KERNEL_INDEX
1027
1028 } // namespace tensorflow
1029