1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/distributed_runtime/rpc/grpc_remote_master.h"
17 
18 #include <utility>
19 
20 #include "tensorflow/core/distributed_runtime/call_options.h"
21 #include "tensorflow/core/distributed_runtime/master_interface.h"
22 #include "tensorflow/core/distributed_runtime/rpc/grpc_master_service_impl.h"
23 #include "tensorflow/core/distributed_runtime/rpc/grpc_util.h"
24 #include "tensorflow/core/lib/core/errors.h"
25 #include "tensorflow/core/lib/core/status.h"
26 #include "tensorflow/core/lib/strings/strcat.h"
27 #include "tensorflow/core/platform/env.h"
28 #include "tensorflow/core/platform/tracing.h"
29 #include "tensorflow/core/protobuf/master.pb.h"
30 
31 namespace tensorflow {
32 
33 // GrpcRemoteMaster is an implementation of the MasterInterface
34 // that uses gRPC to talk to the Master service.
35 class GrpcRemoteMaster : public MasterInterface {
36   using MasterServiceStub = grpc::MasterService::Stub;
37 
38  public:
GrpcRemoteMaster(const SharedGrpcChannelPtr & client_channel)39   explicit GrpcRemoteMaster(const SharedGrpcChannelPtr& client_channel)
40       : stub_(grpc::MasterService::NewStub(client_channel)) {}
41 
~GrpcRemoteMaster()42   ~GrpcRemoteMaster() override {}
43 
CreateSession(CallOptions * call_options,const CreateSessionRequest * request,CreateSessionResponse * response)44   Status CreateSession(CallOptions* call_options,
45                        const CreateSessionRequest* request,
46                        CreateSessionResponse* response) override {
47     return CallWithRetry(call_options, request, response,
48                          &MasterServiceStub::CreateSession);
49   }
50 
ExtendSession(CallOptions * call_options,const ExtendSessionRequest * request,ExtendSessionResponse * response)51   Status ExtendSession(CallOptions* call_options,
52                        const ExtendSessionRequest* request,
53                        ExtendSessionResponse* response) override {
54     return CallWithRetry(call_options, request, response,
55                          &MasterServiceStub::ExtendSession);
56   }
57 
PartialRunSetup(CallOptions * call_options,const PartialRunSetupRequest * request,PartialRunSetupResponse * response)58   Status PartialRunSetup(CallOptions* call_options,
59                          const PartialRunSetupRequest* request,
60                          PartialRunSetupResponse* response) override {
61     return CallWithRetry(call_options, request, response,
62                          &MasterServiceStub::PartialRunSetup);
63   }
64 
RunStep(CallOptions * call_options,RunStepRequestWrapper * request,MutableRunStepResponseWrapper * response)65   Status RunStep(CallOptions* call_options, RunStepRequestWrapper* request,
66                  MutableRunStepResponseWrapper* response) override {
67     return CallWithRetry(call_options, &request->ToProto(),
68                          get_proto_from_wrapper(response),
69                          &MasterServiceStub::RunStep, "RunStep/Client");
70   }
71 
CloseSession(CallOptions * call_options,const CloseSessionRequest * request,CloseSessionResponse * response)72   Status CloseSession(CallOptions* call_options,
73                       const CloseSessionRequest* request,
74                       CloseSessionResponse* response) override {
75     return CallWithRetry(call_options, request, response,
76                          &MasterServiceStub::CloseSession);
77   }
78 
ListDevices(CallOptions * call_options,const ListDevicesRequest * request,ListDevicesResponse * response)79   Status ListDevices(CallOptions* call_options,
80                      const ListDevicesRequest* request,
81                      ListDevicesResponse* response) override {
82     return CallWithRetry(call_options, request, response,
83                          &MasterServiceStub::ListDevices);
84   }
85 
Reset(CallOptions * call_options,const ResetRequest * request,ResetResponse * response)86   Status Reset(CallOptions* call_options, const ResetRequest* request,
87                ResetResponse* response) override {
88     return CallWithRetry(call_options, request, response,
89                          &MasterServiceStub::Reset);
90   }
91 
MakeCallable(CallOptions * call_options,const MakeCallableRequest * request,MakeCallableResponse * response)92   Status MakeCallable(CallOptions* call_options,
93                       const MakeCallableRequest* request,
94                       MakeCallableResponse* response) override {
95     return CallWithRetry(call_options, request, response,
96                          &MasterServiceStub::MakeCallable);
97   }
RunCallable(CallOptions * call_options,const RunCallableRequest * request,RunCallableResponse * response)98   Status RunCallable(CallOptions* call_options,
99                      const RunCallableRequest* request,
100                      RunCallableResponse* response) override {
101     return CallWithRetry(call_options, request, response,
102                          &MasterServiceStub::RunCallable);
103   }
ReleaseCallable(CallOptions * call_options,const ReleaseCallableRequest * request,ReleaseCallableResponse * response)104   Status ReleaseCallable(CallOptions* call_options,
105                          const ReleaseCallableRequest* request,
106                          ReleaseCallableResponse* response) override {
107     return CallWithRetry(call_options, request, response,
108                          &MasterServiceStub::ReleaseCallable);
109   }
110 
111  private:
112   // Start tracing, attaching a unique ID to both the trace and the RPC.
NewTraceRpc(StringPiece name,::grpc::ClientContext * ctx)113   tracing::ScopedActivity* NewTraceRpc(StringPiece name,
114                                        ::grpc::ClientContext* ctx) {
115     string trace_id = strings::StrCat(tracing::GetUniqueArg());
116     ctx->AddMetadata(GrpcIdKey(), trace_id);
117     return new tracing::ScopedActivity(name, trace_id);
118   }
119 
120   template <typename Request, typename Response>
CallWithRetry(CallOptions * call_options,const Request * request,Response * response,::grpc::Status (MasterServiceStub::* pfunc)(::grpc::ClientContext *,const Request &,Response *),string trace_string={})121   Status CallWithRetry(CallOptions* call_options, const Request* request,
122                        Response* response,
123                        ::grpc::Status (MasterServiceStub::*pfunc)(
124                            ::grpc::ClientContext*, const Request&, Response*),
125                        string trace_string = {}) {
126     int64 timeout_in_ms = call_options->GetTimeout();
127     int64 expired_time_micros = Env::Default()->NowMicros();
128     if (timeout_in_ms > 0) {
129       expired_time_micros += (timeout_in_ms / 1000.);
130     }
131     Status s;
132     for (int num_retries = 0;; ++num_retries) {
133       ::grpc::ClientContext ctx;
134       std::unique_ptr<tracing::ScopedActivity> trace;
135       if (!trace_string.empty()) {
136         trace.reset(NewTraceRpc(trace_string, &ctx));
137       }
138       ctx.set_fail_fast(false);
139       if (timeout_in_ms > 0) {
140         // We do not modify the timeout here to match legacy behavior. However,
141         // this could violate the contract of tensorflow::Session. If we retry
142         // an RPC just before the deadline is exceeded, we will still set the
143         // timeout to the original value. This leads to the overall timeout
144         // being double what was expected.
145         // TODO(b/117162170): investigate fixing this behavior for legacy and
146         // gRPC RPC layers.
147         ctx.set_deadline(gpr_time_from_millis(timeout_in_ms, GPR_TIMESPAN));
148       }
149       s = FromGrpcStatus((stub_.get()->*pfunc)(&ctx, *request, response));
150       if (!errors::IsUnavailable(s)) {
151         return s;
152       }
153       // TODO(b/117162170): we may want to make this configurable.
154       constexpr int kMaxRetries = 10;
155       LOG(WARNING) << "RPC failed with status = \"" << s
156                    << "\" and grpc_error_string = \""
157                    << ctx.debug_error_string() << "\", maybe retrying the RPC";
158       if (num_retries >= kMaxRetries) {
159         LOG(WARNING) << "Too many retries, returning last status: " << s;
160         return s;
161       }
162       const int64 now_micros = Env::Default()->NowMicros();
163       const int64 deadline_with_backoff_micros =
164           now_micros + ComputeBackoffMicroseconds(num_retries);
165       // Wait for a short period of time before retrying the RPC.  If our
166       // backoff would put us past the RPC deadline, we truncate it to ensure
167       // our RPC starts before the deadline.
168       const auto backoff_until =
169           (timeout_in_ms <= 0 ||
170            expired_time_micros > deadline_with_backoff_micros)
171               ? deadline_with_backoff_micros
172               : expired_time_micros;
173       Env::Default()->SleepForMicroseconds(backoff_until - now_micros);
174       if (Env::Default()->NowMicros() > expired_time_micros &&
175           timeout_in_ms > 0) {
176         // If timeout_in_ms is set, exit the retry loop on timeout.
177         return errors::DeadlineExceeded(ctx.debug_error_string());
178       }
179     }
180   }
181 
182   std::unique_ptr<MasterServiceStub> stub_;
183 };
184 
NewGrpcMaster(const SharedGrpcChannelPtr & channel)185 MasterInterface* NewGrpcMaster(const SharedGrpcChannelPtr& channel) {
186   return new GrpcRemoteMaster(channel);
187 }
188 
189 }  // namespace tensorflow
190