1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/contrib/gdr/gdr_rendezvous_mgr.h"
17
18 #include "google/protobuf/any.pb.h"
19 #include "tensorflow/contrib/gdr/gdr_memory_manager.h"
20 #include "tensorflow/core/common_runtime/device.h"
21 #include "tensorflow/core/common_runtime/device_mgr.h"
22 #include "tensorflow/core/common_runtime/process_util.h"
23 #include "tensorflow/core/distributed_runtime/request_id.h"
24 #include "tensorflow/core/distributed_runtime/tensor_coding.h"
25 #include "tensorflow/core/distributed_runtime/worker_cache.h"
26 #include "tensorflow/core/distributed_runtime/worker_interface.h"
27 #include "tensorflow/core/framework/types.h"
28 #include "tensorflow/core/lib/core/errors.h"
29 #include "tensorflow/core/lib/strings/numbers.h"
30 #include "tensorflow/core/lib/strings/str_util.h"
31 #include "tensorflow/core/platform/logging.h"
32 #include "tensorflow/core/platform/macros.h"
33 #include "tensorflow/core/platform/types.h"
34
35 namespace tensorflow {
36
37 namespace {
38
39 class GdrRecvTensorCall : public BaseRecvTensorCall {
40 public:
GdrRecvTensorCall(WorkerInterface * wi,Device * dst_device,RemoteMemoryManager * remote_memory_manager,const Rendezvous::Args & recv_args,int64 step_id,StringPiece key)41 GdrRecvTensorCall(WorkerInterface* wi, Device* dst_device,
42 RemoteMemoryManager* remote_memory_manager,
43 const Rendezvous::Args& recv_args, int64 step_id,
44 StringPiece key)
45 : wi_(wi),
46 dst_device_(dst_device),
47 remote_memory_manager_(remote_memory_manager),
48 recv_args_(recv_args) {
49 req_.set_step_id(step_id);
50 req_.set_rendezvous_key(key.data(), key.size());
51 req_.set_request_id(GetUniqueRequestId());
52 }
53
~GdrRecvTensorCall()54 ~GdrRecvTensorCall() override {}
55
Start(std::function<void ()> recv_done)56 void Start(std::function<void()> recv_done) override {
57 req_.set_dma_ok(true);
58 resp_.InitAlloc(dst_device_, recv_args_.alloc_attrs);
59 StatusCallback cb = [this, recv_done](const Status& s) {
60 bool dma_ok = resp_.metadata().has_transport_options();
61 if (s.ok() && tensor().TotalBytes() > 1024 && (!is_dead()) && dma_ok) {
62 auto transport_options = resp_.metadata().transport_options();
63 const bool on_host = recv_args_.alloc_attrs.on_host();
64 remote_memory_manager_->TensorFromTransportOptions(
65 const_cast<Tensor*>(&tensor()), transport_options, dst_device_,
66 recv_args_.device_context, on_host,
67 [this, recv_done](const Status& s) {
68 if (!s.ok()) {
69 mutex_lock l(mu_);
70 status_.Update(s);
71 }
72 recv_done();
73 });
74 return;
75 }
76 if (!s.ok()) {
77 mutex_lock l(mu_);
78 status_.Update(s);
79 }
80 recv_done();
81 };
82 wi_->RecvTensorAsync(&opts_, &req_, &resp_, std::move(cb));
83 }
84
StartAbort(const Status & s)85 void StartAbort(const Status& s) override {
86 {
87 mutex_lock l(mu_);
88 status_.Update(s);
89 }
90 opts_.StartCancel();
91 }
92
status() const93 Status status() const override {
94 mutex_lock l(mu_);
95 return status_;
96 }
97
tensor() const98 const Tensor& tensor() const { return resp_.tensor(); }
99
is_dead() const100 bool is_dead() const { return resp_.metadata().is_dead(); }
101
dst_device() const102 Device* dst_device() const { return dst_device_; }
103
recv_args() const104 const Rendezvous::Args& recv_args() const { return recv_args_; }
105
106 private:
107 WorkerInterface* wi_;
108 Device* dst_device_;
109 RemoteMemoryManager* remote_memory_manager_;
110 CallOptions opts_;
111 RecvTensorRequest req_;
112 TensorResponse resp_;
113 Rendezvous::Args recv_args_;
114
115 mutable mutex mu_;
116 Status status_ GUARDED_BY(mu_);
117
118 TF_DISALLOW_COPY_AND_ASSIGN(GdrRecvTensorCall);
119 };
120
121 class GdrRemoteRendezvous : public BaseRemoteRendezvous {
122 public:
GdrRemoteRendezvous(const WorkerEnv * env,int64 step_id,RemoteMemoryManager * remote_memory_manager)123 GdrRemoteRendezvous(const WorkerEnv* env, int64 step_id,
124 RemoteMemoryManager* remote_memory_manager)
125 : BaseRemoteRendezvous(env, step_id),
126 remote_memory_manager_(remote_memory_manager) {}
127
128 protected:
RecvFromRemoteAsync(const Rendezvous::ParsedKey & parsed,const Rendezvous::Args & recv_args,DoneCallback done)129 void RecvFromRemoteAsync(const Rendezvous::ParsedKey& parsed,
130 const Rendezvous::Args& recv_args,
131 DoneCallback done) override {
132 CHECK(is_initialized());
133
134 string src_worker;
135 string src_rel_device;
136 if (!DeviceNameUtils::SplitDeviceName(parsed.src_device, &src_worker,
137 &src_rel_device)) {
138 Status s = errors::Internal(parsed.src_device,
139 " is invalid remote source device.");
140 done(s, Args(), recv_args, Tensor{}, false);
141 return;
142 }
143
144 WorkerSession* sess = session();
145 WorkerInterface* rwi = sess->worker_cache->CreateWorker(src_worker);
146 if (rwi == nullptr) {
147 Status s = errors::Internal("No worker known as ", src_worker);
148 done(s, Args(), recv_args, Tensor{}, false);
149 return;
150 }
151
152 Device* dst_device;
153 Status s = sess->device_mgr()->LookupDevice(parsed.dst_device, &dst_device);
154 if (!s.ok()) {
155 sess->worker_cache->ReleaseWorker(src_worker, rwi);
156 done(s, Args(), recv_args, Tensor{}, false);
157 return;
158 }
159
160 // Prepare a RecvTensor call that can handle being aborted.
161 GdrRecvTensorCall* call =
162 new GdrRecvTensorCall(rwi, dst_device, remote_memory_manager_,
163 recv_args, step_id_, parsed.FullKey());
164
165 // Record "call" in active_ so that it can be aborted cleanly.
166 RegisterCall(call);
167
168 // RendezvousMgr already aborted, shouldn't send RPC call any more
169 if (!call->status().ok()) {
170 // NOTE: `*session()` can potentially be deleted before we return from
171 // `call->done()(...)`, so we must release the worker before calling the
172 // callback.
173 session()->worker_cache->ReleaseWorker(src_worker, rwi);
174 done(call->status(), Args(), Args(), Tensor(), false);
175 delete call;
176 return;
177 }
178
179 // Start "call".
180 Ref();
181 call->Start([this, call, src_worker, rwi, done]() {
182 // Removes "call" from active_. Prevent StartAbort().
183 DeregisterCall(call);
184 // If StartAbort was called prior to DeregisterCall, then the
185 // current status should be bad.
186 Status s = call->status();
187 // NOTE: `*session()` can potentially be deleted before we return from
188 // `call->done()(...)`, so we must release the worker before calling the
189 // callback.
190 session()->worker_cache->ReleaseWorker(src_worker, rwi);
191 done(s, Args(), call->recv_args(), call->tensor(), call->is_dead());
192 delete call;
193 Unref();
194 });
195 }
196
197 private:
~GdrRemoteRendezvous()198 ~GdrRemoteRendezvous() override {}
199
200 RemoteMemoryManager* remote_memory_manager_;
201
202 TF_DISALLOW_COPY_AND_ASSIGN(GdrRemoteRendezvous);
203 };
204
205 } // namespace
206
GdrRendezvousMgr(const WorkerEnv * env,RemoteMemoryManager * remote_memory_manager)207 GdrRendezvousMgr::GdrRendezvousMgr(const WorkerEnv* env,
208 RemoteMemoryManager* remote_memory_manager)
209 : BaseRendezvousMgr(env), remote_memory_manager_(remote_memory_manager) {}
210
Create(int64 step_id,const WorkerEnv * worker_env)211 BaseRemoteRendezvous* GdrRendezvousMgr::Create(int64 step_id,
212 const WorkerEnv* worker_env) {
213 return new GdrRemoteRendezvous(worker_env, step_id, remote_memory_manager_);
214 }
215
216 } // end namespace tensorflow
217