1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/core/distributed_runtime/rpc/grpc_remote_worker.h"
17
18 #include <utility>
19
20 #include "grpcpp/generic/generic_stub.h"
21 #include "grpcpp/grpcpp.h"
22
23 #include "tensorflow/core/common_runtime/process_util.h"
24 #include "tensorflow/core/distributed_runtime/rpc/grpc_client_cq_tag.h"
25 #include "tensorflow/core/distributed_runtime/rpc/grpc_state.h"
26 #include "tensorflow/core/distributed_runtime/rpc/grpc_util.h"
27 #include "tensorflow/core/distributed_runtime/rpc/grpc_worker_service_impl.h"
28 #include "tensorflow/core/distributed_runtime/tensor_coding.h"
29 #include "tensorflow/core/distributed_runtime/worker_cache_logger.h"
30 #include "tensorflow/core/distributed_runtime/worker_interface.h"
31 #include "tensorflow/core/lib/core/errors.h"
32 #include "tensorflow/core/lib/core/status.h"
33 #include "tensorflow/core/lib/core/threadpool.h"
34 #include "tensorflow/core/lib/strings/str_util.h"
35 #include "tensorflow/core/platform/logging.h"
36 #include "tensorflow/core/platform/tracing.h"
37 #include "tensorflow/core/protobuf/transport_options.pb.h"
38 #include "tensorflow/core/protobuf/worker.pb.h"
39
40 namespace tensorflow {
41
42 const int kMaxWorkerRpcRetries = 10;
43
44 class GrpcRemoteWorker : public WorkerInterface {
45 public:
GrpcRemoteWorker(SharedGrpcChannelPtr channel,::grpc::CompletionQueue * completion_queue,thread::ThreadPool * callback_threadpool,WorkerCacheLogger * logger)46 explicit GrpcRemoteWorker(SharedGrpcChannelPtr channel,
47 ::grpc::CompletionQueue* completion_queue,
48 thread::ThreadPool* callback_threadpool,
49 WorkerCacheLogger* logger)
50 : channel_(std::move(channel)),
51 stub_(channel_),
52 cq_(completion_queue),
53 callback_threadpool_(callback_threadpool),
54 getstatus_(Method(GrpcWorkerMethod::kGetStatus)),
55 createworkersession_(Method(GrpcWorkerMethod::kCreateWorkerSession)),
56 deleteworkersession_(Method(GrpcWorkerMethod::kDeleteWorkerSession)),
57 registergraph_(Method(GrpcWorkerMethod::kRegisterGraph)),
58 deregistergraph_(Method(GrpcWorkerMethod::kDeregisterGraph)),
59 rungraph_(Method(GrpcWorkerMethod::kRunGraph)),
60 cleanupgraph_(Method(GrpcWorkerMethod::kCleanupGraph)),
61 cleanupall_(Method(GrpcWorkerMethod::kCleanupAll)),
62 recvtensor_(Method(GrpcWorkerMethod::kRecvTensor)),
63 recvbuf_(Method(GrpcWorkerMethod::kRecvBuf)),
64 logging_(Method(GrpcWorkerMethod::kLogging)),
65 tracing_(Method(GrpcWorkerMethod::kTracing)),
66 completegroup_(Method(GrpcWorkerMethod::kCompleteGroup)),
67 instancesource_(Method(GrpcWorkerMethod::kCompleteInstance)),
68 getstepsequence_(Method(GrpcWorkerMethod::kGetStepSequence)),
69 logger_(logger) {}
70
~GrpcRemoteWorker()71 ~GrpcRemoteWorker() override {}
72
GetStatusAsync(const GetStatusRequest * request,GetStatusResponse * response,StatusCallback done)73 void GetStatusAsync(const GetStatusRequest* request,
74 GetStatusResponse* response,
75 StatusCallback done) override {
76 IssueRequest(request, response, getstatus_, std::move(done));
77 }
78
CreateWorkerSessionAsync(const CreateWorkerSessionRequest * request,CreateWorkerSessionResponse * response,StatusCallback done)79 void CreateWorkerSessionAsync(const CreateWorkerSessionRequest* request,
80 CreateWorkerSessionResponse* response,
81 StatusCallback done) override {
82 IssueRequest(request, response, createworkersession_, std::move(done));
83 }
84
DeleteWorkerSessionAsync(CallOptions * call_opts,const DeleteWorkerSessionRequest * request,DeleteWorkerSessionResponse * response,StatusCallback done)85 void DeleteWorkerSessionAsync(CallOptions* call_opts,
86 const DeleteWorkerSessionRequest* request,
87 DeleteWorkerSessionResponse* response,
88 StatusCallback done) override {
89 IssueRequest(request, response, deleteworkersession_, std::move(done),
90 call_opts);
91 }
92
RegisterGraphAsync(const RegisterGraphRequest * request,RegisterGraphResponse * response,StatusCallback done)93 void RegisterGraphAsync(const RegisterGraphRequest* request,
94 RegisterGraphResponse* response,
95 StatusCallback done) override {
96 IssueRequest(request, response, registergraph_, std::move(done));
97 }
98
DeregisterGraphAsync(const DeregisterGraphRequest * request,DeregisterGraphResponse * response,StatusCallback done)99 void DeregisterGraphAsync(const DeregisterGraphRequest* request,
100 DeregisterGraphResponse* response,
101 StatusCallback done) override {
102 IssueRequest(request, response, deregistergraph_, std::move(done));
103 }
104
RunGraphAsync(CallOptions * call_opts,const RunGraphRequest * request,RunGraphResponse * response,StatusCallback done)105 void RunGraphAsync(CallOptions* call_opts, const RunGraphRequest* request,
106 RunGraphResponse* response, StatusCallback done) override {
107 IssueRequest(request, response, rungraph_, std::move(done), call_opts);
108 }
RunGraphAsync(CallOptions * call_opts,RunGraphRequestWrapper * request,MutableRunGraphResponseWrapper * response,StatusCallback done)109 void RunGraphAsync(CallOptions* call_opts, RunGraphRequestWrapper* request,
110 MutableRunGraphResponseWrapper* response,
111 StatusCallback done) override {
112 IssueRequest(&request->ToProto(), get_proto_from_wrapper(response),
113 rungraph_, std::move(done), call_opts);
114 }
115
CleanupGraphAsync(const CleanupGraphRequest * request,CleanupGraphResponse * response,StatusCallback done)116 void CleanupGraphAsync(const CleanupGraphRequest* request,
117 CleanupGraphResponse* response,
118 StatusCallback done) override {
119 IssueRequest(request, response, cleanupgraph_, std::move(done));
120 }
121
CleanupAllAsync(const CleanupAllRequest * request,CleanupAllResponse * response,StatusCallback done)122 void CleanupAllAsync(const CleanupAllRequest* request,
123 CleanupAllResponse* response,
124 StatusCallback done) override {
125 IssueRequest(request, response, cleanupall_, std::move(done));
126 }
127
RecvBufAsync(CallOptions * call_opts,const RecvBufRequest * request,RecvBufResponse * response,StatusCallback done)128 void RecvBufAsync(CallOptions* call_opts, const RecvBufRequest* request,
129 RecvBufResponse* response, StatusCallback done) override {
130 int64 start_usec = Env::Default()->NowMicros();
131 // Type-specialized logging for this method.
132 bool logging_active = logger_->LoggingActive() || VLOG_IS_ON(2);
133 StatusCallback wrapper_done;
134 const StatusCallback* cb_to_use;
135 if (!logging_active) {
136 cb_to_use = &done; // No additional work to do, so just use done directly
137 } else {
138 wrapper_done = [this, request, response, done, start_usec](Status s) {
139 if (logger_->LoggingActive()) {
140 int64 end_usec = Env::Default()->NowMicros();
141 int64 step_id = request->step_id();
142 RecvBufRespExtra extra;
143 response->transport_options().UnpackTo(&extra);
144 int64 num_bytes = 0;
145 for (const auto& chunk : extra.tensor_content()) {
146 num_bytes += chunk.size();
147 }
148 int64 send_start_usec = start_usec;
149 // Prefer start time reported by the sender, if available.
150 if (response->send_start_micros()) {
151 send_start_usec = std::max(
152 start_usec, static_cast<int64>(response->send_start_micros()));
153 send_start_usec = std::min(send_start_usec, end_usec - 1);
154 }
155 const string& key = request->buf_rendezvous_key();
156 logger_->RecordDataTransfer(
157 step_id, send_start_usec, end_usec, key, request->src_device(),
158 request->dst_device(), num_bytes, "", "RecvBuf");
159 }
160 VLOG(2) << "done callback, req: " << request->DebugString()
161 << " response " << response->DebugString();
162 done(s);
163 };
164 cb_to_use = &wrapper_done;
165 }
166
167 IssueRequest(request, response, recvbuf_, *cb_to_use, call_opts);
168 }
169
CompleteGroupAsync(CallOptions * call_opts,const CompleteGroupRequest * request,CompleteGroupResponse * response,StatusCallback done)170 void CompleteGroupAsync(CallOptions* call_opts,
171 const CompleteGroupRequest* request,
172 CompleteGroupResponse* response,
173 StatusCallback done) override {
174 IssueRequest(request, response, completegroup_, std::move(done), call_opts);
175 }
176
CompleteInstanceAsync(CallOptions * call_opts,const CompleteInstanceRequest * request,CompleteInstanceResponse * response,StatusCallback done)177 void CompleteInstanceAsync(CallOptions* call_opts,
178 const CompleteInstanceRequest* request,
179 CompleteInstanceResponse* response,
180 StatusCallback done) override {
181 IssueRequest(request, response, instancesource_, std::move(done),
182 call_opts);
183 }
184
GetStepSequenceAsync(const GetStepSequenceRequest * request,GetStepSequenceResponse * response,StatusCallback done)185 void GetStepSequenceAsync(const GetStepSequenceRequest* request,
186 GetStepSequenceResponse* response,
187 StatusCallback done) override {
188 IssueRequest(request, response, getstepsequence_, std::move(done));
189 }
190
RecvTensorAsync(CallOptions * call_opts,const RecvTensorRequest * request,TensorResponse * response,StatusCallback done)191 void RecvTensorAsync(CallOptions* call_opts, const RecvTensorRequest* request,
192 TensorResponse* response, StatusCallback done) override {
193 VLOG(1) << "RecvTensorAsync req: " << request->DebugString();
194 int64 start_usec = Env::Default()->NowMicros();
195 // Type-specialized logging for this method.
196 bool logging_active = logger_->LoggingActive() || VLOG_IS_ON(2);
197 StatusCallback wrapper_done;
198 const StatusCallback* cb_to_use;
199 if (!logging_active) {
200 cb_to_use = &done; // No additional work to do, so just use done directly
201 } else {
202 wrapper_done = [this, request, response, done, start_usec](Status s) {
203 if (logger_->LoggingActive()) {
204 int64 end_usec = Env::Default()->NowMicros();
205 int64 step_id = request->step_id();
206 int64 bytes = response->tensor().TotalBytes();
207 int64 send_start_usec = start_usec;
208 // If a send start time was reported by the other side, use
209 // that instead. Maybe we should mark the display if we're using
210 // our local time instead of the remote start time?
211 if (response->metadata().send_start_micros()) {
212 // send_start_micros is the timestamp taken when the
213 // remote machine began to send the RecvTensor response.
214 // Due to clock skew between source and dest machines, it
215 // is possible that send_start_micros can be larger than
216 // end_usec or less than start_usec.
217 //
218 // To respect causality, we enforce the invariants that
219 // the RecvTensor response can not have been sent before
220 // the RecvTensor request, and must have been sent before
221 // it was received.
222 send_start_usec = std::max(
223 start_usec,
224 static_cast<int64>(response->metadata().send_start_micros()));
225 send_start_usec = std::min(send_start_usec, end_usec - 1);
226 }
227 const string& key = request->rendezvous_key();
228 std::vector<string> key_parts = str_util::Split(key, ';');
229 if (key_parts.size() != 5) {
230 LOG(WARNING) << "Bad key: " << key;
231 } else {
232 logger_->RecordRecvTensor(step_id, send_start_usec, end_usec,
233 key_parts[3], // tensor name
234 key_parts[0], // src_device
235 key_parts[2], // dst_device
236 bytes);
237 }
238 }
239 VLOG(2) << "done callback, req: " << request->DebugString()
240 << " response " << response->metadata().DebugString();
241 done(s);
242 };
243 cb_to_use = &wrapper_done;
244 }
245
246 IssueRequest(request, response, recvtensor_, *cb_to_use, call_opts);
247 }
248
LoggingAsync(const LoggingRequest * request,LoggingResponse * response,StatusCallback done)249 void LoggingAsync(const LoggingRequest* request, LoggingResponse* response,
250 StatusCallback done) override {
251 IssueRequest(request, response, logging_, done);
252 }
253
TracingAsync(const TracingRequest * request,TracingResponse * response,StatusCallback done)254 void TracingAsync(const TracingRequest* request, TracingResponse* response,
255 StatusCallback done) override {
256 IssueRequest(request, response, tracing_, done);
257 }
258
259 private:
260 // Utility method for issuing a generic asynchronous request. The
261 // given callback, `done`, will be called when the RPC completes.
IssueRequest(const protobuf::Message * request,protobuf::Message * response,const::grpc::string & method,StatusCallback done,CallOptions * call_opts=nullptr,int max_retries=kMaxWorkerRpcRetries)262 void IssueRequest(const protobuf::Message* request,
263 protobuf::Message* response, const ::grpc::string& method,
264 StatusCallback done, CallOptions* call_opts = nullptr,
265 int max_retries = kMaxWorkerRpcRetries) {
266 new RPCState<protobuf::Message>(&stub_, cq_, method, *request, response,
267 std::move(done), call_opts,
268 callback_threadpool_, max_retries);
269 }
IssueRequest(const protobuf::Message * request,TensorResponse * response,const::grpc::string & method,StatusCallback done,CallOptions * call_opts=nullptr,int max_retries=kMaxWorkerRpcRetries)270 void IssueRequest(const protobuf::Message* request, TensorResponse* response,
271 const ::grpc::string& method, StatusCallback done,
272 CallOptions* call_opts = nullptr,
273 int max_retries = kMaxWorkerRpcRetries) {
274 new RPCState<TensorResponse>(&stub_, cq_, method, *request, response,
275 std::move(done), call_opts,
276 callback_threadpool_, max_retries);
277 }
278
279 // Helper function for initializing the RpcMethod objects below.
Method(GrpcWorkerMethod id)280 const char* Method(GrpcWorkerMethod id) { return GrpcWorkerMethodName(id); }
281
282 SharedGrpcChannelPtr channel_;
283 ::grpc::GenericStub stub_;
284 ::grpc::CompletionQueue* cq_;
285 thread::ThreadPool* callback_threadpool_;
286
287 const ::grpc::string getstatus_;
288 const ::grpc::string createworkersession_;
289 const ::grpc::string deleteworkersession_;
290 const ::grpc::string registergraph_;
291 const ::grpc::string deregistergraph_;
292 const ::grpc::string rungraph_;
293 const ::grpc::string cleanupgraph_;
294 const ::grpc::string cleanupall_;
295 const ::grpc::string recvtensor_;
296 const ::grpc::string recvbuf_;
297 const ::grpc::string logging_;
298 const ::grpc::string tracing_;
299 const ::grpc::string completegroup_;
300 const ::grpc::string instancesource_;
301 const ::grpc::string getstepsequence_;
302
303 // Support for logging.
304 WorkerCacheLogger* logger_;
305
306 TF_DISALLOW_COPY_AND_ASSIGN(GrpcRemoteWorker);
307 };
308
NewGrpcRemoteWorker(SharedGrpcChannelPtr channel,::grpc::CompletionQueue * completion_queue,thread::ThreadPool * callback_threadpool,WorkerCacheLogger * logger)309 WorkerInterface* NewGrpcRemoteWorker(SharedGrpcChannelPtr channel,
310 ::grpc::CompletionQueue* completion_queue,
311 thread::ThreadPool* callback_threadpool,
312 WorkerCacheLogger* logger) {
313 return new GrpcRemoteWorker(std::move(channel), completion_queue,
314 callback_threadpool, logger);
315 }
316
317 } // namespace tensorflow
318