1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/core/distributed_runtime/rpc/grpc_remote_master.h"
17
18 #include <utility>
19
20 #include "tensorflow/core/distributed_runtime/call_options.h"
21 #include "tensorflow/core/distributed_runtime/master_interface.h"
22 #include "tensorflow/core/distributed_runtime/rpc/grpc_master_service_impl.h"
23 #include "tensorflow/core/distributed_runtime/rpc/grpc_util.h"
24 #include "tensorflow/core/lib/core/errors.h"
25 #include "tensorflow/core/lib/core/status.h"
26 #include "tensorflow/core/lib/strings/strcat.h"
27 #include "tensorflow/core/platform/env.h"
28 #include "tensorflow/core/platform/tracing.h"
29 #include "tensorflow/core/protobuf/master.pb.h"
30
31 namespace tensorflow {
32
33 // GrpcRemoteMaster is an implementation of the MasterInterface
34 // that uses gRPC to talk to the Master service.
35 class GrpcRemoteMaster : public MasterInterface {
36 using MasterServiceStub = grpc::MasterService::Stub;
37
38 public:
GrpcRemoteMaster(const SharedGrpcChannelPtr & client_channel)39 explicit GrpcRemoteMaster(const SharedGrpcChannelPtr& client_channel)
40 : stub_(grpc::MasterService::NewStub(client_channel)) {}
41
~GrpcRemoteMaster()42 ~GrpcRemoteMaster() override {}
43
CreateSession(CallOptions * call_options,const CreateSessionRequest * request,CreateSessionResponse * response)44 Status CreateSession(CallOptions* call_options,
45 const CreateSessionRequest* request,
46 CreateSessionResponse* response) override {
47 return CallWithRetry(call_options, request, response,
48 &MasterServiceStub::CreateSession);
49 }
50
ExtendSession(CallOptions * call_options,const ExtendSessionRequest * request,ExtendSessionResponse * response)51 Status ExtendSession(CallOptions* call_options,
52 const ExtendSessionRequest* request,
53 ExtendSessionResponse* response) override {
54 return CallWithRetry(call_options, request, response,
55 &MasterServiceStub::ExtendSession);
56 }
57
PartialRunSetup(CallOptions * call_options,const PartialRunSetupRequest * request,PartialRunSetupResponse * response)58 Status PartialRunSetup(CallOptions* call_options,
59 const PartialRunSetupRequest* request,
60 PartialRunSetupResponse* response) override {
61 return CallWithRetry(call_options, request, response,
62 &MasterServiceStub::PartialRunSetup);
63 }
64
RunStep(CallOptions * call_options,RunStepRequestWrapper * request,MutableRunStepResponseWrapper * response)65 Status RunStep(CallOptions* call_options, RunStepRequestWrapper* request,
66 MutableRunStepResponseWrapper* response) override {
67 return CallWithRetry(call_options, &request->ToProto(),
68 get_proto_from_wrapper(response),
69 &MasterServiceStub::RunStep, "RunStep/Client");
70 }
71
CloseSession(CallOptions * call_options,const CloseSessionRequest * request,CloseSessionResponse * response)72 Status CloseSession(CallOptions* call_options,
73 const CloseSessionRequest* request,
74 CloseSessionResponse* response) override {
75 return CallWithRetry(call_options, request, response,
76 &MasterServiceStub::CloseSession);
77 }
78
ListDevices(CallOptions * call_options,const ListDevicesRequest * request,ListDevicesResponse * response)79 Status ListDevices(CallOptions* call_options,
80 const ListDevicesRequest* request,
81 ListDevicesResponse* response) override {
82 return CallWithRetry(call_options, request, response,
83 &MasterServiceStub::ListDevices);
84 }
85
Reset(CallOptions * call_options,const ResetRequest * request,ResetResponse * response)86 Status Reset(CallOptions* call_options, const ResetRequest* request,
87 ResetResponse* response) override {
88 return CallWithRetry(call_options, request, response,
89 &MasterServiceStub::Reset);
90 }
91
MakeCallable(CallOptions * call_options,const MakeCallableRequest * request,MakeCallableResponse * response)92 Status MakeCallable(CallOptions* call_options,
93 const MakeCallableRequest* request,
94 MakeCallableResponse* response) override {
95 return CallWithRetry(call_options, request, response,
96 &MasterServiceStub::MakeCallable);
97 }
RunCallable(CallOptions * call_options,const RunCallableRequest * request,RunCallableResponse * response)98 Status RunCallable(CallOptions* call_options,
99 const RunCallableRequest* request,
100 RunCallableResponse* response) override {
101 return CallWithRetry(call_options, request, response,
102 &MasterServiceStub::RunCallable);
103 }
ReleaseCallable(CallOptions * call_options,const ReleaseCallableRequest * request,ReleaseCallableResponse * response)104 Status ReleaseCallable(CallOptions* call_options,
105 const ReleaseCallableRequest* request,
106 ReleaseCallableResponse* response) override {
107 return CallWithRetry(call_options, request, response,
108 &MasterServiceStub::ReleaseCallable);
109 }
110
111 private:
112 // Start tracing, attaching a unique ID to both the trace and the RPC.
NewTraceRpc(StringPiece name,::grpc::ClientContext * ctx)113 tracing::ScopedActivity* NewTraceRpc(StringPiece name,
114 ::grpc::ClientContext* ctx) {
115 string trace_id = strings::StrCat(tracing::GetUniqueArg());
116 ctx->AddMetadata(GrpcIdKey(), trace_id);
117 return new tracing::ScopedActivity(name, trace_id);
118 }
119
120 template <typename Request, typename Response>
CallWithRetry(CallOptions * call_options,const Request * request,Response * response,::grpc::Status (MasterServiceStub::* pfunc)(::grpc::ClientContext *,const Request &,Response *),string trace_string={})121 Status CallWithRetry(CallOptions* call_options, const Request* request,
122 Response* response,
123 ::grpc::Status (MasterServiceStub::*pfunc)(
124 ::grpc::ClientContext*, const Request&, Response*),
125 string trace_string = {}) {
126 int64 timeout_in_ms = call_options->GetTimeout();
127 int64 expired_time_micros = Env::Default()->NowMicros();
128 if (timeout_in_ms > 0) {
129 expired_time_micros += (timeout_in_ms / 1000.);
130 }
131 Status s;
132 for (int num_retries = 0;; ++num_retries) {
133 ::grpc::ClientContext ctx;
134 std::unique_ptr<tracing::ScopedActivity> trace;
135 if (!trace_string.empty()) {
136 trace.reset(NewTraceRpc(trace_string, &ctx));
137 }
138 ctx.set_fail_fast(false);
139 if (timeout_in_ms > 0) {
140 // We do not modify the timeout here to match legacy behavior. However,
141 // this could violate the contract of tensorflow::Session. If we retry
142 // an RPC just before the deadline is exceeded, we will still set the
143 // timeout to the original value. This leads to the overall timeout
144 // being double what was expected.
145 // TODO(b/117162170): investigate fixing this behavior for legacy and
146 // gRPC RPC layers.
147 ctx.set_deadline(gpr_time_from_millis(timeout_in_ms, GPR_TIMESPAN));
148 }
149 s = FromGrpcStatus((stub_.get()->*pfunc)(&ctx, *request, response));
150 if (!errors::IsUnavailable(s)) {
151 return s;
152 }
153 // TODO(b/117162170): we may want to make this configurable.
154 constexpr int kMaxRetries = 10;
155 LOG(WARNING) << "RPC failed with status = \"" << s
156 << "\" and grpc_error_string = \""
157 << ctx.debug_error_string() << "\", maybe retrying the RPC";
158 if (num_retries >= kMaxRetries) {
159 LOG(WARNING) << "Too many retries, returning last status: " << s;
160 return s;
161 }
162 const int64 now_micros = Env::Default()->NowMicros();
163 const int64 deadline_with_backoff_micros =
164 now_micros + ComputeBackoffMicroseconds(num_retries);
165 // Wait for a short period of time before retrying the RPC. If our
166 // backoff would put us past the RPC deadline, we truncate it to ensure
167 // our RPC starts before the deadline.
168 const auto backoff_until =
169 (timeout_in_ms <= 0 ||
170 expired_time_micros > deadline_with_backoff_micros)
171 ? deadline_with_backoff_micros
172 : expired_time_micros;
173 Env::Default()->SleepForMicroseconds(backoff_until - now_micros);
174 if (Env::Default()->NowMicros() > expired_time_micros &&
175 timeout_in_ms > 0) {
176 // If timeout_in_ms is set, exit the retry loop on timeout.
177 return errors::DeadlineExceeded(ctx.debug_error_string());
178 }
179 }
180 }
181
182 std::unique_ptr<MasterServiceStub> stub_;
183 };
184
NewGrpcMaster(const SharedGrpcChannelPtr & channel)185 MasterInterface* NewGrpcMaster(const SharedGrpcChannelPtr& channel) {
186 return new GrpcRemoteMaster(channel);
187 }
188
189 } // namespace tensorflow
190