1 /*
2  *
3  * Copyright 2015 gRPC authors.
4  *
5  * Licensed under the Apache License, Version 2.0 (the "License");
6  * you may not use this file except in compliance with the License.
7  * You may obtain a copy of the License at
8  *
9  *     http://www.apache.org/licenses/LICENSE-2.0
10  *
11  * Unless required by applicable law or agreed to in writing, software
12  * distributed under the License is distributed on an "AS IS" BASIS,
13  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14  * See the License for the specific language governing permissions and
15  * limitations under the License.
16  *
17  */
18 
19 #ifndef GRPCPP_IMPL_CODEGEN_METHOD_HANDLER_IMPL_H
20 #define GRPCPP_IMPL_CODEGEN_METHOD_HANDLER_IMPL_H
21 
22 #include <grpcpp/impl/codegen/byte_buffer.h>
23 #include <grpcpp/impl/codegen/core_codegen_interface.h>
24 #include <grpcpp/impl/codegen/rpc_service_method.h>
25 #include <grpcpp/impl/codegen/sync_stream.h>
26 
27 namespace grpc {
28 
29 namespace internal {
30 
31 // Invoke the method handler, fill in the status, and
32 // return whether or not we finished safely (without an exception).
33 // Note that exception handling is 0-cost in most compiler/library
34 // implementations (except when an exception is actually thrown),
35 // so this process doesn't require additional overhead in the common case.
36 // Additionally, we don't need to return if we caught an exception or not;
37 // the handling is the same in either case.
38 template <class Callable>
CatchingFunctionHandler(Callable && handler)39 Status CatchingFunctionHandler(Callable&& handler) {
40 #if GRPC_ALLOW_EXCEPTIONS
41   try {
42     return handler();
43   } catch (...) {
44     return Status(StatusCode::UNKNOWN, "Unexpected error in RPC handling");
45   }
46 #else   // GRPC_ALLOW_EXCEPTIONS
47   return handler();
48 #endif  // GRPC_ALLOW_EXCEPTIONS
49 }
50 
51 /// A wrapper class of an application provided rpc method handler.
52 template <class ServiceType, class RequestType, class ResponseType>
53 class RpcMethodHandler : public MethodHandler {
54  public:
RpcMethodHandler(std::function<Status (ServiceType *,ServerContext *,const RequestType *,ResponseType *)> func,ServiceType * service)55   RpcMethodHandler(std::function<Status(ServiceType*, ServerContext*,
56                                         const RequestType*, ResponseType*)>
57                        func,
58                    ServiceType* service)
59       : func_(func), service_(service) {}
60 
RunHandler(const HandlerParameter & param)61   void RunHandler(const HandlerParameter& param) final {
62     RequestType req;
63     Status status = SerializationTraits<RequestType>::Deserialize(
64         param.request.bbuf_ptr(), &req);
65     ResponseType rsp;
66     if (status.ok()) {
67       status = CatchingFunctionHandler([this, &param, &req, &rsp] {
68         return func_(service_, param.server_context, &req, &rsp);
69       });
70     }
71 
72     GPR_CODEGEN_ASSERT(!param.server_context->sent_initial_metadata_);
73     CallOpSet<CallOpSendInitialMetadata, CallOpSendMessage,
74               CallOpServerSendStatus>
75         ops;
76     ops.SendInitialMetadata(param.server_context->initial_metadata_,
77                             param.server_context->initial_metadata_flags());
78     if (param.server_context->compression_level_set()) {
79       ops.set_compression_level(param.server_context->compression_level());
80     }
81     if (status.ok()) {
82       status = ops.SendMessage(rsp);
83     }
84     ops.ServerSendStatus(param.server_context->trailing_metadata_, status);
85     param.call->PerformOps(&ops);
86     param.call->cq()->Pluck(&ops);
87   }
88 
89  private:
90   /// Application provided rpc handler function.
91   std::function<Status(ServiceType*, ServerContext*, const RequestType*,
92                        ResponseType*)>
93       func_;
94   // The class the above handler function lives in.
95   ServiceType* service_;
96 };
97 
98 /// A wrapper class of an application provided client streaming handler.
99 template <class ServiceType, class RequestType, class ResponseType>
100 class ClientStreamingHandler : public MethodHandler {
101  public:
ClientStreamingHandler(std::function<Status (ServiceType *,ServerContext *,ServerReader<RequestType> *,ResponseType *)> func,ServiceType * service)102   ClientStreamingHandler(
103       std::function<Status(ServiceType*, ServerContext*,
104                            ServerReader<RequestType>*, ResponseType*)>
105           func,
106       ServiceType* service)
107       : func_(func), service_(service) {}
108 
RunHandler(const HandlerParameter & param)109   void RunHandler(const HandlerParameter& param) final {
110     ServerReader<RequestType> reader(param.call, param.server_context);
111     ResponseType rsp;
112     Status status = CatchingFunctionHandler([this, &param, &reader, &rsp] {
113       return func_(service_, param.server_context, &reader, &rsp);
114     });
115 
116     CallOpSet<CallOpSendInitialMetadata, CallOpSendMessage,
117               CallOpServerSendStatus>
118         ops;
119     if (!param.server_context->sent_initial_metadata_) {
120       ops.SendInitialMetadata(param.server_context->initial_metadata_,
121                               param.server_context->initial_metadata_flags());
122       if (param.server_context->compression_level_set()) {
123         ops.set_compression_level(param.server_context->compression_level());
124       }
125     }
126     if (status.ok()) {
127       status = ops.SendMessage(rsp);
128     }
129     ops.ServerSendStatus(param.server_context->trailing_metadata_, status);
130     param.call->PerformOps(&ops);
131     param.call->cq()->Pluck(&ops);
132   }
133 
134  private:
135   std::function<Status(ServiceType*, ServerContext*, ServerReader<RequestType>*,
136                        ResponseType*)>
137       func_;
138   ServiceType* service_;
139 };
140 
141 /// A wrapper class of an application provided server streaming handler.
142 template <class ServiceType, class RequestType, class ResponseType>
143 class ServerStreamingHandler : public MethodHandler {
144  public:
ServerStreamingHandler(std::function<Status (ServiceType *,ServerContext *,const RequestType *,ServerWriter<ResponseType> *)> func,ServiceType * service)145   ServerStreamingHandler(
146       std::function<Status(ServiceType*, ServerContext*, const RequestType*,
147                            ServerWriter<ResponseType>*)>
148           func,
149       ServiceType* service)
150       : func_(func), service_(service) {}
151 
RunHandler(const HandlerParameter & param)152   void RunHandler(const HandlerParameter& param) final {
153     RequestType req;
154     Status status = SerializationTraits<RequestType>::Deserialize(
155         param.request.bbuf_ptr(), &req);
156 
157     if (status.ok()) {
158       ServerWriter<ResponseType> writer(param.call, param.server_context);
159       status = CatchingFunctionHandler([this, &param, &req, &writer] {
160         return func_(service_, param.server_context, &req, &writer);
161       });
162     }
163 
164     CallOpSet<CallOpSendInitialMetadata, CallOpServerSendStatus> ops;
165     if (!param.server_context->sent_initial_metadata_) {
166       ops.SendInitialMetadata(param.server_context->initial_metadata_,
167                               param.server_context->initial_metadata_flags());
168       if (param.server_context->compression_level_set()) {
169         ops.set_compression_level(param.server_context->compression_level());
170       }
171     }
172     ops.ServerSendStatus(param.server_context->trailing_metadata_, status);
173     param.call->PerformOps(&ops);
174     if (param.server_context->has_pending_ops_) {
175       param.call->cq()->Pluck(&param.server_context->pending_ops_);
176     }
177     param.call->cq()->Pluck(&ops);
178   }
179 
180  private:
181   std::function<Status(ServiceType*, ServerContext*, const RequestType*,
182                        ServerWriter<ResponseType>*)>
183       func_;
184   ServiceType* service_;
185 };
186 
187 /// A wrapper class of an application provided bidi-streaming handler.
188 /// This also applies to server-streamed implementation of a unary method
189 /// with the additional requirement that such methods must have done a
190 /// write for status to be ok
191 /// Since this is used by more than 1 class, the service is not passed in.
192 /// Instead, it is expected to be an implicitly-captured argument of func
193 /// (through bind or something along those lines)
194 template <class Streamer, bool WriteNeeded>
195 class TemplatedBidiStreamingHandler : public MethodHandler {
196  public:
TemplatedBidiStreamingHandler(std::function<Status (ServerContext *,Streamer *)> func)197   TemplatedBidiStreamingHandler(
198       std::function<Status(ServerContext*, Streamer*)> func)
199       : func_(func), write_needed_(WriteNeeded) {}
200 
RunHandler(const HandlerParameter & param)201   void RunHandler(const HandlerParameter& param) final {
202     Streamer stream(param.call, param.server_context);
203     Status status = CatchingFunctionHandler([this, &param, &stream] {
204       return func_(param.server_context, &stream);
205     });
206 
207     CallOpSet<CallOpSendInitialMetadata, CallOpServerSendStatus> ops;
208     if (!param.server_context->sent_initial_metadata_) {
209       ops.SendInitialMetadata(param.server_context->initial_metadata_,
210                               param.server_context->initial_metadata_flags());
211       if (param.server_context->compression_level_set()) {
212         ops.set_compression_level(param.server_context->compression_level());
213       }
214       if (write_needed_ && status.ok()) {
215         // If we needed a write but never did one, we need to mark the
216         // status as a fail
217         status = Status(StatusCode::INTERNAL,
218                         "Service did not provide response message");
219       }
220     }
221     ops.ServerSendStatus(param.server_context->trailing_metadata_, status);
222     param.call->PerformOps(&ops);
223     if (param.server_context->has_pending_ops_) {
224       param.call->cq()->Pluck(&param.server_context->pending_ops_);
225     }
226     param.call->cq()->Pluck(&ops);
227   }
228 
229  private:
230   std::function<Status(ServerContext*, Streamer*)> func_;
231   const bool write_needed_;
232 };
233 
234 template <class ServiceType, class RequestType, class ResponseType>
235 class BidiStreamingHandler
236     : public TemplatedBidiStreamingHandler<
237           ServerReaderWriter<ResponseType, RequestType>, false> {
238  public:
BidiStreamingHandler(std::function<Status (ServiceType *,ServerContext *,ServerReaderWriter<ResponseType,RequestType> *)> func,ServiceType * service)239   BidiStreamingHandler(
240       std::function<Status(ServiceType*, ServerContext*,
241                            ServerReaderWriter<ResponseType, RequestType>*)>
242           func,
243       ServiceType* service)
244       : TemplatedBidiStreamingHandler<
245             ServerReaderWriter<ResponseType, RequestType>, false>(std::bind(
246             func, service, std::placeholders::_1, std::placeholders::_2)) {}
247 };
248 
249 template <class RequestType, class ResponseType>
250 class StreamedUnaryHandler
251     : public TemplatedBidiStreamingHandler<
252           ServerUnaryStreamer<RequestType, ResponseType>, true> {
253  public:
StreamedUnaryHandler(std::function<Status (ServerContext *,ServerUnaryStreamer<RequestType,ResponseType> *)> func)254   explicit StreamedUnaryHandler(
255       std::function<Status(ServerContext*,
256                            ServerUnaryStreamer<RequestType, ResponseType>*)>
257           func)
258       : TemplatedBidiStreamingHandler<
259             ServerUnaryStreamer<RequestType, ResponseType>, true>(func) {}
260 };
261 
262 template <class RequestType, class ResponseType>
263 class SplitServerStreamingHandler
264     : public TemplatedBidiStreamingHandler<
265           ServerSplitStreamer<RequestType, ResponseType>, false> {
266  public:
SplitServerStreamingHandler(std::function<Status (ServerContext *,ServerSplitStreamer<RequestType,ResponseType> *)> func)267   explicit SplitServerStreamingHandler(
268       std::function<Status(ServerContext*,
269                            ServerSplitStreamer<RequestType, ResponseType>*)>
270           func)
271       : TemplatedBidiStreamingHandler<
272             ServerSplitStreamer<RequestType, ResponseType>, false>(func) {}
273 };
274 
275 /// General method handler class for errors that prevent real method use
276 /// e.g., handle unknown method by returning UNIMPLEMENTED error.
277 template <StatusCode code>
278 class ErrorMethodHandler : public MethodHandler {
279  public:
280   template <class T>
FillOps(ServerContext * context,T * ops)281   static void FillOps(ServerContext* context, T* ops) {
282     Status status(code, "");
283     if (!context->sent_initial_metadata_) {
284       ops->SendInitialMetadata(context->initial_metadata_,
285                                context->initial_metadata_flags());
286       if (context->compression_level_set()) {
287         ops->set_compression_level(context->compression_level());
288       }
289       context->sent_initial_metadata_ = true;
290     }
291     ops->ServerSendStatus(context->trailing_metadata_, status);
292   }
293 
RunHandler(const HandlerParameter & param)294   void RunHandler(const HandlerParameter& param) final {
295     CallOpSet<CallOpSendInitialMetadata, CallOpServerSendStatus> ops;
296     FillOps(param.server_context, &ops);
297     param.call->PerformOps(&ops);
298     param.call->cq()->Pluck(&ops);
299     // We also have to destroy any request payload in the handler parameter
300     ByteBuffer* payload = param.request.bbuf_ptr();
301     if (payload != nullptr) {
302       payload->Clear();
303     }
304   }
305 };
306 
307 typedef ErrorMethodHandler<StatusCode::UNIMPLEMENTED> UnknownMethodHandler;
308 typedef ErrorMethodHandler<StatusCode::RESOURCE_EXHAUSTED>
309     ResourceExhaustedHandler;
310 
311 }  // namespace internal
312 }  // namespace grpc
313 
314 #endif  // GRPCPP_IMPL_CODEGEN_METHOD_HANDLER_IMPL_H
315