1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #define EIGEN_USE_THREADS
17 
18 #include <deque>
19 #include <utility>
20 
21 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
22 #include "tensorflow/core/framework/resource_mgr.h"
23 #include "tensorflow/core/framework/variant.h"
24 #include "tensorflow/core/framework/variant_encode_decode.h"
25 #include "tensorflow/core/kernels/ops_util.h"
26 #include "tensorflow/core/lib/core/errors.h"
27 #include "tensorflow/core/lib/core/threadpool.h"
28 #include "tensorflow/core/platform/macros.h"
29 #include "tensorflow/core/platform/mutex.h"
30 #include "tensorflow/core/platform/types.h"
31 
32 namespace tensorflow {
33 
34 namespace {
35 
36 class Mutex : public ResourceBase {
37  public:
Mutex(OpKernelContext * c,const string & name)38   explicit Mutex(OpKernelContext* c, const string& name)
39       : locked_(false),
40         thread_pool_(new thread::ThreadPool(
41             c->env(), ThreadOptions(),
42             strings::StrCat("mutex_lock_thread_", SanitizeThreadSuffix(name)),
43             1 /* num_threads */, false /* low_latency_hint */)),
44         name_(name) {
45     VLOG(2) << "Creating mutex with name " << name << ": " << this;
46   }
47 
DebugString() const48   string DebugString() const override {
49     return strings::StrCat("Mutex ", name_);
50   }
51 
52   class LockReleaser {
53    public:
LockReleaser(Mutex * mutex)54     explicit LockReleaser(Mutex* mutex) : mutex_(mutex) {}
55 
56     LockReleaser(const LockReleaser&) = delete;
57     LockReleaser& operator=(const LockReleaser&) = delete;
58 
~LockReleaser()59     virtual ~LockReleaser() {
60       VLOG(3) << "Destroying LockReleaser " << this << " for mutex: " << mutex_;
61       if (mutex_) {
62         mutex_lock lock(mutex_->mu_);
63         mutex_->locked_ = false;
64         mutex_->cv_.notify_all();
65         VLOG(3) << "Destroying LockReleaser " << this
66                 << ": sent notifications.";
67       }
68     }
69 
70    private:
71     Mutex* mutex_;
72   };
73 
74   struct SharedLockReleaser {
75     std::shared_ptr<LockReleaser> shared_lock;
76 
SharedLockReleasertensorflow::__anon606fee680111::Mutex::SharedLockReleaser77     explicit SharedLockReleaser(std::shared_ptr<LockReleaser>&& lock)
78         : shared_lock(std::forward<decltype(lock)>(lock)) {
79       VLOG(3) << "Creating shared_ptr of " << shared_lock.get()
80               << " count is: " << shared_lock.use_count();
81     }
82 
SharedLockReleasertensorflow::__anon606fee680111::Mutex::SharedLockReleaser83     SharedLockReleaser(SharedLockReleaser&& rhs)
84         : shared_lock(std::move(rhs.shared_lock)) {
85       VLOG(3) << "Moving SharedLockReleaser of " << shared_lock.get()
86               << " count is: " << shared_lock.use_count();
87     }
88 
SharedLockReleasertensorflow::__anon606fee680111::Mutex::SharedLockReleaser89     SharedLockReleaser(const SharedLockReleaser& rhs)
90         : shared_lock(rhs.shared_lock) {
91       VLOG(3) << "Copying SharedLockReleaser of " << shared_lock.get()
92               << " count is: " << shared_lock.use_count();
93     }
94 
~SharedLockReleasertensorflow::__anon606fee680111::Mutex::SharedLockReleaser95     ~SharedLockReleaser() {
96       VLOG(3) << "Destroying SharedLockReleaser of " << shared_lock.get()
97               << " count is: " << shared_lock.use_count();
98     }
99 
Encodetensorflow::__anon606fee680111::Mutex::SharedLockReleaser100     void Encode(VariantTensorData*) const {
101       // Not supported.
102     }
103 
Decodetensorflow::__anon606fee680111::Mutex::SharedLockReleaser104     bool Decode(const VariantTensorData&) {
105       return false;  // Not supported.
106     }
107   };
108 
AcquireAsync(OpKernelContext * c,std::function<void (const Status & s,SharedLockReleaser lock)> fn)109   void AcquireAsync(
110       OpKernelContext* c,
111       std::function<void(const Status& s, SharedLockReleaser lock)> fn) {
112     CancellationManager* cm = c->cancellation_manager();
113     CancellationToken token{};
114     bool* cancelled = nullptr;
115     if (cm) {
116       cancelled = new bool(false);  // GUARDED_BY(mu_);
117       token = cm->get_cancellation_token();
118       const bool already_cancelled =
119           !cm->RegisterCallback(token, [this, cancelled]() {
120             mutex_lock lock(mu_);
121             *cancelled = true;
122             cv_.notify_all();
123           });
124       if (already_cancelled) {
125         delete cancelled;
126         fn(errors::Cancelled("Lock acquisition cancelled."),
127            SharedLockReleaser{nullptr});
128         return;
129       }
130     }
131     thread_pool_->Schedule(std::bind(
132         [this, cm, cancelled,
133          token](std::function<void(const Status& s, SharedLockReleaser&& lock)>
134                     fn_) {
135           bool local_locked;
136           {
137             mutex_lock lock(mu_);
138             while (locked_ && !(cancelled && *cancelled)) {
139               cv_.wait(lock);
140             }
141             local_locked = locked_ = !(cancelled && *cancelled);
142           }
143           if (cm) {
144             cm->DeregisterCallback(token);
145             delete cancelled;
146           }
147           if (local_locked) {  // Not cancelled.
148             fn_(Status::OK(),
149                 SharedLockReleaser{std::make_shared<LockReleaser>(this)});
150           } else {
151             fn_(errors::Cancelled("Lock acquisition cancelled."),
152                 SharedLockReleaser{nullptr});
153           }
154         },
155         std::move(fn)));
156   }
157 
158  private:
159   mutex mu_;
160   condition_variable cv_ GUARDED_BY(mu_);
161   bool locked_ GUARDED_BY(mu_);
162   std::unique_ptr<thread::ThreadPool> thread_pool_;
163   string name_;
164 };
165 
166 }  // namespace
167 
168 class MutexLockOp : public AsyncOpKernel {
169  public:
MutexLockOp(OpKernelConstruction * c)170   explicit MutexLockOp(OpKernelConstruction* c) : AsyncOpKernel(c) {}
171 
172  public:
ComputeAsync(OpKernelContext * c,DoneCallback done)173   void ComputeAsync(OpKernelContext* c, DoneCallback done) override {
174     Mutex* mutex = nullptr;
175     OP_REQUIRES_OK_ASYNC(
176         c,
177         LookupOrCreateResource<Mutex>(c, HandleFromInput(c, 0), &mutex,
178                                       [c](Mutex** ptr) {
179                                         *ptr = new Mutex(
180                                             c, HandleFromInput(c, 0).name());
181                                         return Status::OK();
182                                       }),
183         done);
184 
185     Tensor* variant;
186     OP_REQUIRES_OK_ASYNC(c, c->allocate_output(0, TensorShape({}), &variant),
187                          done);
188 
189     mutex->AcquireAsync(
190         c, std::bind(
191                [c, variant, mutex](DoneCallback done_,
192                                    // End of bound arguments.
193                                    const Status& s,
194                                    Mutex::SharedLockReleaser&& lock) {
195                  VLOG(2) << "Finished locking mutex " << mutex
196                          << " with lock: " << lock.shared_lock.get()
197                          << " status: " << s.ToString();
198                  if (s.ok()) {
199                    variant->scalar<Variant>()() = std::move(lock);
200                  } else {
201                    c->SetStatus(s);
202                  }
203                  mutex->Unref();
204                  done_();
205                },
206                std::move(done), std::placeholders::_1, std::placeholders::_2));
207   }
208 };
209 
210 class ConsumeMutexLockOp : public OpKernel {
211  public:
ConsumeMutexLockOp(OpKernelConstruction * context)212   explicit ConsumeMutexLockOp(OpKernelConstruction* context)
213       : OpKernel(context) {}
214 
Compute(OpKernelContext * c)215   void Compute(OpKernelContext* c) override {
216     VLOG(2) << "Executing ConsumeMutexLockOp";
217     const Tensor& lock_t = c->input(0);
218     OP_REQUIRES(
219         c, lock_t.dims() == 0,
220         errors::InvalidArgument("Expected input to be a scalar, saw shape: ",
221                                 lock_t.shape().DebugString()));
222     OP_REQUIRES(
223         c, lock_t.dtype() == DT_VARIANT,
224         errors::InvalidArgument("Expected input to be a variant, saw type: ",
225                                 DataTypeString(lock_t.dtype())));
226     const auto* lock =
227         lock_t.scalar<Variant>()().get<Mutex::SharedLockReleaser>();
228     OP_REQUIRES(c, lock,
229                 errors::InvalidArgument(
230                     "Expected input to contain a SharedLockReleaser "
231                     "object, but saw variant: '",
232                     lock_t.scalar<Variant>()().DebugString(), "'"));
233     const int use_count = lock->shared_lock.use_count();
234     OP_REQUIRES(
235         c, use_count == 1,
236         errors::InvalidArgument("Expected use count of lock to be 1, but saw: ",
237                                 use_count));
238   }
239 
IsExpensive()240   bool IsExpensive() override { return false; }
241 };
242 
243 REGISTER_KERNEL_BUILDER(Name("MutexLock").Device(DEVICE_CPU), MutexLockOp);
244 
245 REGISTER_KERNEL_BUILDER(Name("MutexLock")
246                             .Device(DEVICE_GPU)
247                             .HostMemory("mutex_lock")
248                             .HostMemory("mutex"),
249                         MutexLockOp);
250 
251 REGISTER_KERNEL_BUILDER(
252     Name("MutexV2").Device(DEVICE_CPU).HostMemory("resource"),
253     ResourceHandleOp<Mutex>);
254 
255 REGISTER_KERNEL_BUILDER(Name("MutexV2").Device(DEVICE_GPU),
256                         ResourceHandleOp<Mutex>);
257 
258 REGISTER_KERNEL_BUILDER(Name("ConsumeMutexLock").Device(DEVICE_CPU),
259                         ConsumeMutexLockOp);
260 
261 REGISTER_KERNEL_BUILDER(
262     Name("ConsumeMutexLock").Device(DEVICE_GPU).HostMemory("mutex_lock"),
263     ConsumeMutexLockOp);
264 
265 }  // namespace tensorflow
266