1 // Copyright 2019 TiKV Project Authors. Licensed under Apache-2.0.
2 
3 use std::pin::Pin;
4 use std::ptr;
5 use std::sync::Arc;
6 use std::time::Duration;
7 
8 use crate::grpc_sys;
9 use futures::ready;
10 use futures::sink::Sink;
11 use futures::stream::Stream;
12 use futures::task::{Context, Poll};
13 use parking_lot::Mutex;
14 use std::future::Future;
15 
16 use super::{ShareCall, ShareCallHolder, SinkBase, WriteFlags};
17 use crate::buf::GrpcSlice;
18 use crate::call::{check_run, Call, MessageReader, Method};
19 use crate::channel::Channel;
20 use crate::codec::{DeserializeFn, SerializeFn};
21 use crate::error::{Error, Result};
22 use crate::metadata::Metadata;
23 use crate::task::{BatchFuture, BatchType};
24 
25 /// Update the flag bit in res.
26 #[inline]
change_flag(res: &mut u32, flag: u32, set: bool)27 pub fn change_flag(res: &mut u32, flag: u32, set: bool) {
28     if set {
29         *res |= flag;
30     } else {
31         *res &= !flag;
32     }
33 }
34 
35 /// Options for calls made by client.
36 #[derive(Clone, Default)]
37 pub struct CallOption {
38     timeout: Option<Duration>,
39     write_flags: WriteFlags,
40     call_flags: u32,
41     headers: Option<Metadata>,
42 }
43 
44 impl CallOption {
45     /// Signal that the call is idempotent.
idempotent(mut self, is_idempotent: bool) -> CallOption46     pub fn idempotent(mut self, is_idempotent: bool) -> CallOption {
47         change_flag(
48             &mut self.call_flags,
49             grpc_sys::GRPC_INITIAL_METADATA_IDEMPOTENT_REQUEST,
50             is_idempotent,
51         );
52         self
53     }
54 
55     /// Signal that the call should not return UNAVAILABLE before it has started.
wait_for_ready(mut self, wait_for_ready: bool) -> CallOption56     pub fn wait_for_ready(mut self, wait_for_ready: bool) -> CallOption {
57         change_flag(
58             &mut self.call_flags,
59             grpc_sys::GRPC_INITIAL_METADATA_WAIT_FOR_READY,
60             wait_for_ready,
61         );
62         self
63     }
64 
65     /// Signal that the call is cacheable. gRPC is free to use GET verb.
cacheable(mut self, cacheable: bool) -> CallOption66     pub fn cacheable(mut self, cacheable: bool) -> CallOption {
67         change_flag(
68             &mut self.call_flags,
69             grpc_sys::GRPC_INITIAL_METADATA_CACHEABLE_REQUEST,
70             cacheable,
71         );
72         self
73     }
74 
75     /// Set write flags.
write_flags(mut self, write_flags: WriteFlags) -> CallOption76     pub fn write_flags(mut self, write_flags: WriteFlags) -> CallOption {
77         self.write_flags = write_flags;
78         self
79     }
80 
81     /// Set a timeout.
timeout(mut self, timeout: Duration) -> CallOption82     pub fn timeout(mut self, timeout: Duration) -> CallOption {
83         self.timeout = Some(timeout);
84         self
85     }
86 
87     /// Get the timeout.
get_timeout(&self) -> Option<Duration>88     pub fn get_timeout(&self) -> Option<Duration> {
89         self.timeout
90     }
91 
92     /// Set the headers to be sent with the call.
headers(mut self, meta: Metadata) -> CallOption93     pub fn headers(mut self, meta: Metadata) -> CallOption {
94         self.headers = Some(meta);
95         self
96     }
97 
98     /// Get headers to be sent with the call.
get_headers(&self) -> Option<&Metadata>99     pub fn get_headers(&self) -> Option<&Metadata> {
100         self.headers.as_ref()
101     }
102 }
103 
104 impl Call {
unary_async<Req, Resp>( channel: &Channel, method: &Method<Req, Resp>, req: &Req, mut opt: CallOption, ) -> Result<ClientUnaryReceiver<Resp>>105     pub fn unary_async<Req, Resp>(
106         channel: &Channel,
107         method: &Method<Req, Resp>,
108         req: &Req,
109         mut opt: CallOption,
110     ) -> Result<ClientUnaryReceiver<Resp>> {
111         let call = channel.create_call(method, &opt)?;
112         let mut payload = GrpcSlice::default();
113         (method.req_ser())(req, &mut payload);
114         let cq_f = check_run(BatchType::CheckRead, |ctx, tag| unsafe {
115             grpc_sys::grpcwrap_call_start_unary(
116                 call.call,
117                 ctx,
118                 payload.as_mut_ptr(),
119                 opt.write_flags.flags,
120                 opt.headers
121                     .as_mut()
122                     .map_or_else(ptr::null_mut, |c| c as *mut _ as _),
123                 opt.call_flags,
124                 tag,
125             )
126         });
127         Ok(ClientUnaryReceiver::new(call, cq_f, method.resp_de()))
128     }
129 
client_streaming<Req, Resp>( channel: &Channel, method: &Method<Req, Resp>, mut opt: CallOption, ) -> Result<(ClientCStreamSender<Req>, ClientCStreamReceiver<Resp>)>130     pub fn client_streaming<Req, Resp>(
131         channel: &Channel,
132         method: &Method<Req, Resp>,
133         mut opt: CallOption,
134     ) -> Result<(ClientCStreamSender<Req>, ClientCStreamReceiver<Resp>)> {
135         let call = channel.create_call(method, &opt)?;
136         let cq_f = check_run(BatchType::CheckRead, |ctx, tag| unsafe {
137             grpc_sys::grpcwrap_call_start_client_streaming(
138                 call.call,
139                 ctx,
140                 opt.headers
141                     .as_mut()
142                     .map_or_else(ptr::null_mut, |c| c as *mut _ as _),
143                 opt.call_flags,
144                 tag,
145             )
146         });
147 
148         let share_call = Arc::new(Mutex::new(ShareCall::new(call, cq_f)));
149         let sink = ClientCStreamSender::new(share_call.clone(), method.req_ser());
150         let recv = ClientCStreamReceiver {
151             call: share_call,
152             resp_de: method.resp_de(),
153             finished: false,
154         };
155         Ok((sink, recv))
156     }
157 
server_streaming<Req, Resp>( channel: &Channel, method: &Method<Req, Resp>, req: &Req, mut opt: CallOption, ) -> Result<ClientSStreamReceiver<Resp>>158     pub fn server_streaming<Req, Resp>(
159         channel: &Channel,
160         method: &Method<Req, Resp>,
161         req: &Req,
162         mut opt: CallOption,
163     ) -> Result<ClientSStreamReceiver<Resp>> {
164         let call = channel.create_call(method, &opt)?;
165         let mut payload = GrpcSlice::default();
166         (method.req_ser())(req, &mut payload);
167         let cq_f = check_run(BatchType::Finish, |ctx, tag| unsafe {
168             grpc_sys::grpcwrap_call_start_server_streaming(
169                 call.call,
170                 ctx,
171                 payload.as_mut_ptr(),
172                 opt.write_flags.flags,
173                 opt.headers
174                     .as_mut()
175                     .map_or_else(ptr::null_mut, |c| c as *mut _ as _),
176                 opt.call_flags,
177                 tag,
178             )
179         });
180 
181         // TODO: handle header
182         check_run(BatchType::Finish, |ctx, tag| unsafe {
183             grpc_sys::grpcwrap_call_recv_initial_metadata(call.call, ctx, tag)
184         });
185 
186         Ok(ClientSStreamReceiver::new(call, cq_f, method.resp_de()))
187     }
188 
duplex_streaming<Req, Resp>( channel: &Channel, method: &Method<Req, Resp>, mut opt: CallOption, ) -> Result<(ClientDuplexSender<Req>, ClientDuplexReceiver<Resp>)>189     pub fn duplex_streaming<Req, Resp>(
190         channel: &Channel,
191         method: &Method<Req, Resp>,
192         mut opt: CallOption,
193     ) -> Result<(ClientDuplexSender<Req>, ClientDuplexReceiver<Resp>)> {
194         let call = channel.create_call(method, &opt)?;
195         let cq_f = check_run(BatchType::Finish, |ctx, tag| unsafe {
196             grpc_sys::grpcwrap_call_start_duplex_streaming(
197                 call.call,
198                 ctx,
199                 opt.headers
200                     .as_mut()
201                     .map_or_else(ptr::null_mut, |c| c as *mut _ as _),
202                 opt.call_flags,
203                 tag,
204             )
205         });
206 
207         // TODO: handle header.
208         check_run(BatchType::Finish, |ctx, tag| unsafe {
209             grpc_sys::grpcwrap_call_recv_initial_metadata(call.call, ctx, tag)
210         });
211 
212         let share_call = Arc::new(Mutex::new(ShareCall::new(call, cq_f)));
213         let sink = ClientDuplexSender::new(share_call.clone(), method.req_ser());
214         let recv = ClientDuplexReceiver::new(share_call, method.resp_de());
215         Ok((sink, recv))
216     }
217 }
218 
219 /// A receiver for unary request.
220 ///
221 /// The future is resolved once response is received.
222 #[must_use = "if unused the ClientUnaryReceiver may immediately cancel the RPC"]
223 pub struct ClientUnaryReceiver<T> {
224     call: Call,
225     resp_f: BatchFuture,
226     resp_de: DeserializeFn<T>,
227 }
228 
229 impl<T> ClientUnaryReceiver<T> {
new(call: Call, resp_f: BatchFuture, resp_de: DeserializeFn<T>) -> ClientUnaryReceiver<T>230     fn new(call: Call, resp_f: BatchFuture, resp_de: DeserializeFn<T>) -> ClientUnaryReceiver<T> {
231         ClientUnaryReceiver {
232             call,
233             resp_f,
234             resp_de,
235         }
236     }
237 
238     /// Cancel the call.
239     #[inline]
cancel(&mut self)240     pub fn cancel(&mut self) {
241         self.call.cancel()
242     }
243 
244     #[inline]
resp_de(&self, reader: MessageReader) -> Result<T>245     pub fn resp_de(&self, reader: MessageReader) -> Result<T> {
246         (self.resp_de)(reader)
247     }
248 }
249 
250 impl<T> Future for ClientUnaryReceiver<T> {
251     type Output = Result<T>;
252 
poll(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<T>>253     fn poll(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<T>> {
254         let data = ready!(Pin::new(&mut self.resp_f).poll(cx)?);
255         let t = self.resp_de(data.unwrap())?;
256         Poll::Ready(Ok(t))
257     }
258 }
259 
260 /// A receiver for client streaming call.
261 ///
262 /// If the corresponding sink has dropped or cancelled, this will poll a
263 /// [`RpcFailure`] error with the [`Cancelled`] status.
264 ///
265 /// [`RpcFailure`]: ./enum.Error.html#variant.RpcFailure
266 /// [`Cancelled`]: ./enum.RpcStatusCode.html#variant.Cancelled
267 #[must_use = "if unused the ClientCStreamReceiver may immediately cancel the RPC"]
268 pub struct ClientCStreamReceiver<T> {
269     call: Arc<Mutex<ShareCall>>,
270     resp_de: DeserializeFn<T>,
271     finished: bool,
272 }
273 
274 impl<T> ClientCStreamReceiver<T> {
275     /// Cancel the call.
cancel(&mut self)276     pub fn cancel(&mut self) {
277         let lock = self.call.lock();
278         lock.call.cancel()
279     }
280 
281     #[inline]
resp_de(&self, reader: MessageReader) -> Result<T>282     pub fn resp_de(&self, reader: MessageReader) -> Result<T> {
283         (self.resp_de)(reader)
284     }
285 }
286 
287 impl<T> Drop for ClientCStreamReceiver<T> {
288     /// The corresponding RPC will be canceled if the receiver did not
289     /// finish before dropping.
drop(&mut self)290     fn drop(&mut self) {
291         if !self.finished {
292             self.cancel();
293         }
294     }
295 }
296 
297 impl<T> Future for ClientCStreamReceiver<T> {
298     type Output = Result<T>;
299 
poll(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<T>>300     fn poll(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<T>> {
301         let data = {
302             let mut call = self.call.lock();
303             ready!(call.poll_finish(cx)?)
304         };
305         let t = (self.resp_de)(data.unwrap())?;
306         self.finished = true;
307         Poll::Ready(Ok(t))
308     }
309 }
310 
311 /// A sink for client streaming call and duplex streaming call.
312 /// To close the sink properly, you should call [`close`] before dropping.
313 ///
314 /// [`close`]: #method.close
315 #[must_use = "if unused the StreamingCallSink may immediately cancel the RPC"]
316 pub struct StreamingCallSink<Req> {
317     call: Arc<Mutex<ShareCall>>,
318     sink_base: SinkBase,
319     close_f: Option<BatchFuture>,
320     req_ser: SerializeFn<Req>,
321 }
322 
323 impl<Req> StreamingCallSink<Req> {
new(call: Arc<Mutex<ShareCall>>, req_ser: SerializeFn<Req>) -> StreamingCallSink<Req>324     fn new(call: Arc<Mutex<ShareCall>>, req_ser: SerializeFn<Req>) -> StreamingCallSink<Req> {
325         StreamingCallSink {
326             call,
327             sink_base: SinkBase::new(false),
328             close_f: None,
329             req_ser,
330         }
331     }
332 
333     /// By default it always sends messages with their configured buffer hint. But when the
334     /// `enhance_batch` is enabled, messages will be batched together as many as possible.
335     /// The rules are listed as below:
336     /// - All messages except the last one will be sent with `buffer_hint` set to true.
337     /// - The last message will also be sent with `buffer_hint` set to true unless any message is
338     ///    offered with buffer hint set to false.
339     ///
340     /// No matter `enhance_batch` is true or false, it's recommended to follow the contract of
341     /// Sink and call `poll_flush` to ensure messages are handled by gRPC C Core.
enhance_batch(&mut self, flag: bool)342     pub fn enhance_batch(&mut self, flag: bool) {
343         self.sink_base.enhance_buffer_strategy = flag;
344     }
345 
cancel(&mut self)346     pub fn cancel(&mut self) {
347         let call = self.call.lock();
348         call.call.cancel()
349     }
350 }
351 
352 impl<P> Drop for StreamingCallSink<P> {
353     /// The corresponding RPC will be canceled if the sink did not call
354     /// [`close`] before dropping.
355     ///
356     /// [`close`]: #method.close
drop(&mut self)357     fn drop(&mut self) {
358         if self.close_f.is_none() {
359             self.cancel();
360         }
361     }
362 }
363 
364 impl<Req> Sink<(Req, WriteFlags)> for StreamingCallSink<Req> {
365     type Error = Error;
366 
367     #[inline]
poll_ready(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<()>>368     fn poll_ready(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<()>> {
369         Pin::new(&mut self.sink_base).poll_ready(cx)
370     }
371 
372     #[inline]
start_send(mut self: Pin<&mut Self>, (msg, flags): (Req, WriteFlags)) -> Result<()>373     fn start_send(mut self: Pin<&mut Self>, (msg, flags): (Req, WriteFlags)) -> Result<()> {
374         {
375             let mut call = self.call.lock();
376             call.check_alive()?;
377         }
378         let t = &mut *self;
379         Pin::new(&mut t.sink_base).start_send(&mut t.call, &msg, flags, t.req_ser)
380     }
381 
382     #[inline]
poll_flush(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<()>>383     fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<()>> {
384         {
385             let mut call = self.call.lock();
386             call.check_alive()?;
387         }
388         let t = &mut *self;
389         Pin::new(&mut t.sink_base).poll_flush(cx, &mut t.call)
390     }
391 
poll_close(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<()>>392     fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<()>> {
393         let t = &mut *self;
394         let mut call = t.call.lock();
395         if t.close_f.is_none() {
396             ready!(Pin::new(&mut t.sink_base).poll_ready(cx)?);
397 
398             let close_f = call.call.start_send_close_client()?;
399             t.close_f = Some(close_f);
400         }
401 
402         if Pin::new(t.close_f.as_mut().unwrap()).poll(cx)?.is_pending() {
403             // if call is finished, can return early here.
404             call.check_alive()?;
405             return Poll::Pending;
406         }
407         Poll::Ready(Ok(()))
408     }
409 }
410 
411 /// A sink for client streaming call.
412 ///
413 /// To close the sink properly, you should call [`close`] before dropping.
414 ///
415 /// [`close`]: #method.close
416 pub type ClientCStreamSender<T> = StreamingCallSink<T>;
417 /// A sink for duplex streaming call.
418 ///
419 /// To close the sink properly, you should call [`close`] before dropping.
420 ///
421 /// [`close`]: #method.close
422 pub type ClientDuplexSender<T> = StreamingCallSink<T>;
423 
424 struct ResponseStreamImpl<H, T> {
425     call: H,
426     msg_f: Option<BatchFuture>,
427     read_done: bool,
428     finished: bool,
429     resp_de: DeserializeFn<T>,
430 }
431 
432 impl<H: ShareCallHolder + Unpin, T> ResponseStreamImpl<H, T> {
new(call: H, resp_de: DeserializeFn<T>) -> ResponseStreamImpl<H, T>433     fn new(call: H, resp_de: DeserializeFn<T>) -> ResponseStreamImpl<H, T> {
434         ResponseStreamImpl {
435             call,
436             msg_f: None,
437             read_done: false,
438             finished: false,
439             resp_de,
440         }
441     }
442 
cancel(&mut self)443     fn cancel(&mut self) {
444         self.call.call(|c| c.call.cancel())
445     }
446 
poll(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Option<Result<T>>>447     fn poll(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Option<Result<T>>> {
448         if !self.finished {
449             let t = &mut *self;
450             let finished = &mut t.finished;
451             let _ = t.call.call(|c| {
452                 let res = c.poll_finish(cx);
453                 *finished = c.finished;
454                 res
455             })?;
456         }
457 
458         let mut bytes = None;
459         loop {
460             if !self.read_done {
461                 if let Some(msg_f) = &mut self.msg_f {
462                     bytes = ready!(Pin::new(msg_f).poll(cx)?);
463                     if bytes.is_none() {
464                         self.read_done = true;
465                     }
466                 }
467             }
468 
469             if self.read_done {
470                 if self.finished {
471                     return Poll::Ready(None);
472                 }
473                 return Poll::Pending;
474             }
475 
476             // so msg_f must be either stale or not initialised yet.
477             self.msg_f.take();
478             let msg_f = self.call.call(|c| c.call.start_recv_message())?;
479             self.msg_f = Some(msg_f);
480             if let Some(data) = bytes {
481                 let msg = (self.resp_de)(data)?;
482                 return Poll::Ready(Some(Ok(msg)));
483             }
484         }
485     }
486 
487     // Cancel the call if we still have some messages or did not
488     // receive status code.
on_drop(&mut self)489     fn on_drop(&mut self) {
490         if !self.read_done || !self.finished {
491             self.cancel();
492         }
493     }
494 }
495 
496 /// A receiver for server streaming call.
497 #[must_use = "if unused the ClientSStreamReceiver may immediately cancel the RPC"]
498 pub struct ClientSStreamReceiver<Resp> {
499     imp: ResponseStreamImpl<ShareCall, Resp>,
500 }
501 
502 impl<Resp> ClientSStreamReceiver<Resp> {
new( call: Call, finish_f: BatchFuture, de: DeserializeFn<Resp>, ) -> ClientSStreamReceiver<Resp>503     fn new(
504         call: Call,
505         finish_f: BatchFuture,
506         de: DeserializeFn<Resp>,
507     ) -> ClientSStreamReceiver<Resp> {
508         let share_call = ShareCall::new(call, finish_f);
509         ClientSStreamReceiver {
510             imp: ResponseStreamImpl::new(share_call, de),
511         }
512     }
513 
cancel(&mut self)514     pub fn cancel(&mut self) {
515         self.imp.cancel()
516     }
517 }
518 
519 impl<Resp> Stream for ClientSStreamReceiver<Resp> {
520     type Item = Result<Resp>;
521 
522     #[inline]
poll_next(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Option<Self::Item>>523     fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Option<Self::Item>> {
524         Pin::new(&mut self.imp).poll(cx)
525     }
526 }
527 
528 /// A response receiver for duplex call.
529 ///
530 /// If the corresponding sink has dropped or cancelled, this will poll a
531 /// [`RpcFailure`] error with the [`Cancelled`] status.
532 ///
533 /// [`RpcFailure`]: ./enum.Error.html#variant.RpcFailure
534 /// [`Cancelled`]: ./enum.RpcStatusCode.html#variant.Cancelled
535 #[must_use = "if unused the ClientDuplexReceiver may immediately cancel the RPC"]
536 pub struct ClientDuplexReceiver<Resp> {
537     imp: ResponseStreamImpl<Arc<Mutex<ShareCall>>, Resp>,
538 }
539 
540 impl<Resp> ClientDuplexReceiver<Resp> {
new(call: Arc<Mutex<ShareCall>>, de: DeserializeFn<Resp>) -> ClientDuplexReceiver<Resp>541     fn new(call: Arc<Mutex<ShareCall>>, de: DeserializeFn<Resp>) -> ClientDuplexReceiver<Resp> {
542         ClientDuplexReceiver {
543             imp: ResponseStreamImpl::new(call, de),
544         }
545     }
546 
cancel(&mut self)547     pub fn cancel(&mut self) {
548         self.imp.cancel()
549     }
550 }
551 
552 impl<Resp> Drop for ClientDuplexReceiver<Resp> {
553     /// The corresponding RPC will be canceled if the receiver did not
554     /// finish before dropping.
drop(&mut self)555     fn drop(&mut self) {
556         self.imp.on_drop()
557     }
558 }
559 
560 impl<Resp> Stream for ClientDuplexReceiver<Resp> {
561     type Item = Result<Resp>;
562 
poll_next(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Option<Self::Item>>563     fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Option<Self::Item>> {
564         Pin::new(&mut self.imp).poll(cx)
565     }
566 }
567 
568 #[cfg(test)]
569 mod tests {
570     #[test]
test_change_flag()571     fn test_change_flag() {
572         let mut flag = 2 | 4;
573         super::change_flag(&mut flag, 8, true);
574         assert_eq!(flag, 2 | 4 | 8);
575         super::change_flag(&mut flag, 4, false);
576         assert_eq!(flag, 2 | 8);
577     }
578 }
579