1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/contrib/gdr/gdr_rendezvous_mgr.h"
17 
18 #include "google/protobuf/any.pb.h"
19 #include "tensorflow/contrib/gdr/gdr_memory_manager.h"
20 #include "tensorflow/core/common_runtime/device.h"
21 #include "tensorflow/core/common_runtime/device_mgr.h"
22 #include "tensorflow/core/common_runtime/process_util.h"
23 #include "tensorflow/core/distributed_runtime/request_id.h"
24 #include "tensorflow/core/distributed_runtime/tensor_coding.h"
25 #include "tensorflow/core/distributed_runtime/worker_cache.h"
26 #include "tensorflow/core/distributed_runtime/worker_interface.h"
27 #include "tensorflow/core/framework/types.h"
28 #include "tensorflow/core/lib/core/errors.h"
29 #include "tensorflow/core/lib/strings/numbers.h"
30 #include "tensorflow/core/lib/strings/str_util.h"
31 #include "tensorflow/core/platform/logging.h"
32 #include "tensorflow/core/platform/macros.h"
33 #include "tensorflow/core/platform/types.h"
34 
35 namespace tensorflow {
36 
37 namespace {
38 
39 class GdrRecvTensorCall : public BaseRecvTensorCall {
40  public:
GdrRecvTensorCall(WorkerInterface * wi,Device * dst_device,RemoteMemoryManager * remote_memory_manager,const Rendezvous::Args & recv_args,int64 step_id,StringPiece key)41   GdrRecvTensorCall(WorkerInterface* wi, Device* dst_device,
42                     RemoteMemoryManager* remote_memory_manager,
43                     const Rendezvous::Args& recv_args, int64 step_id,
44                     StringPiece key)
45       : wi_(wi),
46         dst_device_(dst_device),
47         remote_memory_manager_(remote_memory_manager),
48         recv_args_(recv_args) {
49     req_.set_step_id(step_id);
50     req_.set_rendezvous_key(key.data(), key.size());
51     req_.set_request_id(GetUniqueRequestId());
52   }
53 
~GdrRecvTensorCall()54   ~GdrRecvTensorCall() override {}
55 
Start(std::function<void ()> recv_done)56   void Start(std::function<void()> recv_done) override {
57     req_.set_dma_ok(true);
58     resp_.InitAlloc(dst_device_, recv_args_.alloc_attrs);
59     StatusCallback cb = [this, recv_done](const Status& s) {
60       bool dma_ok = resp_.metadata().has_transport_options();
61       if (s.ok() && tensor().TotalBytes() > 1024 && (!is_dead()) && dma_ok) {
62         auto transport_options = resp_.metadata().transport_options();
63         const bool on_host = recv_args_.alloc_attrs.on_host();
64         remote_memory_manager_->TensorFromTransportOptions(
65             const_cast<Tensor*>(&tensor()), transport_options, dst_device_,
66             recv_args_.device_context, on_host,
67             [this, recv_done](const Status& s) {
68               if (!s.ok()) {
69                 mutex_lock l(mu_);
70                 status_.Update(s);
71               }
72               recv_done();
73             });
74         return;
75       }
76       if (!s.ok()) {
77         mutex_lock l(mu_);
78         status_.Update(s);
79       }
80       recv_done();
81     };
82     wi_->RecvTensorAsync(&opts_, &req_, &resp_, std::move(cb));
83   }
84 
StartAbort(const Status & s)85   void StartAbort(const Status& s) override {
86     {
87       mutex_lock l(mu_);
88       status_.Update(s);
89     }
90     opts_.StartCancel();
91   }
92 
status() const93   Status status() const override {
94     mutex_lock l(mu_);
95     return status_;
96   }
97 
tensor() const98   const Tensor& tensor() const { return resp_.tensor(); }
99 
is_dead() const100   bool is_dead() const { return resp_.metadata().is_dead(); }
101 
dst_device() const102   Device* dst_device() const { return dst_device_; }
103 
recv_args() const104   const Rendezvous::Args& recv_args() const { return recv_args_; }
105 
106  private:
107   WorkerInterface* wi_;
108   Device* dst_device_;
109   RemoteMemoryManager* remote_memory_manager_;
110   CallOptions opts_;
111   RecvTensorRequest req_;
112   TensorResponse resp_;
113   Rendezvous::Args recv_args_;
114 
115   mutable mutex mu_;
116   Status status_ GUARDED_BY(mu_);
117 
118   TF_DISALLOW_COPY_AND_ASSIGN(GdrRecvTensorCall);
119 };
120 
121 class GdrRemoteRendezvous : public BaseRemoteRendezvous {
122  public:
GdrRemoteRendezvous(const WorkerEnv * env,int64 step_id,RemoteMemoryManager * remote_memory_manager)123   GdrRemoteRendezvous(const WorkerEnv* env, int64 step_id,
124                       RemoteMemoryManager* remote_memory_manager)
125       : BaseRemoteRendezvous(env, step_id),
126         remote_memory_manager_(remote_memory_manager) {}
127 
128  protected:
RecvFromRemoteAsync(const Rendezvous::ParsedKey & parsed,const Rendezvous::Args & recv_args,DoneCallback done)129   void RecvFromRemoteAsync(const Rendezvous::ParsedKey& parsed,
130                            const Rendezvous::Args& recv_args,
131                            DoneCallback done) override {
132     CHECK(is_initialized());
133 
134     string src_worker;
135     string src_rel_device;
136     if (!DeviceNameUtils::SplitDeviceName(parsed.src_device, &src_worker,
137                                           &src_rel_device)) {
138       Status s = errors::Internal(parsed.src_device,
139                                   " is invalid remote source device.");
140       done(s, Args(), recv_args, Tensor{}, false);
141       return;
142     }
143 
144     WorkerSession* sess = session();
145     WorkerInterface* rwi = sess->worker_cache->CreateWorker(src_worker);
146     if (rwi == nullptr) {
147       Status s = errors::Internal("No worker known as ", src_worker);
148       done(s, Args(), recv_args, Tensor{}, false);
149       return;
150     }
151 
152     Device* dst_device;
153     Status s = sess->device_mgr()->LookupDevice(parsed.dst_device, &dst_device);
154     if (!s.ok()) {
155       sess->worker_cache->ReleaseWorker(src_worker, rwi);
156       done(s, Args(), recv_args, Tensor{}, false);
157       return;
158     }
159 
160     // Prepare a RecvTensor call that can handle being aborted.
161     GdrRecvTensorCall* call =
162         new GdrRecvTensorCall(rwi, dst_device, remote_memory_manager_,
163                               recv_args, step_id_, parsed.FullKey());
164 
165     // Record "call" in active_ so that it can be aborted cleanly.
166     RegisterCall(call);
167 
168     // RendezvousMgr already aborted, shouldn't send RPC call any more
169     if (!call->status().ok()) {
170       // NOTE: `*session()` can potentially be deleted before we return from
171       // `call->done()(...)`, so we must release the worker before calling the
172       // callback.
173       session()->worker_cache->ReleaseWorker(src_worker, rwi);
174       done(call->status(), Args(), Args(), Tensor(), false);
175       delete call;
176       return;
177     }
178 
179     // Start "call".
180     Ref();
181     call->Start([this, call, src_worker, rwi, done]() {
182       // Removes "call" from active_. Prevent StartAbort().
183       DeregisterCall(call);
184       // If StartAbort was called prior to DeregisterCall, then the
185       // current status should be bad.
186       Status s = call->status();
187       // NOTE: `*session()` can potentially be deleted before we return from
188       // `call->done()(...)`, so we must release the worker before calling the
189       // callback.
190       session()->worker_cache->ReleaseWorker(src_worker, rwi);
191       done(s, Args(), call->recv_args(), call->tensor(), call->is_dead());
192       delete call;
193       Unref();
194     });
195   }
196 
197  private:
~GdrRemoteRendezvous()198   ~GdrRemoteRendezvous() override {}
199 
200   RemoteMemoryManager* remote_memory_manager_;
201 
202   TF_DISALLOW_COPY_AND_ASSIGN(GdrRemoteRendezvous);
203 };
204 
205 }  // namespace
206 
GdrRendezvousMgr(const WorkerEnv * env,RemoteMemoryManager * remote_memory_manager)207 GdrRendezvousMgr::GdrRendezvousMgr(const WorkerEnv* env,
208                                    RemoteMemoryManager* remote_memory_manager)
209     : BaseRendezvousMgr(env), remote_memory_manager_(remote_memory_manager) {}
210 
Create(int64 step_id,const WorkerEnv * worker_env)211 BaseRemoteRendezvous* GdrRendezvousMgr::Create(int64 step_id,
212                                                const WorkerEnv* worker_env) {
213   return new GdrRemoteRendezvous(worker_env, step_id, remote_memory_manager_);
214 }
215 
216 }  // end namespace tensorflow
217