1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_GRPC_CALL_H_
17 #define TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_GRPC_CALL_H_
18 
19 #include "tensorflow/core/lib/core/refcount.h"
20 #include "tensorflow/core/platform/macros.h"
21 #include "tensorflow/core/platform/mutex.h"
22 
23 #include "grpcpp/grpcpp.h"
24 #include "grpcpp/impl/codegen/service_type.h"
25 #include "grpcpp/server_builder.h"
26 
27 namespace tensorflow {
28 
29 // CALL STRUCTURES
30 // ===============
31 //
32 // Each pending (incoming) request corresponds to a call object that
33 // encapsulates the state of the call. Templates and
34 // pointers-to-member functions are used to avoid boilerplate and
35 // redundant closure creation. The class hierarchy is as follows:
36 //
37 // * `UntypedCall<Service>`: The base class represents a call that
38 //   could be associated with any of the methods on a service of type
39 //   `Service`. Also defines a `Tag` nested class that can be used as
40 //   the tag in a `grpc::CompletionQueue`.  Each class that
41 //   instantiates `Service` should have a completion queue polling
42 //   loop that knows about `UntypedCall<Service>::Tag` objects, and
43 //   invokes their `OnCompleted()` method to continue processing.
44 //
45 // * `Call<Service, GrpcService, Req, Resp>`: This class extends
46 //   `UntypedCall<Service>` and is additionally parameterized by the
47 //   gRPC-generated asynchronous service class, and the request and
48 //   response message types. It defines the state associated with a
49 //   call (whose type depends on the message types), and stores a
50 //   pointer to a `Service::HandleFoo()` handler method. Each
51 //   `Service::HandleFoo()` method knows about the corresponding
52 //   `Call` type, in order to access its state, and invoke its
53 //   `SendResponse()` method.
54 //
55 // The lifecycle of a call object is as follows.
56 //
57 // 1. A `Service` creates a `Call` for a particular method and
58 //    enqueues it in its completion queue (via an
59 //    `UntypedCall<Service>::Tag`).
60 //
61 // 2. When the tag is returned from `cq_->Next()`, the
62 //    `UntypedCall::RequestReceived()` method is invoked and takes
63 //    ownership of the call object. This indirectly invokes the
64 //    appropriate handler method on `Service`.
65 //
66 // 3. After the response has been written (perhaps in another thread),
67 //    the `Call::SendResponse()` method is invoked. It transfers
68 //    ownership of the call object back to the completion queue (via
69 //    an `UntypedCall::Tag`).
70 //
71 // 4. When the response has been sent, the tag is returned from
72 //    `cq_->Next()`, and the call object is deleted.
73 
74 // Represents a pending request with unknown message types.
75 template <class Service>
76 class UntypedCall : public core::RefCounted {
77  public:
~UntypedCall()78   virtual ~UntypedCall() {}
79 
80   // The implementation of this method should use `service` to handle
81   // an incoming request, and (perhaps asynchronously) send the
82   // response.
83   //
84   // One reference on `this` is transferred to the callee, and the
85   // callee is responsible for releasing it (typically via
86   // `Call::SendResponse()`).
87   //
88   // `ok` is true if the request was received in a "regular event",
89   // otherwise false.
90   virtual void RequestReceived(Service* service, bool ok) = 0;
91 
92   // This method will be called either (i) when the server is notified
93   // that the request has been canceled, or (ii) when the request completes
94   // normally. The implementation should distinguish these cases by querying
95   // the `grpc::ServerContext` associated with the request.
96   virtual void RequestCancelled(Service* service, bool ok) = 0;
97 
98   // Associates a tag in a `::grpc::CompletionQueue` with a callback
99   // for an incoming RPC.  An active Tag owns a reference on the corresponding
100   // Call object.
101   class Tag {
102    public:
103     // One enum value per supported callback.
104     enum Callback { kRequestReceived, kResponseSent, kCancelled };
105 
Tag(UntypedCall * call,Callback cb)106     Tag(UntypedCall* call, Callback cb) : call_(call), callback_(cb) {}
107 
108     // Calls the callback associated with this tag.
109     //
110     // The callback takes ownership of `this->call_`.
OnCompleted(Service * service,bool ok)111     void OnCompleted(Service* service, bool ok) {
112       switch (callback_) {
113         case kRequestReceived:
114           call_->RequestReceived(service, ok);
115           break;
116         case kResponseSent:
117           // No special handling needed apart from the Unref below.
118           break;
119         case kCancelled:
120           call_->RequestCancelled(service, ok);
121           break;
122       }
123       call_->Unref();  // Ref acquired when tag handed to grpc.
124     }
125 
126    private:
127     UntypedCall* const call_;  // `this` owns one reference.
128     Callback callback_;
129   };
130 };
131 
132 // Represents a pending call with known request and response message
133 // types, and a known request-handling method.
134 template <class Service, class GrpcService, class RequestMessage,
135           class ResponseMessage>
136 class Call : public UntypedCall<Service> {
137  public:
138   // Represents the generic signature of a generated
139   // `GrpcService::RequestFoo()` method, where `Foo` is the name of an
140   // RPC method.
141   using EnqueueFunction = void (GrpcService::*)(
142       ::grpc::ServerContext*, RequestMessage*,
143       ::grpc::ServerAsyncResponseWriter<ResponseMessage>*,
144       ::grpc::CompletionQueue*, ::grpc::ServerCompletionQueue*, void*);
145 
146   // Represents the generic signature of a `Service::HandleFoo()`
147   // method, where `Foo` is the name of an RPC method.
148   using HandleRequestFunction = void (Service::*)(
149       Call<Service, GrpcService, RequestMessage, ResponseMessage>*);
150 
Call(HandleRequestFunction handle_request_function)151   Call(HandleRequestFunction handle_request_function)
152       : handle_request_function_(handle_request_function), responder_(&ctx_) {}
153 
~Call()154   virtual ~Call() {}
155 
RequestReceived(Service * service,bool ok)156   void RequestReceived(Service* service, bool ok) override {
157     if (ok) {
158       this->Ref();
159       (service->*handle_request_function_)(this);
160     }
161   }
162 
SendResponse(::grpc::Status status)163   void SendResponse(::grpc::Status status) {
164     this->Ref();  // Ref for grpc; released in Tag callback.
165     responder_.Finish(response, status, &response_sent_tag_);
166     this->Unref();
167   }
168 
RequestCancelled(Service * service,bool ok)169   void RequestCancelled(Service* service, bool ok) override {
170     if (ctx_.IsCancelled()) {
171       mutex_lock l(mu_);
172       if (cancel_callback_) {
173         cancel_callback_();
174       }
175     }
176   }
177 
178   // Registers `callback` as the function that should be called if and when this
179   // call is canceled by the client.
SetCancelCallback(std::function<void ()> callback)180   void SetCancelCallback(std::function<void()> callback) {
181     mutex_lock l(mu_);
182     cancel_callback_ = std::move(callback);
183   }
184 
185   // Clears any cancellation callback that has been registered for this call.
ClearCancelCallback()186   void ClearCancelCallback() {
187     mutex_lock l(mu_);
188     cancel_callback_ = nullptr;
189   }
190 
191   // Enqueues a new request for the given service on the given
192   // completion queue, using the given `enqueue_function`.
193   //
194   // The request will be handled with the given
195   // `handle_request_function`.
EnqueueRequest(GrpcService * grpc_service,::grpc::ServerCompletionQueue * cq,EnqueueFunction enqueue_function,HandleRequestFunction handle_request_function,bool supports_cancel)196   static void EnqueueRequest(GrpcService* grpc_service,
197                              ::grpc::ServerCompletionQueue* cq,
198                              EnqueueFunction enqueue_function,
199                              HandleRequestFunction handle_request_function,
200                              bool supports_cancel) {
201     auto call = new Call<Service, GrpcService, RequestMessage, ResponseMessage>(
202         handle_request_function);
203     if (supports_cancel) {
204       call->RegisterCancellationHandler();
205     }
206 
207     // Initial ref for call handed to grpc; released in Tag callback.
208     (grpc_service->*enqueue_function)(&call->ctx_, &call->request,
209                                       &call->responder_, cq, cq,
210                                       &call->request_received_tag_);
211   }
212 
213   // Enqueues a new request for the given service on the given
214   // completion queue, using the given `method_id`.
215   //
216   // The request will be handled with the given
217   // `handle_request_function`.
EnqueueRequestForMethod(GrpcService * grpc_service,::grpc::ServerCompletionQueue * cq,int method_id,HandleRequestFunction handle_request_function,bool supports_cancel)218   static void EnqueueRequestForMethod(
219       GrpcService* grpc_service, ::grpc::ServerCompletionQueue* cq,
220       int method_id, HandleRequestFunction handle_request_function,
221       bool supports_cancel) {
222     auto call = new Call<Service, GrpcService, RequestMessage, ResponseMessage>(
223         handle_request_function);
224     if (supports_cancel) {
225       call->RegisterCancellationHandler();
226     }
227 
228     // Initial ref for call handed to grpc; released in Tag callback.
229     grpc_service->RequestAsyncUnary(method_id, &call->ctx_, &call->request,
230                                     &call->responder_, cq, cq,
231                                     &call->request_received_tag_);
232   }
233 
234   RequestMessage request;
235   ResponseMessage response;
236 
client_metadata()237   const std::multimap<::grpc::string_ref, ::grpc::string_ref>& client_metadata()
238       const {
239     return ctx_.client_metadata();
240   }
241 
242  private:
243   // Creates a completion queue tag for handling cancellation by the client.
244   // NOTE: This method must be called before this call is enqueued on a
245   // completion queue.
RegisterCancellationHandler()246   void RegisterCancellationHandler() {
247     this->Ref();  // Ref for grpc; released in Tag callback.
248     ctx_.AsyncNotifyWhenDone(&cancelled_tag_);
249   }
250 
251   HandleRequestFunction handle_request_function_;
252   ::grpc::ServerContext ctx_;
253   ::grpc::ServerAsyncResponseWriter<ResponseMessage> responder_;
254 
255   // Used as void* completion markers from grpc to indicate different
256   // events of interest for a Call.
257   typedef typename UntypedCall<Service>::Tag Tag;
258   Tag request_received_tag_{this, Tag::kRequestReceived};
259   Tag response_sent_tag_{this, Tag::kResponseSent};
260   Tag cancelled_tag_{this, Tag::kCancelled};
261 
262   mutex mu_;
263   std::function<void()> cancel_callback_ GUARDED_BY(mu_);
264 };
265 
266 }  // namespace tensorflow
267 
268 #endif  // TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_GRPC_CALL_H_
269