1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_GRPC_STATE_H_
17 #define TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_GRPC_STATE_H_
18 
19 #include <queue>
20 #include <utility>
21 
22 #include "grpcpp/generic/generic_stub.h"
23 #include "grpcpp/grpcpp.h"
24 #include "tensorflow/core/distributed_runtime/call_options.h"
25 #include "tensorflow/core/distributed_runtime/rpc/grpc_client_cq_tag.h"
26 #include "tensorflow/core/distributed_runtime/rpc/grpc_util.h"
27 #include "tensorflow/core/distributed_runtime/tensor_coding.h"
28 #include "tensorflow/core/lib/core/refcount.h"
29 #include "tensorflow/core/lib/core/status.h"
30 #include "tensorflow/core/lib/core/threadpool.h"
31 #include "tensorflow/core/lib/strings/strcat.h"
32 #include "tensorflow/core/platform/mutex.h"
33 #include "tensorflow/core/platform/notification.h"
34 #include "tensorflow/core/util/env_var.h"
35 
36 namespace tensorflow {
37 
38 // Object allocated per active RPC.
39 // Manage the state of a single asynchronous RPC request.  If `max_retries`
40 // is greater than 0, the request will be retried for any transient failures.
41 template <class Response>
42 class RPCState : public GrpcClientCQTag {
43  public:
44   RPCState(::grpc::GenericStub* stub, ::grpc::CompletionQueue* cq,
45            const ::grpc::string& method, const protobuf::Message& request,
46            Response* response, StatusCallback done, CallOptions* call_opts,
47            thread::ThreadPool* threadpool, int32 max_retries = 0,
48            bool fail_fast = true, const string* target = nullptr)
49       : RPCState(
50             stub, cq, method, request, response, std::move(done), call_opts,
51             threadpool,
52             // 1) If GRPC_FAIL_FAST is set to 'true' or 'false',
53             // fail_fast=$GRPC_FAIL_FAST. See b/141948186.
54             // 2) Otherwise if GRPC_FAIL_FAST is set to 'use_caller', use the
55             // fail_fast from the caller. See b/140260119.
56             //
57             // Current default for PLATFORM_GOOGLE: use caller fail_fast;
58             // Current default for open source: fail_fast=false.
59             //
60             // NOTE: Callers mostly set fail_fast=true to prevent job hanging
61             // on worker task failures, except a few cases such as GetStatus
62             // in cluster initialization and collective param resolution.
63             [fail_fast, &done]() -> bool {
64               string fail_fast_env;
65 #if defined(PLATFORM_GOOGLE)
66               TF_CHECK_OK(ReadStringFromEnvVar("GRPC_FAIL_FAST", "use_caller",
67                                                &fail_fast_env));
68 #else
69               TF_CHECK_OK(ReadStringFromEnvVar("GRPC_FAIL_FAST", "false",
70                                                &fail_fast_env));
71 #endif  // PLATFORM_GOOGLE
72               string fail_fast_env_lower = absl::AsciiStrToLower(fail_fast_env);
73               if (fail_fast_env_lower == "true") {
74                 return true;
75               } else if (fail_fast_env_lower == "use_caller") {
76                 return fail_fast;
77               } else if (fail_fast_env_lower == "false") {
78                 return false;
79               } else {
80                 string error_message = strings::StrCat(
81                     "Invalid GRPC_FAIL_FAST config: ", fail_fast_env);
82                 LOG(WARNING) << error_message;
83                 done(errors::InvalidArgument(error_message));
84                 return false;
85               }
86             }(),
87             (call_opts != nullptr ? call_opts->GetTimeout() : 0), max_retries,
88             target) {
89   }
90 
91   template <typename Request>
RPCState(::grpc::GenericStub * stub,::grpc::CompletionQueue * cq,const::grpc::string & method,const Request & request,Response * response,StatusCallback done,CallOptions * call_opts,thread::ThreadPool * threadpool,bool fail_fast,int64 timeout_in_ms,int32 max_retries,const string * target)92   RPCState(::grpc::GenericStub* stub, ::grpc::CompletionQueue* cq,
93            const ::grpc::string& method, const Request& request,
94            Response* response, StatusCallback done, CallOptions* call_opts,
95            thread::ThreadPool* threadpool, bool fail_fast, int64 timeout_in_ms,
96            int32 max_retries, const string* target)
97       : call_opts_(call_opts),
98         threadpool_(threadpool),
99         done_(std::move(done)),
100         timeout_in_ms_(timeout_in_ms),
101         max_retries_(max_retries),
102         cq_(cq),
103         stub_(stub),
104         method_(method),
105         fail_fast_(fail_fast),
106         target_(target) {
107     response_ = response;
108     ::grpc::Status s = GrpcMaybeUnparseProto(request, &request_buf_);
109     if (!s.ok()) {
110       LOG(ERROR) << "GrpcMaybeUnparseProto returned with non-ok status: "
111                  << s.error_message();
112       // Skip retry logic if we fail to parse our request.
113       done_(FromGrpcStatus(s));
114       delete this;
115       return;
116     }
117     StartCall();
118   }
119 
StartCall()120   void StartCall() {
121     context_.reset(new ::grpc::ClientContext());
122     context_->set_wait_for_ready(!fail_fast_);
123     if (timeout_in_ms_ > 0) {
124       context_->set_deadline(
125           gpr_time_from_millis(timeout_in_ms_, GPR_TIMESPAN));
126     }
127     if (call_opts_) {
128       call_opts_->SetCancelCallback([this]() { context_->TryCancel(); });
129     }
130 
131     VLOG(2) << "Starting call: " << method_;
132 
133     call_ = stub_->PrepareUnaryCall(context_.get(), method_, request_buf_, cq_);
134     call_->StartCall();
135     call_->Finish(&response_buf_, &status_, this);
136   }
137 
OnCompleted(bool ok)138   void OnCompleted(bool ok) override {
139     if (call_opts_) {
140       call_opts_->ClearCancelCallback();
141     }
142 
143     VLOG(2) << "Completed call: " << method_;
144 
145     Status s = FromGrpcStatus(status_);
146     if (s.ok() && !ok) {
147       // Since this function is only being used for processing the response
148       // to Finish for client-side unary calls, ok should never be false
149       s.Update(
150           errors::Internal("GRPC status is okay but CompletionQueueStatus is "
151                            "not.  This should never happen."));
152     }
153 
154     if (s.ok()) {
155       if (threadpool_) {
156         // Run parse and callback in another thread, returning this
157         // one to service more RPCs.
158         threadpool_->Schedule([this]() { ParseAndCallDone(); });
159       } else {
160         ParseAndCallDone();
161       }
162       return;
163     }
164 
165     VLOG(1) << method_ << " returned with non-ok status: " << s
166             << " Retries: " << num_retries_ << " Max: " << max_retries_ << "\n"
167             << context_->debug_error_string();
168     // Retry if we have any attempts left
169     if (++num_retries_ <= max_retries_ &&
170         (errors::IsUnavailable(s) || errors::IsUnknown(s))) {
171       response_buf_.Clear();
172       VLOG(1) << "Retrying call for " << method_ << "Retry: " << num_retries_
173               << " of " << max_retries_;
174       // TODO(b/139945426) Allow user to configure the retry backoff time.
175       StartCall();
176     } else {
177       // Attach additional GRPC error information if any to the final status
178       string error_msg = s.error_message();
179       strings::StrAppend(&error_msg, "\nAdditional GRPC error information");
180       if (target_) {
181         strings::StrAppend(&error_msg, " from remote target ", *target_);
182       }
183       strings::StrAppend(&error_msg, ":\n:", context_->debug_error_string());
184       s = Status(s.code(), error_msg);
185       // Always treat gRPC cancellation as a derived error. This ensures that
186       // other error types are preferred during status aggregation. (gRPC
187       // cancellation messages do not contain the original status message).
188       if (s.code() == tensorflow::error::Code::CANCELLED) {
189         s = StatusGroup::MakeDerived(s);
190       }
191 
192       done_(s);
193       delete this;
194     }
195   }
196 
ParseAndCallDone()197   void ParseAndCallDone() {
198     Status s;
199     if (!GrpcMaybeParseProto(&response_buf_, response_)) {
200       s.Update(errors::Internal("could not parse rpc response"));
201     }
202     done_(s);
203     delete this;
204   }
205 
206  private:
207   CallOptions* call_opts_;
208   std::unique_ptr<::grpc::ClientContext> context_;
209   thread::ThreadPool* threadpool_;
210   std::unique_ptr<::grpc::GenericClientAsyncResponseReader> call_;
211   Response* response_;
212   ::grpc::ByteBuffer request_buf_;
213   ::grpc::ByteBuffer response_buf_;
214   ::grpc::Status status_;
215   StatusCallback done_;
216   int64 timeout_in_ms_;
217 
218   size_t num_retries_ = 0;
219   size_t max_retries_;
220 
221   ::grpc::CompletionQueue* cq_;
222   ::grpc::GenericStub* stub_;
223   ::grpc::string method_;
224   bool fail_fast_;
225   const string* target_;
226 };
227 
228 // Represents state associated with one streaming RPC call.
229 // Similarly to above, we extract the methods of StreamingRPCState that don't
230 // need to be templated into this abstract class.
231 // Currently, *StreamingRPCState does not support client closing the call as
232 // there is no use case for it - current clients keep the streaming call open
233 // as long as possible. If/when the need arises, support can be added
234 // by calling GenericClientAsyncReaderWriter::WritesDone with a new tag
235 // TagType::kClientFinished and handling the completion in a new callback.
236 class UntypedStreamingRPCState : public core::RefCounted {
237  public:
238   virtual void CallStarted(bool ok) = 0;
239   virtual void RequestWriteCompleted(bool ok) = 0;
240   virtual void ResponseReadCompleted(bool ok) = 0;
241   virtual void CallFinished(bool ok) = 0;
242 
243   virtual string DebugString() const = 0;
244 
245   class Tag : public GrpcClientCQTag {
246    public:
247     // One enum value per supported callback.
248     enum class TagType {
249       kCallStarted,
250       kRequestWriteCompleted,
251       kResponseReadCompleted,
252       kCallFinished,
253     };
254 
255     Tag(UntypedStreamingRPCState* streaming_state, Tag::TagType type);
256 
257     // Calls the callback associated with this tag and Unrefs
258     // `this->streaming_state_`.
259     void OnCompleted(bool ok) override;
260 
261    private:
262     // OnCompleted() consumes on reference each time it is called.
263     UntypedStreamingRPCState* const streaming_state_;
264     const Tag::TagType type_;
265   };
266 };
267 
268 const char* ToString(UntypedStreamingRPCState::Tag::TagType tag_type);
269 
270 // Represents a single request/response exchange between client and the server.
271 // A single streaming call contains a sequence of exchanges. Besides the
272 // messages, exchange contains:
273 //  - the user callback to invoke when exchange completes (response is received
274 //    or an error occurs).
275 //  - The current state of the exchange.
276 class Exchange {
277  public:
278   enum class State {
279     kExchangeCreated,
280     kRequestWriteIssued,
281     kRequestWriteCompleted,
282     kResponseReadIssued,
283   };
284 
Exchange(const::grpc::ByteBuffer & request_buf,protobuf::Message * response,StatusCallback cb,string debug_string)285   Exchange(const ::grpc::ByteBuffer& request_buf, protobuf::Message* response,
286            StatusCallback cb, string debug_string)
287       : state_(State::kExchangeCreated),
288         request_buf_(request_buf),
289         response_(response),
290         cb_(std::move(cb)),
291         debug_string_(std::move(debug_string)) {}
292 
request_buf()293   const ::grpc::ByteBuffer& request_buf() { return request_buf_; }
response_buf()294   ::grpc::ByteBuffer* response_buf() { return &response_buf_; }
295 
MarkRequestWriteIssued()296   void MarkRequestWriteIssued() {
297     DCHECK(state_ == State::kExchangeCreated);
298     state_ = State::kRequestWriteIssued;
299   }
MarkRequestWriteCompleted()300   void MarkRequestWriteCompleted() {
301     DCHECK(state_ == State::kRequestWriteIssued);
302     state_ = State::kRequestWriteCompleted;
303   }
MarkResponseReadIssued()304   void MarkResponseReadIssued() {
305     DCHECK(state_ == State::kRequestWriteCompleted);
306     state_ = State::kResponseReadIssued;
307   }
308 
309   // If `status` is success, completes this exchange by parsing the
310   // response_buf_ and invoking cb_ with Status::OK(). Else, invokes the
311   // callback with `status`.
312   void Complete(Status status);
313 
state()314   const State& state() const { return state_; }
315 
316   string DebugString() const;
317 
318  private:
319   State state_;
320   ::grpc::ByteBuffer request_buf_;
321   ::grpc::ByteBuffer response_buf_;
322   protobuf::Message* response_;
323   StatusCallback cb_;
324   string debug_string_;
325 };
326 
327 const char* ToString(Exchange::State s);
328 
329 std::ostream& operator<<(std::ostream& os, const Exchange::State& state);
330 
331 // Represents a queue of exchanges.
332 // When a client sends a new request a new exchange is created and added to the
333 // end of the queue. Completed exchanges are popped from the front of the queue.
334 // An explicit exchange queue is needed to brdige the client, which can send new
335 // requests at any time, with gRPC infrastructure, which can handle a single
336 // read and a single write request at a time.
337 //
338 // As the exchange progresses (request sending initiated, request sending
339 // completed, response reading initiated) the queue helps to make sure that the
340 // right operation is issued on the right exchange at the right time.
341 //
342 // To satisfy gRPC constraints, the states of exchanges must be as follows
343 // starting from the front of the queue:
344 //  - 0 or 1 exchange in kResponseReadIssued state
345 //  - 0 or more exchanges in kRequestWriteCompleted state
346 //  - 0 or 1 exchange in kRequestWriteIssued state
347 //  - 0 or more exchanges in kExchangeCreated state
348 //
349 // Thread-compatible.
350 class ExchangeQueue {
351  public:
352   // Creates a new exchange and adds it to the end of the queue.
353   void Emplace(const ::grpc::ByteBuffer& request_buf,
354                protobuf::Message* response, StatusCallback cb,
355                std::string debug_string);
356 
357   // Returns an exchange for which we can initiate request writing, if any.
358   // Returns nullptr if there is no such exchange.
359   Exchange* GetReadyForRequestWriting();
360 
361   // Returns an exchange for which we can initiate response reading, if any.
362   // Returns nullptr if there is no such exchange.
363   Exchange* GetReadyForResponseReading();
364 
365   // Changes the state of the exchange that is current in kRequestWriteIssued
366   // state to kRequestWriteCompleted state.
367   // REQUIRES: There is an exchange in kRequestWriteIssued state.
368   void MarkRequestWriteCompleted();
369 
370   // Returns the exchange at the front of the queue.
371   // REQUIRES: ExchangeQueue is not empty.
372   Exchange& GetFront();
373 
374   // Removes the exchange at the front of the queue.
375   // REQUIRES: ExchangeQueue is not empty.
376   void PopFront();
377 
378   // Returns a string containing addresses and states of all exchanges in this
379   // queue.
380   string DebugString() const;
381 
382   // Swaps the contents of this and `other`.
383   void Swap(ExchangeQueue* other);
384 
385   // Completes all exchanges in this with `status`.
386   void CompleteAll(Status status);
387 
CallStarted()388   void CallStarted() { call_started_ = true; }
389 
390  private:
391   // Does nothing by default. Turn on VLOG(5) to enable.
392   // Checks that this ExchangeQueue is in a valid state.
393   // Kills the process if not.
394   void CheckInvariants();
395 
396   // We can't process any exchanges until the call has started.
397   bool call_started_ = false;
398 
399   // std::queue is based on std::deque by default. std::deque provides
400   // fairly strong iterator stability.
401   std::deque<Exchange> exchanges_;
402 };  // namespace tensorflow
403 
404 // Represents state associated with one streaming RPC call.
405 // Thread-safe
406 template <class Response>
407 class StreamingRPCState : public UntypedStreamingRPCState {
408  public:
409   // Default behavior is to set fail_fast = False and handle timeouts
410   // manually.
StreamingRPCState(std::unique_ptr<grpc::GenericClientAsyncReaderWriter> call,const std::shared_ptr<::grpc::ClientContext> & context)411   StreamingRPCState(std::unique_ptr<grpc::GenericClientAsyncReaderWriter> call,
412                     const std::shared_ptr<::grpc::ClientContext>& context)
413       : context_(context), call_(std::move(call)), call_state_(State::kActive) {
414     Ref();
415     VLOG(3) << "Created new StreamingRPCState " << this;
416     VLOG(3) << "StreamingRPCState(" << this << ") calling grpc::StartCall";
417     call_->StartCall(&call_started_tag_);
418   }
419 
~StreamingRPCState()420   ~StreamingRPCState() override {
421     VLOG(3) << "Destructing StreamingRPCState " << this;
422   }
423 
424   // Attempts to send the next request. `done` is invoked when
425   // `response` has been filled with the data from the server, or if there
426   // is an error. `done` can be invoked before SendNextRequest returns.
427   // Return `true` if the call is alive and the `done` callback has or
428   // will be invoked. If the call is dead, returns `false`. `done` callback
429   // will not be invoked in this case.
430   // REQUIRES: The call has been started, i.e. WaitForCallStarted() has
431   // returned.
SendNextRequest(const protobuf::Message & request,Response * response,const StatusCallback & done)432   bool SendNextRequest(const protobuf::Message& request, Response* response,
433                        const StatusCallback& done) {
434     ::grpc::ByteBuffer request_buf;
435     ::grpc::Status s = GrpcMaybeUnparseProto(request, &request_buf);
436     if (!s.ok()) {
437       Status status = FromGrpcStatus(s);
438       LOG(ERROR) << "GrpcMaybeUnparseProto returned with non-ok status: "
439                  << status.ToString();
440       done(status);
441       return true;
442     }
443 
444     mutex_lock l(mu_);
445     if (call_state_ != State::kActive) {
446       // `done` is not invoked intentionally.
447       return false;
448     }
449     if (VLOG_IS_ON(3)) {
450       // If vlog 3 is enabled, include first 100 chars of request as debug
451       // string.
452       exchanges_.Emplace(request_buf, response, done,
453                          request.ShortDebugString().substr(0, 100));
454     } else {
455       exchanges_.Emplace(request_buf, response, done, "");
456     }
457     MaybeIssueRequestWriteLocked();
458     return true;
459   }
460 
CallStarted(bool ok)461   void CallStarted(bool ok) override {
462     VLOG(3) << "StreamingRPCState(" << this << ")::CallStarted(ok=" << ok
463             << ")";
464     mutex_lock l(mu_);
465     if (!ok) {
466       call_state_ = State::kDone;
467       return;
468     }
469     exchanges_.CallStarted();
470     // Now that the call has started, we can write our first request, if any.
471     MaybeIssueRequestWriteLocked();
472   }
473 
RequestWriteCompleted(bool ok)474   void RequestWriteCompleted(bool ok) override {
475     VLOG(3) << "StreamingRPCState(" << this
476             << ")::RequestWriteCompleted(ok=" << ok << ")";
477     mu_.lock();
478     if (call_state_ != State::kActive) {
479       mu_.unlock();
480       return;
481     }
482     exchanges_.MarkRequestWriteCompleted();
483     // Issue ResponseRead regardless of OK status on completing RequestWrite.
484     // If the underlying completion queue is in Not-OK status due to previous
485     // request failuress (i.e., `ok` from `Next` call on completion queue is
486     // False), delay the error in ResponseRead so we can get the remote error
487     // message from response buffer.
488     MaybeIssueResponseReadLocked();
489 
490     if (ok) {
491       MaybeIssueRequestWriteLocked();
492     }
493     mu_.unlock();
494   }
495 
ResponseReadCompleted(bool ok)496   void ResponseReadCompleted(bool ok) override {
497     VLOG(3) << "StreamingRPCState(" << this
498             << ")::ResponseReadCompleted(ok=" << ok << ")";
499     mu_.lock();
500     if (call_state_ != State::kActive) {
501       mu_.unlock();
502       return;
503     }
504     if (!ok) {
505       IssueCallFinishLocked();
506       mu_.unlock();
507       return;
508     }
509 
510     // Complete the exchange without holding the lock because user's
511     // callback can call back into this RPC code resulting in a deadlock.
512     // No other thread can pop this exchange while we release the lock because
513     // this is the only method that pops exchanges and it is called from a
514     // single thread that waits on completion queue events.
515     Exchange* e;
516     e = &exchanges_.GetFront();
517     mu_.unlock();
518 
519     e->Complete(Status::OK());
520 
521     {
522       mutex_lock l(mu_);
523       exchanges_.PopFront();
524       MaybeIssueResponseReadLocked();
525     }
526   }
527 
CallFinished(bool ok)528   void CallFinished(bool ok) override {
529     VLOG(3) << "StreamingRPCState(" << this << ")::CallFinished(ok=" << ok
530             << ")";
531     mu_.lock();
532     DCHECK(call_state_ != State::kActive);
533     if (call_state_ != State::kFinishing) {
534       mu_.unlock();
535       return;
536     }
537 
538     Status s = FromGrpcStatus(call_status_);
539     if (s.ok() && !ok) {
540       s.Update(
541           errors::Internal("GRPC status is okay but CompletionQueueStatus is "
542                            "not.  This should never happen.",
543                            context_->debug_error_string()));
544     }
545     // unlocks mu_
546     MarkDoneAndCompleteExchanges(s);
547   }
548 
DebugString()549   string DebugString() const override {
550     mutex_lock l(mu_);
551     return exchanges_.DebugString();
552   }
553 
554  private:
555   enum class State {
556     kActive,
557     kFinishing,
558     kDone,
559   };
560 
MarkDoneAndCompleteExchanges(Status status)561   void MarkDoneAndCompleteExchanges(Status status)
562       TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) TF_UNLOCK_FUNCTION(mu_) {
563     call_state_ = State::kDone;
564     VLOG(2) << "Ending gRPC streaming call on the client side due to "
565             << status.ToString();
566     // Swap the exchanges_ into a temporary ExchangeQueue so that we can
567     // complete all exchanges without holding mu_ in case user callback
568     // reach back into this. This should be impossible now, but safer for
569     // the future.
570     ExchangeQueue queue;
571     exchanges_.Swap(&queue);
572     mu_.unlock();
573     queue.CompleteAll(status);
574   }
575 
MaybeIssueRequestWriteLocked()576   void MaybeIssueRequestWriteLocked() TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
577     Exchange* exchange = exchanges_.GetReadyForRequestWriting();
578     if (exchange == nullptr) {
579       // There are no queued exchanges, there is already an outstanding write,
580       // or there are no just created exchanges.
581       return;
582     }
583     exchange->MarkRequestWriteIssued();
584     Ref();
585     VLOG(3) << "StreamingRPCState(" << this << ") calling grpc::Write";
586     call_->Write(exchange->request_buf(), &request_write_completed_tag_);
587   }
588 
MaybeIssueResponseReadLocked()589   void MaybeIssueResponseReadLocked() TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
590     Exchange* exchange = exchanges_.GetReadyForResponseReading();
591     if (exchange == nullptr) {
592       return;
593     }
594     exchange->MarkResponseReadIssued();
595     Ref();
596     VLOG(3) << "StreamingRPCState(" << this << ") calling grpc::Read";
597     call_->Read(exchange->response_buf(), &response_read_completed_tag_);
598   }
599 
IssueCallFinishLocked()600   void IssueCallFinishLocked() TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
601     call_state_ = State::kFinishing;
602     Ref();
603     VLOG(3) << "StreamingRPCState(" << this << ") calling grpc::Finish";
604     // We call finish in response to completed (with error) response reading tag
605     // on some exchange. We let this exchange hang in ResponseReadIssued state.
606     // ExchangeQueue makes sure that there is at most one exchange in this
607     // state. So, no new reads will be issued.
608     call_->Finish(&call_status_, &finished_tag_);
609   }
610 
611   // Holds state for a single request/response exchange between the client
612   // and the server.
613   typedef typename UntypedStreamingRPCState::Tag Tag;
614 
615   // Order of context_ and call_ is important because context_ must outlive
616   // call_.
617   const std::shared_ptr<const ::grpc::ClientContext> context_;
618   std::unique_ptr<grpc::GenericClientAsyncReaderWriter> call_;
619 
620   mutable mutex mu_;
621   ExchangeQueue exchanges_ TF_GUARDED_BY(mu_);
622   State call_state_ TF_GUARDED_BY(mu_);
623   ::grpc::Status call_status_ TF_GUARDED_BY(mu_);
624 
625   // We can get away with having single instances of these tags per
626   // StreamingRPCState because we make sure (as gRPC requires) that
627   // there is at most one outstanding Read and at most one outstanding Write
628   // in the completion queue.
629   // Tags are immutable. No need to guard them.
630   Tag call_started_tag_{this, Tag::TagType::kCallStarted};
631   Tag request_write_completed_tag_{this, Tag::TagType::kRequestWriteCompleted};
632   Tag response_read_completed_tag_{this, Tag::TagType::kResponseReadCompleted};
633   Tag finished_tag_{this, Tag::TagType::kCallFinished};
634 };
635 
636 // Creates streaming calls and dispatches requests to them.
637 // In the common case, the client would create a StreamingRPCDispatcher for
638 // each bidirectional streaming RPC it might want to make. The first time, it
639 // calls SendNextRequest, a streaming call is initiated and the request is
640 // sent within this call. Initiation of the call blocks the client. If there are
641 // no errors, subsequent calls to SendNextRequest would use the already active
642 // call. If there was an error, the call object will be destroyed after all
643 // the callbacks for outstanding requests have been invoked. The next call to
644 // SendNextRequest will initiate a new call.
645 //
646 // Callbacks that are part of the same call, are invoked in the order they were
647 // provided, but callbacks across calls (a failed and a new one) can be invoked
648 // in any order.
649 //
650 // Thread-safe.
651 template <class Response>
652 class StreamingRPCDispatcher {
653  public:
StreamingRPCDispatcher(::grpc::GenericStub * stub,::grpc::CompletionQueue * cq,const::grpc::string & method)654   StreamingRPCDispatcher(::grpc::GenericStub* stub, ::grpc::CompletionQueue* cq,
655                          const ::grpc::string& method)
656       : stub_(stub), cq_(cq), method_(method) {}
657 
658   // Attempts to send the next request. If there is no active streaming call,
659   // starts one and sends the request on top of it. `done` is invoked when
660   // `response` has been filled with the data from the server, or if there
661   // is an error. `done` can be invoked before SendNextRequest returns.
SendNextRequest(const protobuf::Message & request,Response * response,StatusCallback done)662   void SendNextRequest(const protobuf::Message& request, Response* response,
663                        StatusCallback done) {
664     mutex_lock l(mu_);
665     if (state_ == nullptr) {
666       CreateStreamingState();
667     }
668 
669     bool is_call_alive = state_->SendNextRequest(request, response, done);
670     if (is_call_alive) {
671       return;
672     }
673 
674     // The attempt to send failed because the call was dead, create a new
675     // call and try again. When the call is dead SendNextRequest does not call
676     // `done`.
677     CreateStreamingState();
678 
679     is_call_alive = state_->SendNextRequest(request, response, done);
680     if (!is_call_alive) {
681       // Consider retrying to create and start a call few more times.
682       done(errors::Unknown("gRPC call failed right after it was created"));
683     }
684   }
685 
686   // Request to cancel the current streaming call. Non-blocking.
CancelCall()687   void CancelCall() {
688     mutex_lock l(mu_);
689     if (state_ == nullptr) {
690       return;
691     }
692     context_->TryCancel();
693     state_ = nullptr;
694   }
695 
696  private:
CreateStreamingState()697   void CreateStreamingState() TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
698     // ClientContext cannot be reused across calls.
699     context_ = std::make_shared<::grpc::ClientContext>();
700     // Don't immediately fail StartCall if the channel is not ready. Wait for
701     // the channel to become ready.
702     context_->set_wait_for_ready(true);
703 
704     std::unique_ptr<grpc::GenericClientAsyncReaderWriter> call =
705         stub_->PrepareCall(context_.get(), method_, cq_);
706 
707     state_.reset(new StreamingRPCState<Response>(std::move(call), context_));
708   }
709 
710   mutable mutex mu_;
711 
712   // Both are thread-safe
713   ::grpc::GenericStub* const stub_;
714   ::grpc::CompletionQueue* const cq_;
715 
716   // Does not need synchronization since it is constant.
717   const ::grpc::string method_;
718 
719   std::shared_ptr<::grpc::ClientContext> context_ TF_GUARDED_BY(mu_);
720   core::RefCountPtr<StreamingRPCState<Response>> state_ TF_GUARDED_BY(mu_);
721 };
722 
723 }  // namespace tensorflow
724 
725 #endif  // TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_GRPC_STATE_H_
726