1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_CORE_KERNELS_TRAINING_OP_HELPERS_H_
17 #define TENSORFLOW_CORE_KERNELS_TRAINING_OP_HELPERS_H_
18 
19 #include "tensorflow/core/framework/op_kernel.h"
20 #include "tensorflow/core/framework/tensor.h"
21 #include "tensorflow/core/framework/variant_op_registry.h"
22 #include "tensorflow/core/kernels/dense_update_functor.h"
23 #include "tensorflow/core/kernels/variable_ops.h"
24 #include "tensorflow/core/lib/core/refcount.h"
25 
26 namespace tensorflow {
27 
28 // Must be called before performing a sparse operation on a variable. Ensures
29 // that no concurrent dense operations can happen while holding the variable's
30 // lock.
31 template <typename Device, typename T>
EnsureSparseVariableAccess(OpKernelContext * ctx,Var * var)32 Status EnsureSparseVariableAccess(OpKernelContext* ctx, Var* var) {
33   if (var->copy_on_read_mode.load()) {
34     return Status::OK();
35   }
36   mutex_lock ml(*var->mu());
37   // Once copy-on-read mode is True the refcount is guaranteed to be 1. This can
38   // also happen if there are no concurrent reads of the variable and
39   // copy-on-read mode is false.
40   if (var->tensor()->RefCountIsOne()) {
41     var->copy_on_read_mode.store(true);
42     return Status::OK();
43   }
44   PersistentTensor unused;
45   Tensor* tmp;
46   if (std::is_same<T, Variant>::value) {
47     AllocatorAttributes attr;
48     attr.set_on_host(true);
49     TF_RETURN_IF_ERROR(ctx->allocate_persistent(
50         var->tensor()->dtype(), var->tensor()->shape(), &unused, &tmp, attr));
51 
52     const auto elements_in = var->tensor()->flat<Variant>();
53     auto elements_out = tmp->flat<Variant>();
54     for (int64 i = 0; i < elements_in.size(); ++i) {
55       elements_out(i) = elements_in(i);
56     }
57   } else {
58     AllocatorAttributes attr;
59     attr.set_gpu_compatible(true);
60     attr.set_nic_compatible(true);
61     TF_RETURN_IF_ERROR(ctx->allocate_persistent(
62         var->tensor()->dtype(), var->tensor()->shape(), &unused, &tmp, attr));
63     functor::DenseUpdate<Device, T, ASSIGN> copy_functor;
64     copy_functor(ctx->eigen_device<Device>(), tmp->flat<T>(),
65                  const_cast<const Tensor*>(var->tensor())->flat<T>());
66   }
67   *var->tensor() = *tmp;
68   var->copy_on_read_mode.store(true);
69   return Status::OK();
70 }
71 
72 // Utility structure that releases a sequence of borrowed mutexes when it is
73 // deleted.
74 struct VariableInputLockHolder {
75  public:
VariableInputLockHolderVariableInputLockHolder76   VariableInputLockHolder(
77       std::vector<Var*> vars, std::unique_ptr<std::vector<mutex_lock>> locks,
78       std::unique_ptr<std::vector<tf_shared_lock>> shared_locks)
79       : vars_(std::move(vars)),
80         locks_(std::move(locks)),
81         shared_locks_(std::move(shared_locks)) {}
82 
VariableInputLockHolderVariableInputLockHolder83   VariableInputLockHolder(VariableInputLockHolder&& other)
84       : vars_(std::move(other.vars_)),
85         locks_(std::move(other.locks_)),
86         shared_locks_(std::move(other.shared_locks_)) {}
87 
~VariableInputLockHolderVariableInputLockHolder88   ~VariableInputLockHolder() {
89     // Release the locks before unreffing the Vars, because each lock
90     // is potentially borrowed from a Var in vars_.
91     locks_.reset();
92     for (Var* var : vars_) {
93       var->Unref();
94     }
95   }
96 
97  private:
98   std::vector<Var*> vars_;
99   // NOTE: Use a `std::unique_ptr` instead of moving in a vector directly,
100   // because a `std::vector<mutex_lock>` is not movable on all platforms.
101   std::unique_ptr<std::vector<mutex_lock>> locks_;
102   std::unique_ptr<std::vector<tf_shared_lock>> shared_locks_;
103 };
104 
105 // Returns a borrowed pointer to the mutex for the variable `input` in `ctx`.
106 //
107 // If `input` corresponds to a `DT_RESOURCE`-type variable input,
108 // `*maybe_resource` will be updated to contain the underlying resource, and the
109 // caller will be responsible for calling `Unref()` on that resource.
110 template <typename Device, typename T>
GetTrainingVariableMutex(OpKernelContext * ctx,int input,bool sparse,Var ** maybe_resource)111 mutex* GetTrainingVariableMutex(OpKernelContext* ctx, int input, bool sparse,
112                                 Var** maybe_resource) {
113   *maybe_resource = nullptr;
114   if (ctx->input_dtype(input) == DT_RESOURCE) {
115     if (LookupResource(ctx, HandleFromInput(ctx, input), maybe_resource).ok()) {
116       if (sparse) {
117         EnsureSparseVariableAccess<Device, T>(ctx, *maybe_resource)
118             .IgnoreError();
119       }
120       return (*maybe_resource)->mu();
121     } else {
122       ctx->CtxFailureWithWarning(
123           errors::Internal("Invalid variable reference."));
124       return nullptr;
125     }
126   }
127   return ctx->input_ref_mutex(input);
128 }
129 
130 // MaybeLockVariableInputMutexesInOrder is a helper function to acquire mutexes
131 // in address order to mitigate deadlock.  Returns a structure that, when
132 // deleted, will release the acquired mutexes. Safe to pass duplicates - will
133 // only lock each distinct mutex once. If sparse is true will ensure the
134 // variable gets switched to copy-on-read mode before trying to acquire the
135 // locks. If do_lock is false, returns immediately for reference variables. For
136 // resource variables in copy-on-read-mode it will grab a shared lock if do_lock
137 // is false, exclusive lock otherwise.  Note that this silently doesn't lock
138 // mutexes for invalid variable references; in all usages this is followed by
139 // GetInputTensor which will signal a failure.
140 template <typename Device, typename T>
MaybeLockVariableInputMutexesInOrder(OpKernelContext * ctx,bool do_lock,bool sparse,const std::vector<int> & input_ids)141 VariableInputLockHolder MaybeLockVariableInputMutexesInOrder(
142     OpKernelContext* ctx, bool do_lock, bool sparse,
143     const std::vector<int>& input_ids) {
144   bool any_resource = false;
145   for (auto i : input_ids) {
146     if (ctx->input_dtype(i) == DT_RESOURCE) {
147       any_resource = true;
148       break;
149     }
150   }
151   if (!do_lock && !any_resource) {
152     return VariableInputLockHolder({}, {}, {});
153   }
154   std::vector<Var*> vars;
155   std::vector<mutex*> mutexes;
156   std::vector<int> acquire_order;
157   for (auto input : input_ids) {
158     Var* var;
159     mutex* mutex =
160         GetTrainingVariableMutex<Device, T>(ctx, input, sparse, &var);
161     if (var) vars.push_back(var);
162     // Only lock each mutex once if duplicates exist (n^2 but n is 2 or 3).
163     if (std::find(mutexes.begin(), mutexes.end(), mutex) == mutexes.end()) {
164       acquire_order.push_back(mutexes.size());
165       mutexes.push_back(mutex);
166     }
167   }
168   std::sort(acquire_order.begin(), acquire_order.end(),
169             [&mutexes](int a, int b) { return mutexes[a] < mutexes[b]; });
170 
171   auto locks = absl::make_unique<std::vector<mutex_lock>>();
172   auto shared_locks = absl::make_unique<std::vector<tf_shared_lock>>();
173   locks->reserve(acquire_order.size());
174 
175   for (auto input : acquire_order) {
176     Var* var;
177     mutex* mu = GetTrainingVariableMutex<Device, T>(ctx, input, sparse, &var);
178     core::ScopedUnref scoped_unref(var);
179     if (mu != nullptr) {
180       if (!sparse || do_lock) {
181         locks->emplace_back(*mu);
182       } else {
183         shared_locks->emplace_back(*mu);
184       }
185     }
186   }
187   return VariableInputLockHolder(std::move(vars), std::move(locks),
188                                  std::move(shared_locks));
189 }
190 
191 void MaybeForwardRefInputToRefOutput(OpKernelContext* ctx, int input,
192                                      int output);
193 
194 // This is for use with ResourceVariables to ensure *tensor has a
195 // reference count of 1 before you update it.
196 // REQUIRES: If you pass in variable->tensor(), *variable->mu() must be held.
197 template <typename Device, typename T>
PrepareToUpdateVariable(OpKernelContext * ctx,Tensor * tensor,bool copy_on_read_mode)198 Status PrepareToUpdateVariable(OpKernelContext* ctx, Tensor* tensor,
199                                bool copy_on_read_mode) {
200   if (copy_on_read_mode || !tensor->RefCountIsOne()) {
201     // Tensor's buffer is in use by some read, so we need to copy before
202     // updating.
203     PersistentTensor unused;
204     Tensor* tmp;
205     if (std::is_same<T, Variant>::value) {
206       AllocatorAttributes attr;
207       attr.set_on_host(true);
208       TF_RETURN_IF_ERROR(ctx->allocate_persistent(
209           tensor->dtype(), tensor->shape(), &unused, &tmp, attr));
210 
211       const auto elements_in = tensor->flat<Variant>();
212       auto elements_out = tmp->flat<Variant>();
213       for (int64 i = 0; i < elements_in.size(); ++i) {
214         elements_out(i) = elements_in(i);
215       }
216     } else {
217       AllocatorAttributes attr;
218       attr.set_gpu_compatible(true);
219       attr.set_nic_compatible(true);
220       TF_RETURN_IF_ERROR(ctx->allocate_persistent(
221           tensor->dtype(), tensor->shape(), &unused, &tmp, attr));
222       functor::DenseUpdate<Device, T, ASSIGN> copy_functor;
223       copy_functor(ctx->eigen_device<Device>(), tmp->flat<T>(),
224                    const_cast<const Tensor*>(tensor)->flat<T>());
225     }
226     *tensor = *tmp;
227   }
228   return Status::OK();
229 }
230 
231 // This gives you `*out`, a tensor you can update, corresponding to a variable
232 // passed as input index `input`.  This handles the differences between
233 // reference and resource variables. For reference variables we can just grab
234 // the tensor, grabbing the lock if lock_held is False.
235 //
236 // For resource variables we, if sparse is true, ensure it's in copy-on-read
237 // mode, and then, regardless of the value of sparse, ensure its refcount is 1
238 // (by potentially copying its contents). In this case lock_held is ignored.
239 template <typename Device, typename T>
GetInputTensorFromVariable(OpKernelContext * ctx,int input,bool lock_held,bool sparse,Tensor * out)240 Status GetInputTensorFromVariable(OpKernelContext* ctx, int input,
241                                   bool lock_held, bool sparse, Tensor* out) {
242   if (ctx->input_dtype(input) == DT_RESOURCE) {
243     core::RefCountPtr<Var> var;
244     TF_RETURN_IF_ERROR(LookupResource(ctx, HandleFromInput(ctx, input), &var));
245     if (sparse) {
246       TF_RETURN_IF_ERROR(EnsureSparseVariableAccess<Device, T>(ctx, var.get()));
247       *out = *var->tensor();
248       return Status::OK();
249     }
250     TF_RETURN_IF_ERROR(PrepareToUpdateVariable<Device, T>(
251         ctx, var->tensor(), var->copy_on_read_mode.load()));
252     *out = *var->tensor();
253     return Status::OK();
254   }
255   *out = ctx->mutable_input(input, lock_held);
256   return Status::OK();
257 }
258 
259 }  // end namespace tensorflow
260 
261 #endif  // TENSORFLOW_CORE_KERNELS_TRAINING_OP_HELPERS_H_
262