1 //===------- RPCUTils.h - Utilities for building RPC APIs -------*- C++ -*-===//
2 //
3 //                     The LLVM Compiler Infrastructure
4 //
5 // This file is distributed under the University of Illinois Open Source
6 // License. See LICENSE.TXT for details.
7 //
8 //===----------------------------------------------------------------------===//
9 //
10 // Utilities to support construction of simple RPC APIs.
11 //
12 // The RPC utilities aim for ease of use (minimal conceptual overhead) for C++
13 // programmers, high performance, low memory overhead, and efficient use of the
14 // communications channel.
15 //
16 //===----------------------------------------------------------------------===//
17 
18 #ifndef LLVM_EXECUTIONENGINE_ORC_RPCUTILS_H
19 #define LLVM_EXECUTIONENGINE_ORC_RPCUTILS_H
20 
21 #include <map>
22 #include <thread>
23 #include <vector>
24 
25 #include "llvm/ADT/STLExtras.h"
26 #include "llvm/ExecutionEngine/Orc/OrcError.h"
27 #include "llvm/ExecutionEngine/Orc/RPCSerialization.h"
28 
29 #include <future>
30 
31 namespace llvm {
32 namespace orc {
33 namespace rpc {
34 
35 /// Base class of all fatal RPC errors (those that necessarily result in the
36 /// termination of the RPC session).
37 class RPCFatalError : public ErrorInfo<RPCFatalError> {
38 public:
39   static char ID;
40 };
41 
42 /// RPCConnectionClosed is returned from RPC operations if the RPC connection
43 /// has already been closed due to either an error or graceful disconnection.
44 class ConnectionClosed : public ErrorInfo<ConnectionClosed> {
45 public:
46   static char ID;
47   std::error_code convertToErrorCode() const override;
48   void log(raw_ostream &OS) const override;
49 };
50 
51 /// BadFunctionCall is returned from handleOne when the remote makes a call with
52 /// an unrecognized function id.
53 ///
54 /// This error is fatal because Orc RPC needs to know how to parse a function
55 /// call to know where the next call starts, and if it doesn't recognize the
56 /// function id it cannot parse the call.
57 template <typename FnIdT, typename SeqNoT>
58 class BadFunctionCall
59   : public ErrorInfo<BadFunctionCall<FnIdT, SeqNoT>, RPCFatalError> {
60 public:
61   static char ID;
62 
BadFunctionCall(FnIdT FnId,SeqNoT SeqNo)63   BadFunctionCall(FnIdT FnId, SeqNoT SeqNo)
64       : FnId(std::move(FnId)), SeqNo(std::move(SeqNo)) {}
65 
convertToErrorCode()66   std::error_code convertToErrorCode() const override {
67     return orcError(OrcErrorCode::UnexpectedRPCCall);
68   }
69 
log(raw_ostream & OS)70   void log(raw_ostream &OS) const override {
71     OS << "Call to invalid RPC function id '" << FnId << "' with "
72           "sequence number " << SeqNo;
73   }
74 
75 private:
76   FnIdT FnId;
77   SeqNoT SeqNo;
78 };
79 
80 template <typename FnIdT, typename SeqNoT>
81 char BadFunctionCall<FnIdT, SeqNoT>::ID = 0;
82 
83 /// InvalidSequenceNumberForResponse is returned from handleOne when a response
84 /// call arrives with a sequence number that doesn't correspond to any in-flight
85 /// function call.
86 ///
87 /// This error is fatal because Orc RPC needs to know how to parse the rest of
88 /// the response call to know where the next call starts, and if it doesn't have
89 /// a result parser for this sequence number it can't do that.
90 template <typename SeqNoT>
91 class InvalidSequenceNumberForResponse
92     : public ErrorInfo<InvalidSequenceNumberForResponse<SeqNoT>, RPCFatalError> {
93 public:
94   static char ID;
95 
InvalidSequenceNumberForResponse(SeqNoT SeqNo)96   InvalidSequenceNumberForResponse(SeqNoT SeqNo)
97       : SeqNo(std::move(SeqNo)) {}
98 
convertToErrorCode()99   std::error_code convertToErrorCode() const override {
100     return orcError(OrcErrorCode::UnexpectedRPCCall);
101   };
102 
log(raw_ostream & OS)103   void log(raw_ostream &OS) const override {
104     OS << "Response has unknown sequence number " << SeqNo;
105   }
106 private:
107   SeqNoT SeqNo;
108 };
109 
110 template <typename SeqNoT>
111 char InvalidSequenceNumberForResponse<SeqNoT>::ID = 0;
112 
113 /// This non-fatal error will be passed to asynchronous result handlers in place
114 /// of a result if the connection goes down before a result returns, or if the
115 /// function to be called cannot be negotiated with the remote.
116 class ResponseAbandoned : public ErrorInfo<ResponseAbandoned> {
117 public:
118   static char ID;
119 
120   std::error_code convertToErrorCode() const override;
121   void log(raw_ostream &OS) const override;
122 };
123 
124 /// This error is returned if the remote does not have a handler installed for
125 /// the given RPC function.
126 class CouldNotNegotiate : public ErrorInfo<CouldNotNegotiate> {
127 public:
128   static char ID;
129 
130   CouldNotNegotiate(std::string Signature);
131   std::error_code convertToErrorCode() const override;
132   void log(raw_ostream &OS) const override;
getSignature()133   const std::string &getSignature() const { return Signature; }
134 private:
135   std::string Signature;
136 };
137 
138 template <typename DerivedFunc, typename FnT> class Function;
139 
140 // RPC Function class.
141 // DerivedFunc should be a user defined class with a static 'getName()' method
142 // returning a const char* representing the function's name.
143 template <typename DerivedFunc, typename RetT, typename... ArgTs>
144 class Function<DerivedFunc, RetT(ArgTs...)> {
145 public:
146   /// User defined function type.
147   using Type = RetT(ArgTs...);
148 
149   /// Return type.
150   using ReturnType = RetT;
151 
152   /// Returns the full function prototype as a string.
getPrototype()153   static const char *getPrototype() {
154     std::lock_guard<std::mutex> Lock(NameMutex);
155     if (Name.empty())
156       raw_string_ostream(Name)
157           << RPCTypeName<RetT>::getName() << " " << DerivedFunc::getName()
158           << "(" << llvm::orc::rpc::RPCTypeNameSequence<ArgTs...>() << ")";
159     return Name.data();
160   }
161 
162 private:
163   static std::mutex NameMutex;
164   static std::string Name;
165 };
166 
167 template <typename DerivedFunc, typename RetT, typename... ArgTs>
168 std::mutex Function<DerivedFunc, RetT(ArgTs...)>::NameMutex;
169 
170 template <typename DerivedFunc, typename RetT, typename... ArgTs>
171 std::string Function<DerivedFunc, RetT(ArgTs...)>::Name;
172 
173 /// Allocates RPC function ids during autonegotiation.
174 /// Specializations of this class must provide four members:
175 ///
176 /// static T getInvalidId():
177 ///   Should return a reserved id that will be used to represent missing
178 /// functions during autonegotiation.
179 ///
180 /// static T getResponseId():
181 ///   Should return a reserved id that will be used to send function responses
182 /// (return values).
183 ///
184 /// static T getNegotiateId():
185 ///   Should return a reserved id for the negotiate function, which will be used
186 /// to negotiate ids for user defined functions.
187 ///
188 /// template <typename Func> T allocate():
189 ///   Allocate a unique id for function Func.
190 template <typename T, typename = void> class RPCFunctionIdAllocator;
191 
192 /// This specialization of RPCFunctionIdAllocator provides a default
193 /// implementation for integral types.
194 template <typename T>
195 class RPCFunctionIdAllocator<
196     T, typename std::enable_if<std::is_integral<T>::value>::type> {
197 public:
getInvalidId()198   static T getInvalidId() { return T(0); }
getResponseId()199   static T getResponseId() { return T(1); }
getNegotiateId()200   static T getNegotiateId() { return T(2); }
201 
allocate()202   template <typename Func> T allocate() { return NextId++; }
203 
204 private:
205   T NextId = 3;
206 };
207 
208 namespace detail {
209 
210 // FIXME: Remove MSVCPError/MSVCPExpected once MSVC's future implementation
211 //        supports classes without default constructors.
212 #ifdef _MSC_VER
213 
214 namespace msvc_hacks {
215 
216 // Work around MSVC's future implementation's use of default constructors:
217 // A default constructed value in the promise will be overwritten when the
218 // real error is set - so the default constructed Error has to be checked
219 // already.
220 class MSVCPError : public Error {
221 public:
MSVCPError()222   MSVCPError() { (void)!!*this; }
223 
MSVCPError(MSVCPError && Other)224   MSVCPError(MSVCPError &&Other) : Error(std::move(Other)) {}
225 
226   MSVCPError &operator=(MSVCPError Other) {
227     Error::operator=(std::move(Other));
228     return *this;
229   }
230 
MSVCPError(Error Err)231   MSVCPError(Error Err) : Error(std::move(Err)) {}
232 };
233 
234 // Work around MSVC's future implementation, similar to MSVCPError.
235 template <typename T> class MSVCPExpected : public Expected<T> {
236 public:
MSVCPExpected()237   MSVCPExpected()
238       : Expected<T>(make_error<StringError>("", inconvertibleErrorCode())) {
239     consumeError(this->takeError());
240   }
241 
MSVCPExpected(MSVCPExpected && Other)242   MSVCPExpected(MSVCPExpected &&Other) : Expected<T>(std::move(Other)) {}
243 
244   MSVCPExpected &operator=(MSVCPExpected &&Other) {
245     Expected<T>::operator=(std::move(Other));
246     return *this;
247   }
248 
MSVCPExpected(Error Err)249   MSVCPExpected(Error Err) : Expected<T>(std::move(Err)) {}
250 
251   template <typename OtherT>
252   MSVCPExpected(
253       OtherT &&Val,
254       typename std::enable_if<std::is_convertible<OtherT, T>::value>::type * =
255           nullptr)
256       : Expected<T>(std::move(Val)) {}
257 
258   template <class OtherT>
259   MSVCPExpected(
260       Expected<OtherT> &&Other,
261       typename std::enable_if<std::is_convertible<OtherT, T>::value>::type * =
262           nullptr)
263       : Expected<T>(std::move(Other)) {}
264 
265   template <class OtherT>
266   explicit MSVCPExpected(
267       Expected<OtherT> &&Other,
268       typename std::enable_if<!std::is_convertible<OtherT, T>::value>::type * =
269           nullptr)
270       : Expected<T>(std::move(Other)) {}
271 };
272 
273 } // end namespace msvc_hacks
274 
275 #endif // _MSC_VER
276 
277 /// Provides a typedef for a tuple containing the decayed argument types.
278 template <typename T> class FunctionArgsTuple;
279 
280 template <typename RetT, typename... ArgTs>
281 class FunctionArgsTuple<RetT(ArgTs...)> {
282 public:
283   using Type = std::tuple<typename std::decay<
284       typename std::remove_reference<ArgTs>::type>::type...>;
285 };
286 
287 // ResultTraits provides typedefs and utilities specific to the return type
288 // of functions.
289 template <typename RetT> class ResultTraits {
290 public:
291   // The return type wrapped in llvm::Expected.
292   using ErrorReturnType = Expected<RetT>;
293 
294 #ifdef _MSC_VER
295   // The ErrorReturnType wrapped in a std::promise.
296   using ReturnPromiseType = std::promise<msvc_hacks::MSVCPExpected<RetT>>;
297 
298   // The ErrorReturnType wrapped in a std::future.
299   using ReturnFutureType = std::future<msvc_hacks::MSVCPExpected<RetT>>;
300 #else
301   // The ErrorReturnType wrapped in a std::promise.
302   using ReturnPromiseType = std::promise<ErrorReturnType>;
303 
304   // The ErrorReturnType wrapped in a std::future.
305   using ReturnFutureType = std::future<ErrorReturnType>;
306 #endif
307 
308   // Create a 'blank' value of the ErrorReturnType, ready and safe to
309   // overwrite.
createBlankErrorReturnValue()310   static ErrorReturnType createBlankErrorReturnValue() {
311     return ErrorReturnType(RetT());
312   }
313 
314   // Consume an abandoned ErrorReturnType.
consumeAbandoned(ErrorReturnType RetOrErr)315   static void consumeAbandoned(ErrorReturnType RetOrErr) {
316     consumeError(RetOrErr.takeError());
317   }
318 };
319 
320 // ResultTraits specialization for void functions.
321 template <> class ResultTraits<void> {
322 public:
323   // For void functions, ErrorReturnType is llvm::Error.
324   using ErrorReturnType = Error;
325 
326 #ifdef _MSC_VER
327   // The ErrorReturnType wrapped in a std::promise.
328   using ReturnPromiseType = std::promise<msvc_hacks::MSVCPError>;
329 
330   // The ErrorReturnType wrapped in a std::future.
331   using ReturnFutureType = std::future<msvc_hacks::MSVCPError>;
332 #else
333   // The ErrorReturnType wrapped in a std::promise.
334   using ReturnPromiseType = std::promise<ErrorReturnType>;
335 
336   // The ErrorReturnType wrapped in a std::future.
337   using ReturnFutureType = std::future<ErrorReturnType>;
338 #endif
339 
340   // Create a 'blank' value of the ErrorReturnType, ready and safe to
341   // overwrite.
createBlankErrorReturnValue()342   static ErrorReturnType createBlankErrorReturnValue() {
343     return ErrorReturnType::success();
344   }
345 
346   // Consume an abandoned ErrorReturnType.
consumeAbandoned(ErrorReturnType Err)347   static void consumeAbandoned(ErrorReturnType Err) {
348     consumeError(std::move(Err));
349   }
350 };
351 
352 // ResultTraits<Error> is equivalent to ResultTraits<void>. This allows
353 // handlers for void RPC functions to return either void (in which case they
354 // implicitly succeed) or Error (in which case their error return is
355 // propagated). See usage in HandlerTraits::runHandlerHelper.
356 template <> class ResultTraits<Error> : public ResultTraits<void> {};
357 
358 // ResultTraits<Expected<T>> is equivalent to ResultTraits<T>. This allows
359 // handlers for RPC functions returning a T to return either a T (in which
360 // case they implicitly succeed) or Expected<T> (in which case their error
361 // return is propagated). See usage in HandlerTraits::runHandlerHelper.
362 template <typename RetT>
363 class ResultTraits<Expected<RetT>> : public ResultTraits<RetT> {};
364 
365 // Determines whether an RPC function's defined error return type supports
366 // error return value.
367 template <typename T>
368 class SupportsErrorReturn {
369 public:
370   static const bool value = false;
371 };
372 
373 template <>
374 class SupportsErrorReturn<Error> {
375 public:
376   static const bool value = true;
377 };
378 
379 template <typename T>
380 class SupportsErrorReturn<Expected<T>> {
381 public:
382   static const bool value = true;
383 };
384 
385 // RespondHelper packages return values based on whether or not the declared
386 // RPC function return type supports error returns.
387 template <bool FuncSupportsErrorReturn>
388 class RespondHelper;
389 
390 // RespondHelper specialization for functions that support error returns.
391 template <>
392 class RespondHelper<true> {
393 public:
394 
395   // Send Expected<T>.
396   template <typename WireRetT, typename HandlerRetT, typename ChannelT,
397             typename FunctionIdT, typename SequenceNumberT>
sendResult(ChannelT & C,const FunctionIdT & ResponseId,SequenceNumberT SeqNo,Expected<HandlerRetT> ResultOrErr)398   static Error sendResult(ChannelT &C, const FunctionIdT &ResponseId,
399                           SequenceNumberT SeqNo,
400                           Expected<HandlerRetT> ResultOrErr) {
401     if (!ResultOrErr && ResultOrErr.template errorIsA<RPCFatalError>())
402       return ResultOrErr.takeError();
403 
404     // Open the response message.
405     if (auto Err = C.startSendMessage(ResponseId, SeqNo))
406       return Err;
407 
408     // Serialize the result.
409     if (auto Err =
410         SerializationTraits<ChannelT, WireRetT,
411                             Expected<HandlerRetT>>::serialize(
412                                                      C, std::move(ResultOrErr)))
413       return Err;
414 
415     // Close the response message.
416     return C.endSendMessage();
417   }
418 
419   template <typename ChannelT, typename FunctionIdT, typename SequenceNumberT>
sendResult(ChannelT & C,const FunctionIdT & ResponseId,SequenceNumberT SeqNo,Error Err)420   static Error sendResult(ChannelT &C, const FunctionIdT &ResponseId,
421                           SequenceNumberT SeqNo, Error Err) {
422     if (Err && Err.isA<RPCFatalError>())
423       return Err;
424     if (auto Err2 = C.startSendMessage(ResponseId, SeqNo))
425       return Err2;
426     if (auto Err2 = serializeSeq(C, std::move(Err)))
427       return Err2;
428     return C.endSendMessage();
429   }
430 
431 };
432 
433 // RespondHelper specialization for functions that do not support error returns.
434 template <>
435 class RespondHelper<false> {
436 public:
437 
438   template <typename WireRetT, typename HandlerRetT, typename ChannelT,
439             typename FunctionIdT, typename SequenceNumberT>
sendResult(ChannelT & C,const FunctionIdT & ResponseId,SequenceNumberT SeqNo,Expected<HandlerRetT> ResultOrErr)440   static Error sendResult(ChannelT &C, const FunctionIdT &ResponseId,
441                           SequenceNumberT SeqNo,
442                           Expected<HandlerRetT> ResultOrErr) {
443     if (auto Err = ResultOrErr.takeError())
444       return Err;
445 
446     // Open the response message.
447     if (auto Err = C.startSendMessage(ResponseId, SeqNo))
448       return Err;
449 
450     // Serialize the result.
451     if (auto Err =
452         SerializationTraits<ChannelT, WireRetT, HandlerRetT>::serialize(
453                                                                C, *ResultOrErr))
454       return Err;
455 
456     // Close the response message.
457     return C.endSendMessage();
458   }
459 
460   template <typename ChannelT, typename FunctionIdT, typename SequenceNumberT>
sendResult(ChannelT & C,const FunctionIdT & ResponseId,SequenceNumberT SeqNo,Error Err)461   static Error sendResult(ChannelT &C, const FunctionIdT &ResponseId,
462                           SequenceNumberT SeqNo, Error Err) {
463     if (Err)
464       return Err;
465     if (auto Err2 = C.startSendMessage(ResponseId, SeqNo))
466       return Err2;
467     return C.endSendMessage();
468   }
469 
470 };
471 
472 
473 // Send a response of the given wire return type (WireRetT) over the
474 // channel, with the given sequence number.
475 template <typename WireRetT, typename HandlerRetT, typename ChannelT,
476           typename FunctionIdT, typename SequenceNumberT>
respond(ChannelT & C,const FunctionIdT & ResponseId,SequenceNumberT SeqNo,Expected<HandlerRetT> ResultOrErr)477 Error respond(ChannelT &C, const FunctionIdT &ResponseId,
478               SequenceNumberT SeqNo, Expected<HandlerRetT> ResultOrErr) {
479   return RespondHelper<SupportsErrorReturn<WireRetT>::value>::
480     template sendResult<WireRetT>(C, ResponseId, SeqNo, std::move(ResultOrErr));
481 }
482 
483 // Send an empty response message on the given channel to indicate that
484 // the handler ran.
485 template <typename WireRetT, typename ChannelT, typename FunctionIdT,
486           typename SequenceNumberT>
respond(ChannelT & C,const FunctionIdT & ResponseId,SequenceNumberT SeqNo,Error Err)487 Error respond(ChannelT &C, const FunctionIdT &ResponseId, SequenceNumberT SeqNo,
488               Error Err) {
489   return RespondHelper<SupportsErrorReturn<WireRetT>::value>::
490     sendResult(C, ResponseId, SeqNo, std::move(Err));
491 }
492 
493 // Converts a given type to the equivalent error return type.
494 template <typename T> class WrappedHandlerReturn {
495 public:
496   using Type = Expected<T>;
497 };
498 
499 template <typename T> class WrappedHandlerReturn<Expected<T>> {
500 public:
501   using Type = Expected<T>;
502 };
503 
504 template <> class WrappedHandlerReturn<void> {
505 public:
506   using Type = Error;
507 };
508 
509 template <> class WrappedHandlerReturn<Error> {
510 public:
511   using Type = Error;
512 };
513 
514 template <> class WrappedHandlerReturn<ErrorSuccess> {
515 public:
516   using Type = Error;
517 };
518 
519 // Traits class that strips the response function from the list of handler
520 // arguments.
521 template <typename FnT> class AsyncHandlerTraits;
522 
523 template <typename ResultT, typename... ArgTs>
524 class AsyncHandlerTraits<Error(std::function<Error(Expected<ResultT>)>, ArgTs...)> {
525 public:
526   using Type = Error(ArgTs...);
527   using ResultType = Expected<ResultT>;
528 };
529 
530 template <typename... ArgTs>
531 class AsyncHandlerTraits<Error(std::function<Error(Error)>, ArgTs...)> {
532 public:
533   using Type = Error(ArgTs...);
534   using ResultType = Error;
535 };
536 
537 template <typename... ArgTs>
538 class AsyncHandlerTraits<ErrorSuccess(std::function<Error(Error)>, ArgTs...)> {
539 public:
540   using Type = Error(ArgTs...);
541   using ResultType = Error;
542 };
543 
544 template <typename... ArgTs>
545 class AsyncHandlerTraits<void(std::function<Error(Error)>, ArgTs...)> {
546 public:
547   using Type = Error(ArgTs...);
548   using ResultType = Error;
549 };
550 
551 template <typename ResponseHandlerT, typename... ArgTs>
552 class AsyncHandlerTraits<Error(ResponseHandlerT, ArgTs...)> :
553     public AsyncHandlerTraits<Error(typename std::decay<ResponseHandlerT>::type,
554                                     ArgTs...)> {};
555 
556 // This template class provides utilities related to RPC function handlers.
557 // The base case applies to non-function types (the template class is
558 // specialized for function types) and inherits from the appropriate
559 // speciilization for the given non-function type's call operator.
560 template <typename HandlerT>
561 class HandlerTraits : public HandlerTraits<decltype(
562                           &std::remove_reference<HandlerT>::type::operator())> {
563 };
564 
565 // Traits for handlers with a given function type.
566 template <typename RetT, typename... ArgTs>
567 class HandlerTraits<RetT(ArgTs...)> {
568 public:
569   // Function type of the handler.
570   using Type = RetT(ArgTs...);
571 
572   // Return type of the handler.
573   using ReturnType = RetT;
574 
575   // Call the given handler with the given arguments.
576   template <typename HandlerT, typename... TArgTs>
577   static typename WrappedHandlerReturn<RetT>::Type
unpackAndRun(HandlerT & Handler,std::tuple<TArgTs...> & Args)578   unpackAndRun(HandlerT &Handler, std::tuple<TArgTs...> &Args) {
579     return unpackAndRunHelper(Handler, Args,
580                               llvm::index_sequence_for<TArgTs...>());
581   }
582 
583   // Call the given handler with the given arguments.
584   template <typename HandlerT, typename ResponderT, typename... TArgTs>
unpackAndRunAsync(HandlerT & Handler,ResponderT & Responder,std::tuple<TArgTs...> & Args)585   static Error unpackAndRunAsync(HandlerT &Handler, ResponderT &Responder,
586                                  std::tuple<TArgTs...> &Args) {
587     return unpackAndRunAsyncHelper(Handler, Responder, Args,
588                                    llvm::index_sequence_for<TArgTs...>());
589   }
590 
591   // Call the given handler with the given arguments.
592   template <typename HandlerT>
593   static typename std::enable_if<
594       std::is_void<typename HandlerTraits<HandlerT>::ReturnType>::value,
595       Error>::type
run(HandlerT & Handler,ArgTs &&...Args)596   run(HandlerT &Handler, ArgTs &&... Args) {
597     Handler(std::move(Args)...);
598     return Error::success();
599   }
600 
601   template <typename HandlerT, typename... TArgTs>
602   static typename std::enable_if<
603       !std::is_void<typename HandlerTraits<HandlerT>::ReturnType>::value,
604       typename HandlerTraits<HandlerT>::ReturnType>::type
run(HandlerT & Handler,TArgTs...Args)605   run(HandlerT &Handler, TArgTs... Args) {
606     return Handler(std::move(Args)...);
607   }
608 
609   // Serialize arguments to the channel.
610   template <typename ChannelT, typename... CArgTs>
serializeArgs(ChannelT & C,const CArgTs...CArgs)611   static Error serializeArgs(ChannelT &C, const CArgTs... CArgs) {
612     return SequenceSerialization<ChannelT, ArgTs...>::serialize(C, CArgs...);
613   }
614 
615   // Deserialize arguments from the channel.
616   template <typename ChannelT, typename... CArgTs>
deserializeArgs(ChannelT & C,std::tuple<CArgTs...> & Args)617   static Error deserializeArgs(ChannelT &C, std::tuple<CArgTs...> &Args) {
618     return deserializeArgsHelper(C, Args,
619                                  llvm::index_sequence_for<CArgTs...>());
620   }
621 
622 private:
623   template <typename ChannelT, typename... CArgTs, size_t... Indexes>
deserializeArgsHelper(ChannelT & C,std::tuple<CArgTs...> & Args,llvm::index_sequence<Indexes...> _)624   static Error deserializeArgsHelper(ChannelT &C, std::tuple<CArgTs...> &Args,
625                                      llvm::index_sequence<Indexes...> _) {
626     return SequenceSerialization<ChannelT, ArgTs...>::deserialize(
627         C, std::get<Indexes>(Args)...);
628   }
629 
630   template <typename HandlerT, typename ArgTuple, size_t... Indexes>
631   static typename WrappedHandlerReturn<
632       typename HandlerTraits<HandlerT>::ReturnType>::Type
unpackAndRunHelper(HandlerT & Handler,ArgTuple & Args,llvm::index_sequence<Indexes...>)633   unpackAndRunHelper(HandlerT &Handler, ArgTuple &Args,
634                      llvm::index_sequence<Indexes...>) {
635     return run(Handler, std::move(std::get<Indexes>(Args))...);
636   }
637 
638 
639   template <typename HandlerT, typename ResponderT, typename ArgTuple,
640             size_t... Indexes>
641   static typename WrappedHandlerReturn<
642       typename HandlerTraits<HandlerT>::ReturnType>::Type
unpackAndRunAsyncHelper(HandlerT & Handler,ResponderT & Responder,ArgTuple & Args,llvm::index_sequence<Indexes...>)643   unpackAndRunAsyncHelper(HandlerT &Handler, ResponderT &Responder,
644                           ArgTuple &Args,
645                           llvm::index_sequence<Indexes...>) {
646     return run(Handler, Responder, std::move(std::get<Indexes>(Args))...);
647   }
648 };
649 
650 // Handler traits for free functions.
651 template <typename RetT, typename... ArgTs>
652 class HandlerTraits<RetT(*)(ArgTs...)>
653   : public HandlerTraits<RetT(ArgTs...)> {};
654 
655 // Handler traits for class methods (especially call operators for lambdas).
656 template <typename Class, typename RetT, typename... ArgTs>
657 class HandlerTraits<RetT (Class::*)(ArgTs...)>
658     : public HandlerTraits<RetT(ArgTs...)> {};
659 
660 // Handler traits for const class methods (especially call operators for
661 // lambdas).
662 template <typename Class, typename RetT, typename... ArgTs>
663 class HandlerTraits<RetT (Class::*)(ArgTs...) const>
664     : public HandlerTraits<RetT(ArgTs...)> {};
665 
666 // Utility to peel the Expected wrapper off a response handler error type.
667 template <typename HandlerT> class ResponseHandlerArg;
668 
669 template <typename ArgT> class ResponseHandlerArg<Error(Expected<ArgT>)> {
670 public:
671   using ArgType = Expected<ArgT>;
672   using UnwrappedArgType = ArgT;
673 };
674 
675 template <typename ArgT>
676 class ResponseHandlerArg<ErrorSuccess(Expected<ArgT>)> {
677 public:
678   using ArgType = Expected<ArgT>;
679   using UnwrappedArgType = ArgT;
680 };
681 
682 template <> class ResponseHandlerArg<Error(Error)> {
683 public:
684   using ArgType = Error;
685 };
686 
687 template <> class ResponseHandlerArg<ErrorSuccess(Error)> {
688 public:
689   using ArgType = Error;
690 };
691 
692 // ResponseHandler represents a handler for a not-yet-received function call
693 // result.
694 template <typename ChannelT> class ResponseHandler {
695 public:
~ResponseHandler()696   virtual ~ResponseHandler() {}
697 
698   // Reads the function result off the wire and acts on it. The meaning of
699   // "act" will depend on how this method is implemented in any given
700   // ResponseHandler subclass but could, for example, mean running a
701   // user-specified handler or setting a promise value.
702   virtual Error handleResponse(ChannelT &C) = 0;
703 
704   // Abandons this outstanding result.
705   virtual void abandon() = 0;
706 
707   // Create an error instance representing an abandoned response.
createAbandonedResponseError()708   static Error createAbandonedResponseError() {
709     return make_error<ResponseAbandoned>();
710   }
711 };
712 
713 // ResponseHandler subclass for RPC functions with non-void returns.
714 template <typename ChannelT, typename FuncRetT, typename HandlerT>
715 class ResponseHandlerImpl : public ResponseHandler<ChannelT> {
716 public:
ResponseHandlerImpl(HandlerT Handler)717   ResponseHandlerImpl(HandlerT Handler) : Handler(std::move(Handler)) {}
718 
719   // Handle the result by deserializing it from the channel then passing it
720   // to the user defined handler.
handleResponse(ChannelT & C)721   Error handleResponse(ChannelT &C) override {
722     using UnwrappedArgType = typename ResponseHandlerArg<
723         typename HandlerTraits<HandlerT>::Type>::UnwrappedArgType;
724     UnwrappedArgType Result;
725     if (auto Err =
726             SerializationTraits<ChannelT, FuncRetT,
727                                 UnwrappedArgType>::deserialize(C, Result))
728       return Err;
729     if (auto Err = C.endReceiveMessage())
730       return Err;
731     return Handler(std::move(Result));
732   }
733 
734   // Abandon this response by calling the handler with an 'abandoned response'
735   // error.
abandon()736   void abandon() override {
737     if (auto Err = Handler(this->createAbandonedResponseError())) {
738       // Handlers should not fail when passed an abandoned response error.
739       report_fatal_error(std::move(Err));
740     }
741   }
742 
743 private:
744   HandlerT Handler;
745 };
746 
747 // ResponseHandler subclass for RPC functions with void returns.
748 template <typename ChannelT, typename HandlerT>
749 class ResponseHandlerImpl<ChannelT, void, HandlerT>
750     : public ResponseHandler<ChannelT> {
751 public:
ResponseHandlerImpl(HandlerT Handler)752   ResponseHandlerImpl(HandlerT Handler) : Handler(std::move(Handler)) {}
753 
754   // Handle the result (no actual value, just a notification that the function
755   // has completed on the remote end) by calling the user-defined handler with
756   // Error::success().
handleResponse(ChannelT & C)757   Error handleResponse(ChannelT &C) override {
758     if (auto Err = C.endReceiveMessage())
759       return Err;
760     return Handler(Error::success());
761   }
762 
763   // Abandon this response by calling the handler with an 'abandoned response'
764   // error.
abandon()765   void abandon() override {
766     if (auto Err = Handler(this->createAbandonedResponseError())) {
767       // Handlers should not fail when passed an abandoned response error.
768       report_fatal_error(std::move(Err));
769     }
770   }
771 
772 private:
773   HandlerT Handler;
774 };
775 
776 template <typename ChannelT, typename FuncRetT, typename HandlerT>
777 class ResponseHandlerImpl<ChannelT, Expected<FuncRetT>, HandlerT>
778     : public ResponseHandler<ChannelT> {
779 public:
ResponseHandlerImpl(HandlerT Handler)780   ResponseHandlerImpl(HandlerT Handler) : Handler(std::move(Handler)) {}
781 
782   // Handle the result by deserializing it from the channel then passing it
783   // to the user defined handler.
handleResponse(ChannelT & C)784   Error handleResponse(ChannelT &C) override {
785     using HandlerArgType = typename ResponseHandlerArg<
786         typename HandlerTraits<HandlerT>::Type>::ArgType;
787     HandlerArgType Result((typename HandlerArgType::value_type()));
788 
789     if (auto Err =
790             SerializationTraits<ChannelT, Expected<FuncRetT>,
791                                 HandlerArgType>::deserialize(C, Result))
792       return Err;
793     if (auto Err = C.endReceiveMessage())
794       return Err;
795     return Handler(std::move(Result));
796   }
797 
798   // Abandon this response by calling the handler with an 'abandoned response'
799   // error.
abandon()800   void abandon() override {
801     if (auto Err = Handler(this->createAbandonedResponseError())) {
802       // Handlers should not fail when passed an abandoned response error.
803       report_fatal_error(std::move(Err));
804     }
805   }
806 
807 private:
808   HandlerT Handler;
809 };
810 
811 template <typename ChannelT, typename HandlerT>
812 class ResponseHandlerImpl<ChannelT, Error, HandlerT>
813     : public ResponseHandler<ChannelT> {
814 public:
ResponseHandlerImpl(HandlerT Handler)815   ResponseHandlerImpl(HandlerT Handler) : Handler(std::move(Handler)) {}
816 
817   // Handle the result by deserializing it from the channel then passing it
818   // to the user defined handler.
handleResponse(ChannelT & C)819   Error handleResponse(ChannelT &C) override {
820     Error Result = Error::success();
821     if (auto Err =
822             SerializationTraits<ChannelT, Error, Error>::deserialize(C, Result))
823       return Err;
824     if (auto Err = C.endReceiveMessage())
825       return Err;
826     return Handler(std::move(Result));
827   }
828 
829   // Abandon this response by calling the handler with an 'abandoned response'
830   // error.
abandon()831   void abandon() override {
832     if (auto Err = Handler(this->createAbandonedResponseError())) {
833       // Handlers should not fail when passed an abandoned response error.
834       report_fatal_error(std::move(Err));
835     }
836   }
837 
838 private:
839   HandlerT Handler;
840 };
841 
842 // Create a ResponseHandler from a given user handler.
843 template <typename ChannelT, typename FuncRetT, typename HandlerT>
createResponseHandler(HandlerT H)844 std::unique_ptr<ResponseHandler<ChannelT>> createResponseHandler(HandlerT H) {
845   return llvm::make_unique<ResponseHandlerImpl<ChannelT, FuncRetT, HandlerT>>(
846       std::move(H));
847 }
848 
849 // Helper for wrapping member functions up as functors. This is useful for
850 // installing methods as result handlers.
851 template <typename ClassT, typename RetT, typename... ArgTs>
852 class MemberFnWrapper {
853 public:
854   using MethodT = RetT (ClassT::*)(ArgTs...);
MemberFnWrapper(ClassT & Instance,MethodT Method)855   MemberFnWrapper(ClassT &Instance, MethodT Method)
856       : Instance(Instance), Method(Method) {}
operator()857   RetT operator()(ArgTs &&... Args) {
858     return (Instance.*Method)(std::move(Args)...);
859   }
860 
861 private:
862   ClassT &Instance;
863   MethodT Method;
864 };
865 
866 // Helper that provides a Functor for deserializing arguments.
867 template <typename... ArgTs> class ReadArgs {
868 public:
operator()869   Error operator()() { return Error::success(); }
870 };
871 
872 template <typename ArgT, typename... ArgTs>
873 class ReadArgs<ArgT, ArgTs...> : public ReadArgs<ArgTs...> {
874 public:
ReadArgs(ArgT & Arg,ArgTs &...Args)875   ReadArgs(ArgT &Arg, ArgTs &... Args)
876       : ReadArgs<ArgTs...>(Args...), Arg(Arg) {}
877 
operator()878   Error operator()(ArgT &ArgVal, ArgTs &... ArgVals) {
879     this->Arg = std::move(ArgVal);
880     return ReadArgs<ArgTs...>::operator()(ArgVals...);
881   }
882 
883 private:
884   ArgT &Arg;
885 };
886 
887 // Manage sequence numbers.
888 template <typename SequenceNumberT> class SequenceNumberManager {
889 public:
890   // Reset, making all sequence numbers available.
reset()891   void reset() {
892     std::lock_guard<std::mutex> Lock(SeqNoLock);
893     NextSequenceNumber = 0;
894     FreeSequenceNumbers.clear();
895   }
896 
897   // Get the next available sequence number. Will re-use numbers that have
898   // been released.
getSequenceNumber()899   SequenceNumberT getSequenceNumber() {
900     std::lock_guard<std::mutex> Lock(SeqNoLock);
901     if (FreeSequenceNumbers.empty())
902       return NextSequenceNumber++;
903     auto SequenceNumber = FreeSequenceNumbers.back();
904     FreeSequenceNumbers.pop_back();
905     return SequenceNumber;
906   }
907 
908   // Release a sequence number, making it available for re-use.
releaseSequenceNumber(SequenceNumberT SequenceNumber)909   void releaseSequenceNumber(SequenceNumberT SequenceNumber) {
910     std::lock_guard<std::mutex> Lock(SeqNoLock);
911     FreeSequenceNumbers.push_back(SequenceNumber);
912   }
913 
914 private:
915   std::mutex SeqNoLock;
916   SequenceNumberT NextSequenceNumber = 0;
917   std::vector<SequenceNumberT> FreeSequenceNumbers;
918 };
919 
920 // Checks that predicate P holds for each corresponding pair of type arguments
921 // from T1 and T2 tuple.
922 template <template <class, class> class P, typename T1Tuple, typename T2Tuple>
923 class RPCArgTypeCheckHelper;
924 
925 template <template <class, class> class P>
926 class RPCArgTypeCheckHelper<P, std::tuple<>, std::tuple<>> {
927 public:
928   static const bool value = true;
929 };
930 
931 template <template <class, class> class P, typename T, typename... Ts,
932           typename U, typename... Us>
933 class RPCArgTypeCheckHelper<P, std::tuple<T, Ts...>, std::tuple<U, Us...>> {
934 public:
935   static const bool value =
936       P<T, U>::value &&
937       RPCArgTypeCheckHelper<P, std::tuple<Ts...>, std::tuple<Us...>>::value;
938 };
939 
940 template <template <class, class> class P, typename T1Sig, typename T2Sig>
941 class RPCArgTypeCheck {
942 public:
943   using T1Tuple = typename FunctionArgsTuple<T1Sig>::Type;
944   using T2Tuple = typename FunctionArgsTuple<T2Sig>::Type;
945 
946   static_assert(std::tuple_size<T1Tuple>::value >=
947                     std::tuple_size<T2Tuple>::value,
948                 "Too many arguments to RPC call");
949   static_assert(std::tuple_size<T1Tuple>::value <=
950                     std::tuple_size<T2Tuple>::value,
951                 "Too few arguments to RPC call");
952 
953   static const bool value = RPCArgTypeCheckHelper<P, T1Tuple, T2Tuple>::value;
954 };
955 
956 template <typename ChannelT, typename WireT, typename ConcreteT>
957 class CanSerialize {
958 private:
959   using S = SerializationTraits<ChannelT, WireT, ConcreteT>;
960 
961   template <typename T>
962   static std::true_type
963   check(typename std::enable_if<
964         std::is_same<decltype(T::serialize(std::declval<ChannelT &>(),
965                                            std::declval<const ConcreteT &>())),
966                      Error>::value,
967         void *>::type);
968 
969   template <typename> static std::false_type check(...);
970 
971 public:
972   static const bool value = decltype(check<S>(0))::value;
973 };
974 
975 template <typename ChannelT, typename WireT, typename ConcreteT>
976 class CanDeserialize {
977 private:
978   using S = SerializationTraits<ChannelT, WireT, ConcreteT>;
979 
980   template <typename T>
981   static std::true_type
982   check(typename std::enable_if<
983         std::is_same<decltype(T::deserialize(std::declval<ChannelT &>(),
984                                              std::declval<ConcreteT &>())),
985                      Error>::value,
986         void *>::type);
987 
988   template <typename> static std::false_type check(...);
989 
990 public:
991   static const bool value = decltype(check<S>(0))::value;
992 };
993 
994 /// Contains primitive utilities for defining, calling and handling calls to
995 /// remote procedures. ChannelT is a bidirectional stream conforming to the
996 /// RPCChannel interface (see RPCChannel.h), FunctionIdT is a procedure
997 /// identifier type that must be serializable on ChannelT, and SequenceNumberT
998 /// is an integral type that will be used to number in-flight function calls.
999 ///
1000 /// These utilities support the construction of very primitive RPC utilities.
1001 /// Their intent is to ensure correct serialization and deserialization of
1002 /// procedure arguments, and to keep the client and server's view of the API in
1003 /// sync.
1004 template <typename ImplT, typename ChannelT, typename FunctionIdT,
1005           typename SequenceNumberT>
1006 class RPCEndpointBase {
1007 protected:
1008   class OrcRPCInvalid : public Function<OrcRPCInvalid, void()> {
1009   public:
getName()1010     static const char *getName() { return "__orc_rpc$invalid"; }
1011   };
1012 
1013   class OrcRPCResponse : public Function<OrcRPCResponse, void()> {
1014   public:
getName()1015     static const char *getName() { return "__orc_rpc$response"; }
1016   };
1017 
1018   class OrcRPCNegotiate
1019       : public Function<OrcRPCNegotiate, FunctionIdT(std::string)> {
1020   public:
getName()1021     static const char *getName() { return "__orc_rpc$negotiate"; }
1022   };
1023 
1024   // Helper predicate for testing for the presence of SerializeTraits
1025   // serializers.
1026   template <typename WireT, typename ConcreteT>
1027   class CanSerializeCheck : detail::CanSerialize<ChannelT, WireT, ConcreteT> {
1028   public:
1029     using detail::CanSerialize<ChannelT, WireT, ConcreteT>::value;
1030 
1031     static_assert(value, "Missing serializer for argument (Can't serialize the "
1032                          "first template type argument of CanSerializeCheck "
1033                          "from the second)");
1034   };
1035 
1036   // Helper predicate for testing for the presence of SerializeTraits
1037   // deserializers.
1038   template <typename WireT, typename ConcreteT>
1039   class CanDeserializeCheck
1040       : detail::CanDeserialize<ChannelT, WireT, ConcreteT> {
1041   public:
1042     using detail::CanDeserialize<ChannelT, WireT, ConcreteT>::value;
1043 
1044     static_assert(value, "Missing deserializer for argument (Can't deserialize "
1045                          "the second template type argument of "
1046                          "CanDeserializeCheck from the first)");
1047   };
1048 
1049 public:
1050   /// Construct an RPC instance on a channel.
RPCEndpointBase(ChannelT & C,bool LazyAutoNegotiation)1051   RPCEndpointBase(ChannelT &C, bool LazyAutoNegotiation)
1052       : C(C), LazyAutoNegotiation(LazyAutoNegotiation) {
1053     // Hold ResponseId in a special variable, since we expect Response to be
1054     // called relatively frequently, and want to avoid the map lookup.
1055     ResponseId = FnIdAllocator.getResponseId();
1056     RemoteFunctionIds[OrcRPCResponse::getPrototype()] = ResponseId;
1057 
1058     // Register the negotiate function id and handler.
1059     auto NegotiateId = FnIdAllocator.getNegotiateId();
1060     RemoteFunctionIds[OrcRPCNegotiate::getPrototype()] = NegotiateId;
1061     Handlers[NegotiateId] = wrapHandler<OrcRPCNegotiate>(
1062         [this](const std::string &Name) { return handleNegotiate(Name); });
1063   }
1064 
1065 
1066   /// Negotiate a function id for Func with the other end of the channel.
1067   template <typename Func> Error negotiateFunction(bool Retry = false) {
1068     return getRemoteFunctionId<Func>(true, Retry).takeError();
1069   }
1070 
1071   /// Append a call Func, does not call send on the channel.
1072   /// The first argument specifies a user-defined handler to be run when the
1073   /// function returns. The handler should take an Expected<Func::ReturnType>,
1074   /// or an Error (if Func::ReturnType is void). The handler will be called
1075   /// with an error if the return value is abandoned due to a channel error.
1076   template <typename Func, typename HandlerT, typename... ArgTs>
appendCallAsync(HandlerT Handler,const ArgTs &...Args)1077   Error appendCallAsync(HandlerT Handler, const ArgTs &... Args) {
1078 
1079     static_assert(
1080         detail::RPCArgTypeCheck<CanSerializeCheck, typename Func::Type,
1081                                 void(ArgTs...)>::value,
1082         "");
1083 
1084     // Look up the function ID.
1085     FunctionIdT FnId;
1086     if (auto FnIdOrErr = getRemoteFunctionId<Func>(LazyAutoNegotiation, false))
1087       FnId = *FnIdOrErr;
1088     else {
1089       // Negotiation failed. Notify the handler then return the negotiate-failed
1090       // error.
1091       cantFail(Handler(make_error<ResponseAbandoned>()));
1092       return FnIdOrErr.takeError();
1093     }
1094 
1095     SequenceNumberT SeqNo; // initialized in locked scope below.
1096     {
1097       // Lock the pending responses map and sequence number manager.
1098       std::lock_guard<std::mutex> Lock(ResponsesMutex);
1099 
1100       // Allocate a sequence number.
1101       SeqNo = SequenceNumberMgr.getSequenceNumber();
1102       assert(!PendingResponses.count(SeqNo) &&
1103              "Sequence number already allocated");
1104 
1105       // Install the user handler.
1106       PendingResponses[SeqNo] =
1107         detail::createResponseHandler<ChannelT, typename Func::ReturnType>(
1108             std::move(Handler));
1109     }
1110 
1111     // Open the function call message.
1112     if (auto Err = C.startSendMessage(FnId, SeqNo)) {
1113       abandonPendingResponses();
1114       return Err;
1115     }
1116 
1117     // Serialize the call arguments.
1118     if (auto Err = detail::HandlerTraits<typename Func::Type>::serializeArgs(
1119             C, Args...)) {
1120       abandonPendingResponses();
1121       return Err;
1122     }
1123 
1124     // Close the function call messagee.
1125     if (auto Err = C.endSendMessage()) {
1126       abandonPendingResponses();
1127       return Err;
1128     }
1129 
1130     return Error::success();
1131   }
1132 
sendAppendedCalls()1133   Error sendAppendedCalls() { return C.send(); };
1134 
1135   template <typename Func, typename HandlerT, typename... ArgTs>
callAsync(HandlerT Handler,const ArgTs &...Args)1136   Error callAsync(HandlerT Handler, const ArgTs &... Args) {
1137     if (auto Err = appendCallAsync<Func>(std::move(Handler), Args...))
1138       return Err;
1139     return C.send();
1140   }
1141 
1142   /// Handle one incoming call.
handleOne()1143   Error handleOne() {
1144     FunctionIdT FnId;
1145     SequenceNumberT SeqNo;
1146     if (auto Err = C.startReceiveMessage(FnId, SeqNo)) {
1147       abandonPendingResponses();
1148       return Err;
1149     }
1150     if (FnId == ResponseId)
1151       return handleResponse(SeqNo);
1152     auto I = Handlers.find(FnId);
1153     if (I != Handlers.end())
1154       return I->second(C, SeqNo);
1155 
1156     // else: No handler found. Report error to client?
1157     return make_error<BadFunctionCall<FunctionIdT, SequenceNumberT>>(FnId,
1158                                                                      SeqNo);
1159   }
1160 
1161   /// Helper for handling setter procedures - this method returns a functor that
1162   /// sets the variables referred to by Args... to values deserialized from the
1163   /// channel.
1164   /// E.g.
1165   ///
1166   ///   typedef Function<0, bool, int> Func1;
1167   ///
1168   ///   ...
1169   ///   bool B;
1170   ///   int I;
1171   ///   if (auto Err = expect<Func1>(Channel, readArgs(B, I)))
1172   ///     /* Handle Args */ ;
1173   ///
1174   template <typename... ArgTs>
readArgs(ArgTs &...Args)1175   static detail::ReadArgs<ArgTs...> readArgs(ArgTs &... Args) {
1176     return detail::ReadArgs<ArgTs...>(Args...);
1177   }
1178 
1179   /// Abandon all outstanding result handlers.
1180   ///
1181   /// This will call all currently registered result handlers to receive an
1182   /// "abandoned" error as their argument. This is used internally by the RPC
1183   /// in error situations, but can also be called directly by clients who are
1184   /// disconnecting from the remote and don't or can't expect responses to their
1185   /// outstanding calls. (Especially for outstanding blocking calls, calling
1186   /// this function may be necessary to avoid dead threads).
abandonPendingResponses()1187   void abandonPendingResponses() {
1188     // Lock the pending responses map and sequence number manager.
1189     std::lock_guard<std::mutex> Lock(ResponsesMutex);
1190 
1191     for (auto &KV : PendingResponses)
1192       KV.second->abandon();
1193     PendingResponses.clear();
1194     SequenceNumberMgr.reset();
1195   }
1196 
1197   /// Remove the handler for the given function.
1198   /// A handler must currently be registered for this function.
1199   template <typename Func>
removeHandler()1200   void removeHandler() {
1201     auto IdItr = LocalFunctionIds.find(Func::getPrototype());
1202     assert(IdItr != LocalFunctionIds.end() &&
1203            "Function does not have a registered handler");
1204     auto HandlerItr = Handlers.find(IdItr->second);
1205     assert(HandlerItr != Handlers.end() &&
1206            "Function does not have a registered handler");
1207     Handlers.erase(HandlerItr);
1208   }
1209 
1210   /// Clear all handlers.
clearHandlers()1211   void clearHandlers() {
1212     Handlers.clear();
1213   }
1214 
1215 protected:
1216 
getInvalidFunctionId()1217   FunctionIdT getInvalidFunctionId() const {
1218     return FnIdAllocator.getInvalidId();
1219   }
1220 
1221   /// Add the given handler to the handler map and make it available for
1222   /// autonegotiation and execution.
1223   template <typename Func, typename HandlerT>
addHandlerImpl(HandlerT Handler)1224   void addHandlerImpl(HandlerT Handler) {
1225 
1226     static_assert(detail::RPCArgTypeCheck<
1227                       CanDeserializeCheck, typename Func::Type,
1228                       typename detail::HandlerTraits<HandlerT>::Type>::value,
1229                   "");
1230 
1231     FunctionIdT NewFnId = FnIdAllocator.template allocate<Func>();
1232     LocalFunctionIds[Func::getPrototype()] = NewFnId;
1233     Handlers[NewFnId] = wrapHandler<Func>(std::move(Handler));
1234   }
1235 
1236   template <typename Func, typename HandlerT>
addAsyncHandlerImpl(HandlerT Handler)1237   void addAsyncHandlerImpl(HandlerT Handler) {
1238 
1239     static_assert(detail::RPCArgTypeCheck<
1240                       CanDeserializeCheck, typename Func::Type,
1241                       typename detail::AsyncHandlerTraits<
1242                         typename detail::HandlerTraits<HandlerT>::Type
1243                       >::Type>::value,
1244                   "");
1245 
1246     FunctionIdT NewFnId = FnIdAllocator.template allocate<Func>();
1247     LocalFunctionIds[Func::getPrototype()] = NewFnId;
1248     Handlers[NewFnId] = wrapAsyncHandler<Func>(std::move(Handler));
1249   }
1250 
handleResponse(SequenceNumberT SeqNo)1251   Error handleResponse(SequenceNumberT SeqNo) {
1252     using Handler = typename decltype(PendingResponses)::mapped_type;
1253     Handler PRHandler;
1254 
1255     {
1256       // Lock the pending responses map and sequence number manager.
1257       std::unique_lock<std::mutex> Lock(ResponsesMutex);
1258       auto I = PendingResponses.find(SeqNo);
1259 
1260       if (I != PendingResponses.end()) {
1261         PRHandler = std::move(I->second);
1262         PendingResponses.erase(I);
1263         SequenceNumberMgr.releaseSequenceNumber(SeqNo);
1264       } else {
1265         // Unlock the pending results map to prevent recursive lock.
1266         Lock.unlock();
1267         abandonPendingResponses();
1268         return make_error<
1269                  InvalidSequenceNumberForResponse<SequenceNumberT>>(SeqNo);
1270       }
1271     }
1272 
1273     assert(PRHandler &&
1274            "If we didn't find a response handler we should have bailed out");
1275 
1276     if (auto Err = PRHandler->handleResponse(C)) {
1277       abandonPendingResponses();
1278       return Err;
1279     }
1280 
1281     return Error::success();
1282   }
1283 
handleNegotiate(const std::string & Name)1284   FunctionIdT handleNegotiate(const std::string &Name) {
1285     auto I = LocalFunctionIds.find(Name);
1286     if (I == LocalFunctionIds.end())
1287       return getInvalidFunctionId();
1288     return I->second;
1289   }
1290 
1291   // Find the remote FunctionId for the given function.
1292   template <typename Func>
getRemoteFunctionId(bool NegotiateIfNotInMap,bool NegotiateIfInvalid)1293   Expected<FunctionIdT> getRemoteFunctionId(bool NegotiateIfNotInMap,
1294                                             bool NegotiateIfInvalid) {
1295     bool DoNegotiate;
1296 
1297     // Check if we already have a function id...
1298     auto I = RemoteFunctionIds.find(Func::getPrototype());
1299     if (I != RemoteFunctionIds.end()) {
1300       // If it's valid there's nothing left to do.
1301       if (I->second != getInvalidFunctionId())
1302         return I->second;
1303       DoNegotiate = NegotiateIfInvalid;
1304     } else
1305       DoNegotiate = NegotiateIfNotInMap;
1306 
1307     // We don't have a function id for Func yet, but we're allowed to try to
1308     // negotiate one.
1309     if (DoNegotiate) {
1310       auto &Impl = static_cast<ImplT &>(*this);
1311       if (auto RemoteIdOrErr =
1312           Impl.template callB<OrcRPCNegotiate>(Func::getPrototype())) {
1313         RemoteFunctionIds[Func::getPrototype()] = *RemoteIdOrErr;
1314         if (*RemoteIdOrErr == getInvalidFunctionId())
1315           return make_error<CouldNotNegotiate>(Func::getPrototype());
1316         return *RemoteIdOrErr;
1317       } else
1318         return RemoteIdOrErr.takeError();
1319     }
1320 
1321     // No key was available in the map and we weren't allowed to try to
1322     // negotiate one, so return an unknown function error.
1323     return make_error<CouldNotNegotiate>(Func::getPrototype());
1324   }
1325 
1326   using WrappedHandlerFn = std::function<Error(ChannelT &, SequenceNumberT)>;
1327 
1328   // Wrap the given user handler in the necessary argument-deserialization code,
1329   // result-serialization code, and call to the launch policy (if present).
1330   template <typename Func, typename HandlerT>
wrapHandler(HandlerT Handler)1331   WrappedHandlerFn wrapHandler(HandlerT Handler) {
1332     return [this, Handler](ChannelT &Channel,
1333                            SequenceNumberT SeqNo) mutable -> Error {
1334       // Start by deserializing the arguments.
1335       using ArgsTuple =
1336           typename detail::FunctionArgsTuple<
1337             typename detail::HandlerTraits<HandlerT>::Type>::Type;
1338       auto Args = std::make_shared<ArgsTuple>();
1339 
1340       if (auto Err =
1341               detail::HandlerTraits<typename Func::Type>::deserializeArgs(
1342                   Channel, *Args))
1343         return Err;
1344 
1345       // GCC 4.7 and 4.8 incorrectly issue a -Wunused-but-set-variable warning
1346       // for RPCArgs. Void cast RPCArgs to work around this for now.
1347       // FIXME: Remove this workaround once we can assume a working GCC version.
1348       (void)Args;
1349 
1350       // End receieve message, unlocking the channel for reading.
1351       if (auto Err = Channel.endReceiveMessage())
1352         return Err;
1353 
1354       using HTraits = detail::HandlerTraits<HandlerT>;
1355       using FuncReturn = typename Func::ReturnType;
1356       return detail::respond<FuncReturn>(Channel, ResponseId, SeqNo,
1357                                          HTraits::unpackAndRun(Handler, *Args));
1358     };
1359   }
1360 
1361   // Wrap the given user handler in the necessary argument-deserialization code,
1362   // result-serialization code, and call to the launch policy (if present).
1363   template <typename Func, typename HandlerT>
wrapAsyncHandler(HandlerT Handler)1364   WrappedHandlerFn wrapAsyncHandler(HandlerT Handler) {
1365     return [this, Handler](ChannelT &Channel,
1366                            SequenceNumberT SeqNo) mutable -> Error {
1367       // Start by deserializing the arguments.
1368       using AHTraits = detail::AsyncHandlerTraits<
1369                          typename detail::HandlerTraits<HandlerT>::Type>;
1370       using ArgsTuple =
1371           typename detail::FunctionArgsTuple<typename AHTraits::Type>::Type;
1372       auto Args = std::make_shared<ArgsTuple>();
1373 
1374       if (auto Err =
1375               detail::HandlerTraits<typename Func::Type>::deserializeArgs(
1376                   Channel, *Args))
1377         return Err;
1378 
1379       // GCC 4.7 and 4.8 incorrectly issue a -Wunused-but-set-variable warning
1380       // for RPCArgs. Void cast RPCArgs to work around this for now.
1381       // FIXME: Remove this workaround once we can assume a working GCC version.
1382       (void)Args;
1383 
1384       // End receieve message, unlocking the channel for reading.
1385       if (auto Err = Channel.endReceiveMessage())
1386         return Err;
1387 
1388       using HTraits = detail::HandlerTraits<HandlerT>;
1389       using FuncReturn = typename Func::ReturnType;
1390       auto Responder =
1391         [this, SeqNo](typename AHTraits::ResultType RetVal) -> Error {
1392           return detail::respond<FuncReturn>(C, ResponseId, SeqNo,
1393                                              std::move(RetVal));
1394         };
1395 
1396       return HTraits::unpackAndRunAsync(Handler, Responder, *Args);
1397     };
1398   }
1399 
1400   ChannelT &C;
1401 
1402   bool LazyAutoNegotiation;
1403 
1404   RPCFunctionIdAllocator<FunctionIdT> FnIdAllocator;
1405 
1406   FunctionIdT ResponseId;
1407   std::map<std::string, FunctionIdT> LocalFunctionIds;
1408   std::map<const char *, FunctionIdT> RemoteFunctionIds;
1409 
1410   std::map<FunctionIdT, WrappedHandlerFn> Handlers;
1411 
1412   std::mutex ResponsesMutex;
1413   detail::SequenceNumberManager<SequenceNumberT> SequenceNumberMgr;
1414   std::map<SequenceNumberT, std::unique_ptr<detail::ResponseHandler<ChannelT>>>
1415       PendingResponses;
1416 };
1417 
1418 } // end namespace detail
1419 
1420 template <typename ChannelT, typename FunctionIdT = uint32_t,
1421           typename SequenceNumberT = uint32_t>
1422 class MultiThreadedRPCEndpoint
1423     : public detail::RPCEndpointBase<
1424           MultiThreadedRPCEndpoint<ChannelT, FunctionIdT, SequenceNumberT>,
1425           ChannelT, FunctionIdT, SequenceNumberT> {
1426 private:
1427   using BaseClass =
1428       detail::RPCEndpointBase<
1429         MultiThreadedRPCEndpoint<ChannelT, FunctionIdT, SequenceNumberT>,
1430         ChannelT, FunctionIdT, SequenceNumberT>;
1431 
1432 public:
MultiThreadedRPCEndpoint(ChannelT & C,bool LazyAutoNegotiation)1433   MultiThreadedRPCEndpoint(ChannelT &C, bool LazyAutoNegotiation)
1434       : BaseClass(C, LazyAutoNegotiation) {}
1435 
1436   /// Add a handler for the given RPC function.
1437   /// This installs the given handler functor for the given RPC Function, and
1438   /// makes the RPC function available for negotiation/calling from the remote.
1439   template <typename Func, typename HandlerT>
addHandler(HandlerT Handler)1440   void addHandler(HandlerT Handler) {
1441     return this->template addHandlerImpl<Func>(std::move(Handler));
1442   }
1443 
1444   /// Add a class-method as a handler.
1445   template <typename Func, typename ClassT, typename RetT, typename... ArgTs>
addHandler(ClassT & Object,RetT (ClassT::* Method)(ArgTs...))1446   void addHandler(ClassT &Object, RetT (ClassT::*Method)(ArgTs...)) {
1447     addHandler<Func>(
1448       detail::MemberFnWrapper<ClassT, RetT, ArgTs...>(Object, Method));
1449   }
1450 
1451   template <typename Func, typename HandlerT>
addAsyncHandler(HandlerT Handler)1452   void addAsyncHandler(HandlerT Handler) {
1453     return this->template addAsyncHandlerImpl<Func>(std::move(Handler));
1454   }
1455 
1456   /// Add a class-method as a handler.
1457   template <typename Func, typename ClassT, typename RetT, typename... ArgTs>
addAsyncHandler(ClassT & Object,RetT (ClassT::* Method)(ArgTs...))1458   void addAsyncHandler(ClassT &Object, RetT (ClassT::*Method)(ArgTs...)) {
1459     addAsyncHandler<Func>(
1460       detail::MemberFnWrapper<ClassT, RetT, ArgTs...>(Object, Method));
1461   }
1462 
1463   /// Return type for non-blocking call primitives.
1464   template <typename Func>
1465   using NonBlockingCallResult = typename detail::ResultTraits<
1466       typename Func::ReturnType>::ReturnFutureType;
1467 
1468   /// Call Func on Channel C. Does not block, does not call send. Returns a pair
1469   /// of a future result and the sequence number assigned to the result.
1470   ///
1471   /// This utility function is primarily used for single-threaded mode support,
1472   /// where the sequence number can be used to wait for the corresponding
1473   /// result. In multi-threaded mode the appendCallNB method, which does not
1474   /// return the sequence numeber, should be preferred.
1475   template <typename Func, typename... ArgTs>
appendCallNB(const ArgTs &...Args)1476   Expected<NonBlockingCallResult<Func>> appendCallNB(const ArgTs &... Args) {
1477     using RTraits = detail::ResultTraits<typename Func::ReturnType>;
1478     using ErrorReturn = typename RTraits::ErrorReturnType;
1479     using ErrorReturnPromise = typename RTraits::ReturnPromiseType;
1480 
1481     // FIXME: Stack allocate and move this into the handler once LLVM builds
1482     //        with C++14.
1483     auto Promise = std::make_shared<ErrorReturnPromise>();
1484     auto FutureResult = Promise->get_future();
1485 
1486     if (auto Err = this->template appendCallAsync<Func>(
1487             [Promise](ErrorReturn RetOrErr) {
1488               Promise->set_value(std::move(RetOrErr));
1489               return Error::success();
1490             },
1491             Args...)) {
1492       RTraits::consumeAbandoned(FutureResult.get());
1493       return std::move(Err);
1494     }
1495     return std::move(FutureResult);
1496   }
1497 
1498   /// The same as appendCallNBWithSeq, except that it calls C.send() to
1499   /// flush the channel after serializing the call.
1500   template <typename Func, typename... ArgTs>
callNB(const ArgTs &...Args)1501   Expected<NonBlockingCallResult<Func>> callNB(const ArgTs &... Args) {
1502     auto Result = appendCallNB<Func>(Args...);
1503     if (!Result)
1504       return Result;
1505     if (auto Err = this->C.send()) {
1506       this->abandonPendingResponses();
1507       detail::ResultTraits<typename Func::ReturnType>::consumeAbandoned(
1508           std::move(Result->get()));
1509       return std::move(Err);
1510     }
1511     return Result;
1512   }
1513 
1514   /// Call Func on Channel C. Blocks waiting for a result. Returns an Error
1515   /// for void functions or an Expected<T> for functions returning a T.
1516   ///
1517   /// This function is for use in threaded code where another thread is
1518   /// handling responses and incoming calls.
1519   template <typename Func, typename... ArgTs,
1520             typename AltRetT = typename Func::ReturnType>
1521   typename detail::ResultTraits<AltRetT>::ErrorReturnType
callB(const ArgTs &...Args)1522   callB(const ArgTs &... Args) {
1523     if (auto FutureResOrErr = callNB<Func>(Args...))
1524       return FutureResOrErr->get();
1525     else
1526       return FutureResOrErr.takeError();
1527   }
1528 
1529   /// Handle incoming RPC calls.
handlerLoop()1530   Error handlerLoop() {
1531     while (true)
1532       if (auto Err = this->handleOne())
1533         return Err;
1534     return Error::success();
1535   }
1536 };
1537 
1538 template <typename ChannelT, typename FunctionIdT = uint32_t,
1539           typename SequenceNumberT = uint32_t>
1540 class SingleThreadedRPCEndpoint
1541     : public detail::RPCEndpointBase<
1542           SingleThreadedRPCEndpoint<ChannelT, FunctionIdT, SequenceNumberT>,
1543           ChannelT, FunctionIdT, SequenceNumberT> {
1544 private:
1545   using BaseClass =
1546       detail::RPCEndpointBase<
1547         SingleThreadedRPCEndpoint<ChannelT, FunctionIdT, SequenceNumberT>,
1548         ChannelT, FunctionIdT, SequenceNumberT>;
1549 
1550 public:
SingleThreadedRPCEndpoint(ChannelT & C,bool LazyAutoNegotiation)1551   SingleThreadedRPCEndpoint(ChannelT &C, bool LazyAutoNegotiation)
1552       : BaseClass(C, LazyAutoNegotiation) {}
1553 
1554   template <typename Func, typename HandlerT>
addHandler(HandlerT Handler)1555   void addHandler(HandlerT Handler) {
1556     return this->template addHandlerImpl<Func>(std::move(Handler));
1557   }
1558 
1559   template <typename Func, typename ClassT, typename RetT, typename... ArgTs>
addHandler(ClassT & Object,RetT (ClassT::* Method)(ArgTs...))1560   void addHandler(ClassT &Object, RetT (ClassT::*Method)(ArgTs...)) {
1561     addHandler<Func>(
1562         detail::MemberFnWrapper<ClassT, RetT, ArgTs...>(Object, Method));
1563   }
1564 
1565   template <typename Func, typename HandlerT>
addAsyncHandler(HandlerT Handler)1566   void addAsyncHandler(HandlerT Handler) {
1567     return this->template addAsyncHandlerImpl<Func>(std::move(Handler));
1568   }
1569 
1570   /// Add a class-method as a handler.
1571   template <typename Func, typename ClassT, typename RetT, typename... ArgTs>
addAsyncHandler(ClassT & Object,RetT (ClassT::* Method)(ArgTs...))1572   void addAsyncHandler(ClassT &Object, RetT (ClassT::*Method)(ArgTs...)) {
1573     addAsyncHandler<Func>(
1574       detail::MemberFnWrapper<ClassT, RetT, ArgTs...>(Object, Method));
1575   }
1576 
1577   template <typename Func, typename... ArgTs,
1578             typename AltRetT = typename Func::ReturnType>
1579   typename detail::ResultTraits<AltRetT>::ErrorReturnType
callB(const ArgTs &...Args)1580   callB(const ArgTs &... Args) {
1581     bool ReceivedResponse = false;
1582     using ResultType = typename detail::ResultTraits<AltRetT>::ErrorReturnType;
1583     auto Result = detail::ResultTraits<AltRetT>::createBlankErrorReturnValue();
1584 
1585     // We have to 'Check' result (which we know is in a success state at this
1586     // point) so that it can be overwritten in the async handler.
1587     (void)!!Result;
1588 
1589     if (auto Err = this->template appendCallAsync<Func>(
1590             [&](ResultType R) {
1591               Result = std::move(R);
1592               ReceivedResponse = true;
1593               return Error::success();
1594             },
1595             Args...)) {
1596       detail::ResultTraits<typename Func::ReturnType>::consumeAbandoned(
1597           std::move(Result));
1598       return std::move(Err);
1599     }
1600 
1601     while (!ReceivedResponse) {
1602       if (auto Err = this->handleOne()) {
1603         detail::ResultTraits<typename Func::ReturnType>::consumeAbandoned(
1604             std::move(Result));
1605         return std::move(Err);
1606       }
1607     }
1608 
1609     return Result;
1610   }
1611 };
1612 
1613 /// Asynchronous dispatch for a function on an RPC endpoint.
1614 template <typename RPCClass, typename Func>
1615 class RPCAsyncDispatch {
1616 public:
RPCAsyncDispatch(RPCClass & Endpoint)1617   RPCAsyncDispatch(RPCClass &Endpoint) : Endpoint(Endpoint) {}
1618 
1619   template <typename HandlerT, typename... ArgTs>
operator()1620   Error operator()(HandlerT Handler, const ArgTs &... Args) const {
1621     return Endpoint.template appendCallAsync<Func>(std::move(Handler), Args...);
1622   }
1623 
1624 private:
1625   RPCClass &Endpoint;
1626 };
1627 
1628 /// Construct an asynchronous dispatcher from an RPC endpoint and a Func.
1629 template <typename Func, typename RPCEndpointT>
rpcAsyncDispatch(RPCEndpointT & Endpoint)1630 RPCAsyncDispatch<RPCEndpointT, Func> rpcAsyncDispatch(RPCEndpointT &Endpoint) {
1631   return RPCAsyncDispatch<RPCEndpointT, Func>(Endpoint);
1632 }
1633 
1634 /// Allows a set of asynchrounous calls to be dispatched, and then
1635 ///        waited on as a group.
1636 class ParallelCallGroup {
1637 public:
1638 
1639   ParallelCallGroup() = default;
1640   ParallelCallGroup(const ParallelCallGroup &) = delete;
1641   ParallelCallGroup &operator=(const ParallelCallGroup &) = delete;
1642 
1643   /// Make as asynchronous call.
1644   template <typename AsyncDispatcher, typename HandlerT, typename... ArgTs>
call(const AsyncDispatcher & AsyncDispatch,HandlerT Handler,const ArgTs &...Args)1645   Error call(const AsyncDispatcher &AsyncDispatch, HandlerT Handler,
1646              const ArgTs &... Args) {
1647     // Increment the count of outstanding calls. This has to happen before
1648     // we invoke the call, as the handler may (depending on scheduling)
1649     // be run immediately on another thread, and we don't want the decrement
1650     // in the wrapped handler below to run before the increment.
1651     {
1652       std::unique_lock<std::mutex> Lock(M);
1653       ++NumOutstandingCalls;
1654     }
1655 
1656     // Wrap the user handler in a lambda that will decrement the
1657     // outstanding calls count, then poke the condition variable.
1658     using ArgType = typename detail::ResponseHandlerArg<
1659         typename detail::HandlerTraits<HandlerT>::Type>::ArgType;
1660     // FIXME: Move handler into wrapped handler once we have C++14.
1661     auto WrappedHandler = [this, Handler](ArgType Arg) {
1662       auto Err = Handler(std::move(Arg));
1663       std::unique_lock<std::mutex> Lock(M);
1664       --NumOutstandingCalls;
1665       CV.notify_all();
1666       return Err;
1667     };
1668 
1669     return AsyncDispatch(std::move(WrappedHandler), Args...);
1670   }
1671 
1672   /// Blocks until all calls have been completed and their return value
1673   ///        handlers run.
wait()1674   void wait() {
1675     std::unique_lock<std::mutex> Lock(M);
1676     while (NumOutstandingCalls > 0)
1677       CV.wait(Lock);
1678   }
1679 
1680 private:
1681   std::mutex M;
1682   std::condition_variable CV;
1683   uint32_t NumOutstandingCalls = 0;
1684 };
1685 
1686 /// Convenience class for grouping RPC Functions into APIs that can be
1687 ///        negotiated as a block.
1688 ///
1689 template <typename... Funcs>
1690 class APICalls {
1691 public:
1692 
1693   /// Test whether this API contains Function F.
1694   template <typename F>
1695   class Contains {
1696   public:
1697     static const bool value = false;
1698   };
1699 
1700   /// Negotiate all functions in this API.
1701   template <typename RPCEndpoint>
negotiate(RPCEndpoint & R)1702   static Error negotiate(RPCEndpoint &R) {
1703     return Error::success();
1704   }
1705 };
1706 
1707 template <typename Func, typename... Funcs>
1708 class APICalls<Func, Funcs...> {
1709 public:
1710 
1711   template <typename F>
1712   class Contains {
1713   public:
1714     static const bool value = std::is_same<F, Func>::value |
1715                               APICalls<Funcs...>::template Contains<F>::value;
1716   };
1717 
1718   template <typename RPCEndpoint>
negotiate(RPCEndpoint & R)1719   static Error negotiate(RPCEndpoint &R) {
1720     if (auto Err = R.template negotiateFunction<Func>())
1721       return Err;
1722     return APICalls<Funcs...>::negotiate(R);
1723   }
1724 
1725 };
1726 
1727 template <typename... InnerFuncs, typename... Funcs>
1728 class APICalls<APICalls<InnerFuncs...>, Funcs...> {
1729 public:
1730 
1731   template <typename F>
1732   class Contains {
1733   public:
1734     static const bool value =
1735       APICalls<InnerFuncs...>::template Contains<F>::value |
1736       APICalls<Funcs...>::template Contains<F>::value;
1737   };
1738 
1739   template <typename RPCEndpoint>
negotiate(RPCEndpoint & R)1740   static Error negotiate(RPCEndpoint &R) {
1741     if (auto Err = APICalls<InnerFuncs...>::negotiate(R))
1742       return Err;
1743     return APICalls<Funcs...>::negotiate(R);
1744   }
1745 
1746 };
1747 
1748 } // end namespace rpc
1749 } // end namespace orc
1750 } // end namespace llvm
1751 
1752 #endif
1753