1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #ifndef TENSORFLOW_CORE_KERNELS_TRAINING_OP_HELPERS_H_
17 #define TENSORFLOW_CORE_KERNELS_TRAINING_OP_HELPERS_H_
18
19 #include "tensorflow/core/framework/op_kernel.h"
20 #include "tensorflow/core/framework/tensor.h"
21 #include "tensorflow/core/framework/variant_op_registry.h"
22 #include "tensorflow/core/kernels/dense_update_functor.h"
23 #include "tensorflow/core/kernels/variable_ops.h"
24 #include "tensorflow/core/lib/core/refcount.h"
25
26 namespace tensorflow {
27
28 // Must be called before performing a sparse operation on a variable. Ensures
29 // that no concurrent dense operations can happen while holding the variable's
30 // lock.
31 template <typename Device, typename T>
EnsureSparseVariableAccess(OpKernelContext * ctx,Var * var)32 Status EnsureSparseVariableAccess(OpKernelContext* ctx, Var* var) {
33 if (var->copy_on_read_mode.load()) {
34 return Status::OK();
35 }
36 mutex_lock ml(*var->mu());
37 // Once copy-on-read mode is True the refcount is guaranteed to be 1. This can
38 // also happen if there are no concurrent reads of the variable and
39 // copy-on-read mode is false.
40 if (var->tensor()->RefCountIsOne()) {
41 var->copy_on_read_mode.store(true);
42 return Status::OK();
43 }
44 PersistentTensor unused;
45 Tensor* tmp;
46 if (std::is_same<T, Variant>::value) {
47 AllocatorAttributes attr;
48 attr.set_on_host(true);
49 TF_RETURN_IF_ERROR(ctx->allocate_persistent(
50 var->tensor()->dtype(), var->tensor()->shape(), &unused, &tmp, attr));
51
52 const auto elements_in = var->tensor()->flat<Variant>();
53 auto elements_out = tmp->flat<Variant>();
54 for (int64 i = 0; i < elements_in.size(); ++i) {
55 elements_out(i) = elements_in(i);
56 }
57 } else {
58 AllocatorAttributes attr;
59 attr.set_gpu_compatible(true);
60 attr.set_nic_compatible(true);
61 TF_RETURN_IF_ERROR(ctx->allocate_persistent(
62 var->tensor()->dtype(), var->tensor()->shape(), &unused, &tmp, attr));
63 functor::DenseUpdate<Device, T, ASSIGN> copy_functor;
64 copy_functor(ctx->eigen_device<Device>(), tmp->flat<T>(),
65 const_cast<const Tensor*>(var->tensor())->flat<T>());
66 }
67 *var->tensor() = *tmp;
68 var->copy_on_read_mode.store(true);
69 return Status::OK();
70 }
71
72 // Utility structure that releases a sequence of borrowed mutexes when it is
73 // deleted.
74 struct VariableInputLockHolder {
75 public:
VariableInputLockHolderVariableInputLockHolder76 VariableInputLockHolder(
77 std::vector<Var*> vars, std::unique_ptr<std::vector<mutex_lock>> locks,
78 std::unique_ptr<std::vector<tf_shared_lock>> shared_locks)
79 : vars_(std::move(vars)),
80 locks_(std::move(locks)),
81 shared_locks_(std::move(shared_locks)) {}
82
VariableInputLockHolderVariableInputLockHolder83 VariableInputLockHolder(VariableInputLockHolder&& other)
84 : vars_(std::move(other.vars_)),
85 locks_(std::move(other.locks_)),
86 shared_locks_(std::move(other.shared_locks_)) {}
87
~VariableInputLockHolderVariableInputLockHolder88 ~VariableInputLockHolder() {
89 // Release the locks before unreffing the Vars, because each lock
90 // is potentially borrowed from a Var in vars_.
91 locks_.reset();
92 for (Var* var : vars_) {
93 var->Unref();
94 }
95 }
96
97 private:
98 std::vector<Var*> vars_;
99 // NOTE: Use a `std::unique_ptr` instead of moving in a vector directly,
100 // because a `std::vector<mutex_lock>` is not movable on all platforms.
101 std::unique_ptr<std::vector<mutex_lock>> locks_;
102 std::unique_ptr<std::vector<tf_shared_lock>> shared_locks_;
103 };
104
105 // Returns a borrowed pointer to the mutex for the variable `input` in `ctx`.
106 //
107 // If `input` corresponds to a `DT_RESOURCE`-type variable input,
108 // `*maybe_resource` will be updated to contain the underlying resource, and the
109 // caller will be responsible for calling `Unref()` on that resource.
110 template <typename Device, typename T>
GetTrainingVariableMutex(OpKernelContext * ctx,int input,bool sparse,Var ** maybe_resource)111 mutex* GetTrainingVariableMutex(OpKernelContext* ctx, int input, bool sparse,
112 Var** maybe_resource) {
113 *maybe_resource = nullptr;
114 if (ctx->input_dtype(input) == DT_RESOURCE) {
115 if (LookupResource(ctx, HandleFromInput(ctx, input), maybe_resource).ok()) {
116 if (sparse) {
117 EnsureSparseVariableAccess<Device, T>(ctx, *maybe_resource)
118 .IgnoreError();
119 }
120 return (*maybe_resource)->mu();
121 } else {
122 ctx->CtxFailureWithWarning(
123 errors::Internal("Invalid variable reference."));
124 return nullptr;
125 }
126 }
127 return ctx->input_ref_mutex(input);
128 }
129
130 // MaybeLockVariableInputMutexesInOrder is a helper function to acquire mutexes
131 // in address order to mitigate deadlock. Returns a structure that, when
132 // deleted, will release the acquired mutexes. Safe to pass duplicates - will
133 // only lock each distinct mutex once. If sparse is true will ensure the
134 // variable gets switched to copy-on-read mode before trying to acquire the
135 // locks. If do_lock is false, returns immediately for reference variables. For
136 // resource variables in copy-on-read-mode it will grab a shared lock if do_lock
137 // is false, exclusive lock otherwise. Note that this silently doesn't lock
138 // mutexes for invalid variable references; in all usages this is followed by
139 // GetInputTensor which will signal a failure.
140 template <typename Device, typename T>
MaybeLockVariableInputMutexesInOrder(OpKernelContext * ctx,bool do_lock,bool sparse,const std::vector<int> & input_ids)141 VariableInputLockHolder MaybeLockVariableInputMutexesInOrder(
142 OpKernelContext* ctx, bool do_lock, bool sparse,
143 const std::vector<int>& input_ids) {
144 bool any_resource = false;
145 for (auto i : input_ids) {
146 if (ctx->input_dtype(i) == DT_RESOURCE) {
147 any_resource = true;
148 break;
149 }
150 }
151 if (!do_lock && !any_resource) {
152 return VariableInputLockHolder({}, {}, {});
153 }
154 std::vector<Var*> vars;
155 std::vector<mutex*> mutexes;
156 std::vector<int> acquire_order;
157 for (auto input : input_ids) {
158 Var* var;
159 mutex* mutex =
160 GetTrainingVariableMutex<Device, T>(ctx, input, sparse, &var);
161 if (var) vars.push_back(var);
162 // Only lock each mutex once if duplicates exist (n^2 but n is 2 or 3).
163 if (std::find(mutexes.begin(), mutexes.end(), mutex) == mutexes.end()) {
164 acquire_order.push_back(mutexes.size());
165 mutexes.push_back(mutex);
166 }
167 }
168 std::sort(acquire_order.begin(), acquire_order.end(),
169 [&mutexes](int a, int b) { return mutexes[a] < mutexes[b]; });
170
171 auto locks = absl::make_unique<std::vector<mutex_lock>>();
172 auto shared_locks = absl::make_unique<std::vector<tf_shared_lock>>();
173 locks->reserve(acquire_order.size());
174
175 for (auto input : acquire_order) {
176 Var* var;
177 mutex* mu = GetTrainingVariableMutex<Device, T>(ctx, input, sparse, &var);
178 core::ScopedUnref scoped_unref(var);
179 if (mu != nullptr) {
180 if (!sparse || do_lock) {
181 locks->emplace_back(*mu);
182 } else {
183 shared_locks->emplace_back(*mu);
184 }
185 }
186 }
187 return VariableInputLockHolder(std::move(vars), std::move(locks),
188 std::move(shared_locks));
189 }
190
191 void MaybeForwardRefInputToRefOutput(OpKernelContext* ctx, int input,
192 int output);
193
194 // This is for use with ResourceVariables to ensure *tensor has a
195 // reference count of 1 before you update it.
196 // REQUIRES: If you pass in variable->tensor(), *variable->mu() must be held.
197 template <typename Device, typename T>
PrepareToUpdateVariable(OpKernelContext * ctx,Tensor * tensor,bool copy_on_read_mode)198 Status PrepareToUpdateVariable(OpKernelContext* ctx, Tensor* tensor,
199 bool copy_on_read_mode) {
200 if (copy_on_read_mode || !tensor->RefCountIsOne()) {
201 // Tensor's buffer is in use by some read, so we need to copy before
202 // updating.
203 PersistentTensor unused;
204 Tensor* tmp;
205 if (std::is_same<T, Variant>::value) {
206 AllocatorAttributes attr;
207 attr.set_on_host(true);
208 TF_RETURN_IF_ERROR(ctx->allocate_persistent(
209 tensor->dtype(), tensor->shape(), &unused, &tmp, attr));
210
211 const auto elements_in = tensor->flat<Variant>();
212 auto elements_out = tmp->flat<Variant>();
213 for (int64 i = 0; i < elements_in.size(); ++i) {
214 elements_out(i) = elements_in(i);
215 }
216 } else {
217 AllocatorAttributes attr;
218 attr.set_gpu_compatible(true);
219 attr.set_nic_compatible(true);
220 TF_RETURN_IF_ERROR(ctx->allocate_persistent(
221 tensor->dtype(), tensor->shape(), &unused, &tmp, attr));
222 functor::DenseUpdate<Device, T, ASSIGN> copy_functor;
223 copy_functor(ctx->eigen_device<Device>(), tmp->flat<T>(),
224 const_cast<const Tensor*>(tensor)->flat<T>());
225 }
226 *tensor = *tmp;
227 }
228 return Status::OK();
229 }
230
231 // This gives you `*out`, a tensor you can update, corresponding to a variable
232 // passed as input index `input`. This handles the differences between
233 // reference and resource variables. For reference variables we can just grab
234 // the tensor, grabbing the lock if lock_held is False.
235 //
236 // For resource variables we, if sparse is true, ensure it's in copy-on-read
237 // mode, and then, regardless of the value of sparse, ensure its refcount is 1
238 // (by potentially copying its contents). In this case lock_held is ignored.
239 template <typename Device, typename T>
GetInputTensorFromVariable(OpKernelContext * ctx,int input,bool lock_held,bool sparse,Tensor * out)240 Status GetInputTensorFromVariable(OpKernelContext* ctx, int input,
241 bool lock_held, bool sparse, Tensor* out) {
242 if (ctx->input_dtype(input) == DT_RESOURCE) {
243 core::RefCountPtr<Var> var;
244 TF_RETURN_IF_ERROR(LookupResource(ctx, HandleFromInput(ctx, input), &var));
245 if (sparse) {
246 TF_RETURN_IF_ERROR(EnsureSparseVariableAccess<Device, T>(ctx, var.get()));
247 *out = *var->tensor();
248 return Status::OK();
249 }
250 TF_RETURN_IF_ERROR(PrepareToUpdateVariable<Device, T>(
251 ctx, var->tensor(), var->copy_on_read_mode.load()));
252 *out = *var->tensor();
253 return Status::OK();
254 }
255 *out = ctx->mutable_input(input, lock_held);
256 return Status::OK();
257 }
258
259 } // end namespace tensorflow
260
261 #endif // TENSORFLOW_CORE_KERNELS_TRAINING_OP_HELPERS_H_
262