1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #define EIGEN_USE_THREADS 17 18 #include <deque> 19 #include <utility> 20 21 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" 22 #include "tensorflow/core/framework/resource_mgr.h" 23 #include "tensorflow/core/framework/variant.h" 24 #include "tensorflow/core/framework/variant_encode_decode.h" 25 #include "tensorflow/core/kernels/ops_util.h" 26 #include "tensorflow/core/lib/core/errors.h" 27 #include "tensorflow/core/lib/core/threadpool.h" 28 #include "tensorflow/core/platform/macros.h" 29 #include "tensorflow/core/platform/mutex.h" 30 #include "tensorflow/core/platform/types.h" 31 32 namespace tensorflow { 33 34 namespace { 35 36 class Mutex : public ResourceBase { 37 public: Mutex(OpKernelContext * c,const string & name)38 explicit Mutex(OpKernelContext* c, const string& name) 39 : locked_(false), 40 thread_pool_(new thread::ThreadPool( 41 c->env(), ThreadOptions(), 42 strings::StrCat("mutex_lock_thread_", SanitizeThreadSuffix(name)), 43 1 /* num_threads */, false /* low_latency_hint */)), 44 name_(name) { 45 VLOG(2) << "Creating mutex with name " << name << ": " << this; 46 } 47 DebugString() const48 string DebugString() const override { 49 return strings::StrCat("Mutex ", name_); 50 } 51 52 class LockReleaser { 53 public: LockReleaser(Mutex * mutex)54 explicit LockReleaser(Mutex* mutex) : mutex_(mutex) {} 55 56 LockReleaser(const LockReleaser&) = delete; 57 LockReleaser& operator=(const LockReleaser&) = delete; 58 ~LockReleaser()59 virtual ~LockReleaser() { 60 VLOG(3) << "Destroying LockReleaser " << this << " for mutex: " << mutex_; 61 if (mutex_) { 62 mutex_lock lock(mutex_->mu_); 63 mutex_->locked_ = false; 64 mutex_->cv_.notify_all(); 65 VLOG(3) << "Destroying LockReleaser " << this 66 << ": sent notifications."; 67 } 68 } 69 70 private: 71 Mutex* mutex_; 72 }; 73 74 struct SharedLockReleaser { 75 std::shared_ptr<LockReleaser> shared_lock; 76 SharedLockReleasertensorflow::__anon606fee680111::Mutex::SharedLockReleaser77 explicit SharedLockReleaser(std::shared_ptr<LockReleaser>&& lock) 78 : shared_lock(std::forward<decltype(lock)>(lock)) { 79 VLOG(3) << "Creating shared_ptr of " << shared_lock.get() 80 << " count is: " << shared_lock.use_count(); 81 } 82 SharedLockReleasertensorflow::__anon606fee680111::Mutex::SharedLockReleaser83 SharedLockReleaser(SharedLockReleaser&& rhs) 84 : shared_lock(std::move(rhs.shared_lock)) { 85 VLOG(3) << "Moving SharedLockReleaser of " << shared_lock.get() 86 << " count is: " << shared_lock.use_count(); 87 } 88 SharedLockReleasertensorflow::__anon606fee680111::Mutex::SharedLockReleaser89 SharedLockReleaser(const SharedLockReleaser& rhs) 90 : shared_lock(rhs.shared_lock) { 91 VLOG(3) << "Copying SharedLockReleaser of " << shared_lock.get() 92 << " count is: " << shared_lock.use_count(); 93 } 94 ~SharedLockReleasertensorflow::__anon606fee680111::Mutex::SharedLockReleaser95 ~SharedLockReleaser() { 96 VLOG(3) << "Destroying SharedLockReleaser of " << shared_lock.get() 97 << " count is: " << shared_lock.use_count(); 98 } 99 Encodetensorflow::__anon606fee680111::Mutex::SharedLockReleaser100 void Encode(VariantTensorData*) const { 101 // Not supported. 102 } 103 Decodetensorflow::__anon606fee680111::Mutex::SharedLockReleaser104 bool Decode(const VariantTensorData&) { 105 return false; // Not supported. 106 } 107 }; 108 AcquireAsync(OpKernelContext * c,std::function<void (const Status & s,SharedLockReleaser lock)> fn)109 void AcquireAsync( 110 OpKernelContext* c, 111 std::function<void(const Status& s, SharedLockReleaser lock)> fn) { 112 CancellationManager* cm = c->cancellation_manager(); 113 CancellationToken token{}; 114 bool* cancelled = nullptr; 115 if (cm) { 116 cancelled = new bool(false); // GUARDED_BY(mu_); 117 token = cm->get_cancellation_token(); 118 const bool already_cancelled = 119 !cm->RegisterCallback(token, [this, cancelled]() { 120 mutex_lock lock(mu_); 121 *cancelled = true; 122 cv_.notify_all(); 123 }); 124 if (already_cancelled) { 125 delete cancelled; 126 fn(errors::Cancelled("Lock acquisition cancelled."), 127 SharedLockReleaser{nullptr}); 128 return; 129 } 130 } 131 thread_pool_->Schedule(std::bind( 132 [this, cm, cancelled, 133 token](std::function<void(const Status& s, SharedLockReleaser&& lock)> 134 fn_) { 135 bool local_locked; 136 { 137 mutex_lock lock(mu_); 138 while (locked_ && !(cancelled && *cancelled)) { 139 cv_.wait(lock); 140 } 141 local_locked = locked_ = !(cancelled && *cancelled); 142 } 143 if (cm) { 144 cm->DeregisterCallback(token); 145 delete cancelled; 146 } 147 if (local_locked) { // Not cancelled. 148 fn_(Status::OK(), 149 SharedLockReleaser{std::make_shared<LockReleaser>(this)}); 150 } else { 151 fn_(errors::Cancelled("Lock acquisition cancelled."), 152 SharedLockReleaser{nullptr}); 153 } 154 }, 155 std::move(fn))); 156 } 157 158 private: 159 mutex mu_; 160 condition_variable cv_ GUARDED_BY(mu_); 161 bool locked_ GUARDED_BY(mu_); 162 std::unique_ptr<thread::ThreadPool> thread_pool_; 163 string name_; 164 }; 165 166 } // namespace 167 168 class MutexLockOp : public AsyncOpKernel { 169 public: MutexLockOp(OpKernelConstruction * c)170 explicit MutexLockOp(OpKernelConstruction* c) : AsyncOpKernel(c) {} 171 172 public: ComputeAsync(OpKernelContext * c,DoneCallback done)173 void ComputeAsync(OpKernelContext* c, DoneCallback done) override { 174 Mutex* mutex = nullptr; 175 OP_REQUIRES_OK_ASYNC( 176 c, 177 LookupOrCreateResource<Mutex>(c, HandleFromInput(c, 0), &mutex, 178 [c](Mutex** ptr) { 179 *ptr = new Mutex( 180 c, HandleFromInput(c, 0).name()); 181 return Status::OK(); 182 }), 183 done); 184 185 Tensor* variant; 186 OP_REQUIRES_OK_ASYNC(c, c->allocate_output(0, TensorShape({}), &variant), 187 done); 188 189 mutex->AcquireAsync( 190 c, std::bind( 191 [c, variant, mutex](DoneCallback done_, 192 // End of bound arguments. 193 const Status& s, 194 Mutex::SharedLockReleaser&& lock) { 195 VLOG(2) << "Finished locking mutex " << mutex 196 << " with lock: " << lock.shared_lock.get() 197 << " status: " << s.ToString(); 198 if (s.ok()) { 199 variant->scalar<Variant>()() = std::move(lock); 200 } else { 201 c->SetStatus(s); 202 } 203 mutex->Unref(); 204 done_(); 205 }, 206 std::move(done), std::placeholders::_1, std::placeholders::_2)); 207 } 208 }; 209 210 class ConsumeMutexLockOp : public OpKernel { 211 public: ConsumeMutexLockOp(OpKernelConstruction * context)212 explicit ConsumeMutexLockOp(OpKernelConstruction* context) 213 : OpKernel(context) {} 214 Compute(OpKernelContext * c)215 void Compute(OpKernelContext* c) override { 216 VLOG(2) << "Executing ConsumeMutexLockOp"; 217 const Tensor& lock_t = c->input(0); 218 OP_REQUIRES( 219 c, lock_t.dims() == 0, 220 errors::InvalidArgument("Expected input to be a scalar, saw shape: ", 221 lock_t.shape().DebugString())); 222 OP_REQUIRES( 223 c, lock_t.dtype() == DT_VARIANT, 224 errors::InvalidArgument("Expected input to be a variant, saw type: ", 225 DataTypeString(lock_t.dtype()))); 226 const auto* lock = 227 lock_t.scalar<Variant>()().get<Mutex::SharedLockReleaser>(); 228 OP_REQUIRES(c, lock, 229 errors::InvalidArgument( 230 "Expected input to contain a SharedLockReleaser " 231 "object, but saw variant: '", 232 lock_t.scalar<Variant>()().DebugString(), "'")); 233 const int use_count = lock->shared_lock.use_count(); 234 OP_REQUIRES( 235 c, use_count == 1, 236 errors::InvalidArgument("Expected use count of lock to be 1, but saw: ", 237 use_count)); 238 } 239 IsExpensive()240 bool IsExpensive() override { return false; } 241 }; 242 243 REGISTER_KERNEL_BUILDER(Name("MutexLock").Device(DEVICE_CPU), MutexLockOp); 244 245 REGISTER_KERNEL_BUILDER(Name("MutexLock") 246 .Device(DEVICE_GPU) 247 .HostMemory("mutex_lock") 248 .HostMemory("mutex"), 249 MutexLockOp); 250 251 REGISTER_KERNEL_BUILDER( 252 Name("MutexV2").Device(DEVICE_CPU).HostMemory("resource"), 253 ResourceHandleOp<Mutex>); 254 255 REGISTER_KERNEL_BUILDER(Name("MutexV2").Device(DEVICE_GPU), 256 ResourceHandleOp<Mutex>); 257 258 REGISTER_KERNEL_BUILDER(Name("ConsumeMutexLock").Device(DEVICE_CPU), 259 ConsumeMutexLockOp); 260 261 REGISTER_KERNEL_BUILDER( 262 Name("ConsumeMutexLock").Device(DEVICE_GPU).HostMemory("mutex_lock"), 263 ConsumeMutexLockOp); 264 265 } // namespace tensorflow 266