1 /*
2  *
3  * Copyright 2015 gRPC authors.
4  *
5  * Licensed under the Apache License, Version 2.0 (the "License");
6  * you may not use this file except in compliance with the License.
7  * You may obtain a copy of the License at
8  *
9  *     http://www.apache.org/licenses/LICENSE-2.0
10  *
11  * Unless required by applicable law or agreed to in writing, software
12  * distributed under the License is distributed on an "AS IS" BASIS,
13  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14  * See the License for the specific language governing permissions and
15  * limitations under the License.
16  *
17  */
18 
19 #include <forward_list>
20 #include <functional>
21 #include <list>
22 #include <memory>
23 #include <mutex>
24 #include <sstream>
25 #include <string>
26 #include <thread>
27 #include <utility>
28 #include <vector>
29 
30 #include <grpc/grpc.h>
31 #include <grpc/support/cpu.h>
32 #include <grpc/support/log.h>
33 #include <grpcpp/alarm.h>
34 #include <grpcpp/channel.h>
35 #include <grpcpp/client_context.h>
36 #include <grpcpp/generic/generic_stub.h>
37 
38 #include "src/core/lib/surface/completion_queue.h"
39 #include "src/proto/grpc/testing/benchmark_service.grpc.pb.h"
40 #include "test/cpp/qps/client.h"
41 #include "test/cpp/qps/usage_timer.h"
42 #include "test/cpp/util/create_test_channel.h"
43 
44 namespace grpc {
45 namespace testing {
46 
47 class ClientRpcContext {
48  public:
ClientRpcContext()49   ClientRpcContext() {}
~ClientRpcContext()50   virtual ~ClientRpcContext() {}
51   // next state, return false if done. Collect stats when appropriate
52   virtual bool RunNextState(bool, HistogramEntry* entry) = 0;
53   virtual void StartNewClone(CompletionQueue* cq) = 0;
tag(ClientRpcContext * c)54   static void* tag(ClientRpcContext* c) { return static_cast<void*>(c); }
detag(void * t)55   static ClientRpcContext* detag(void* t) {
56     return static_cast<ClientRpcContext*>(t);
57   }
58 
59   virtual void Start(CompletionQueue* cq, const ClientConfig& config) = 0;
60   virtual void TryCancel() = 0;
61 };
62 
63 template <class RequestType, class ResponseType>
64 class ClientRpcContextUnaryImpl : public ClientRpcContext {
65  public:
ClientRpcContextUnaryImpl(BenchmarkService::Stub * stub,const RequestType & req,std::function<gpr_timespec ()> next_issue,std::function<std::unique_ptr<grpc::ClientAsyncResponseReader<ResponseType>> (BenchmarkService::Stub *,grpc::ClientContext *,const RequestType &,CompletionQueue *)> prepare_req,std::function<void (grpc::Status,ResponseType *,HistogramEntry *)> on_done)66   ClientRpcContextUnaryImpl(
67       BenchmarkService::Stub* stub, const RequestType& req,
68       std::function<gpr_timespec()> next_issue,
69       std::function<
70           std::unique_ptr<grpc::ClientAsyncResponseReader<ResponseType>>(
71               BenchmarkService::Stub*, grpc::ClientContext*, const RequestType&,
72               CompletionQueue*)>
73           prepare_req,
74       std::function<void(grpc::Status, ResponseType*, HistogramEntry*)> on_done)
75       : context_(),
76         stub_(stub),
77         cq_(nullptr),
78         req_(req),
79         response_(),
80         next_state_(State::READY),
81         callback_(on_done),
82         next_issue_(std::move(next_issue)),
83         prepare_req_(prepare_req) {}
~ClientRpcContextUnaryImpl()84   ~ClientRpcContextUnaryImpl() override {}
Start(CompletionQueue * cq,const ClientConfig & config)85   void Start(CompletionQueue* cq, const ClientConfig& config) override {
86     GPR_ASSERT(!config.use_coalesce_api());  // not supported.
87     StartInternal(cq);
88   }
RunNextState(bool ok,HistogramEntry * entry)89   bool RunNextState(bool ok, HistogramEntry* entry) override {
90     switch (next_state_) {
91       case State::READY:
92         start_ = UsageTimer::Now();
93         response_reader_ = prepare_req_(stub_, &context_, req_, cq_);
94         response_reader_->StartCall();
95         next_state_ = State::RESP_DONE;
96         response_reader_->Finish(&response_, &status_,
97                                  ClientRpcContext::tag(this));
98         return true;
99       case State::RESP_DONE:
100         if (status_.ok()) {
101           entry->set_value((UsageTimer::Now() - start_) * 1e9);
102         }
103         callback_(status_, &response_, entry);
104         next_state_ = State::INVALID;
105         return false;
106       default:
107         GPR_ASSERT(false);
108         return false;
109     }
110   }
StartNewClone(CompletionQueue * cq)111   void StartNewClone(CompletionQueue* cq) override {
112     auto* clone = new ClientRpcContextUnaryImpl(stub_, req_, next_issue_,
113                                                 prepare_req_, callback_);
114     clone->StartInternal(cq);
115   }
TryCancel()116   void TryCancel() override { context_.TryCancel(); }
117 
118  private:
119   grpc::ClientContext context_;
120   BenchmarkService::Stub* stub_;
121   CompletionQueue* cq_;
122   std::unique_ptr<Alarm> alarm_;
123   const RequestType& req_;
124   ResponseType response_;
125   enum State { INVALID, READY, RESP_DONE };
126   State next_state_;
127   std::function<void(grpc::Status, ResponseType*, HistogramEntry*)> callback_;
128   std::function<gpr_timespec()> next_issue_;
129   std::function<std::unique_ptr<grpc::ClientAsyncResponseReader<ResponseType>>(
130       BenchmarkService::Stub*, grpc::ClientContext*, const RequestType&,
131       CompletionQueue*)>
132       prepare_req_;
133   grpc::Status status_;
134   double start_;
135   std::unique_ptr<grpc::ClientAsyncResponseReader<ResponseType>>
136       response_reader_;
137 
StartInternal(CompletionQueue * cq)138   void StartInternal(CompletionQueue* cq) {
139     cq_ = cq;
140     if (!next_issue_) {  // ready to issue
141       RunNextState(true, nullptr);
142     } else {  // wait for the issue time
143       alarm_.reset(new Alarm);
144       alarm_->Set(cq_, next_issue_(), ClientRpcContext::tag(this));
145     }
146   }
147 };
148 
149 template <class StubType, class RequestType>
150 class AsyncClient : public ClientImpl<StubType, RequestType> {
151   // Specify which protected members we are using since there is no
152   // member name resolution until the template types are fully resolved
153  public:
154   using Client::NextIssuer;
155   using Client::SetupLoadTest;
156   using Client::closed_loop_;
157   using ClientImpl<StubType, RequestType>::cores_;
158   using ClientImpl<StubType, RequestType>::channels_;
159   using ClientImpl<StubType, RequestType>::request_;
AsyncClient(const ClientConfig & config,std::function<ClientRpcContext * (StubType *,std::function<gpr_timespec ()> next_issue,const RequestType &)> setup_ctx,std::function<std::unique_ptr<StubType> (std::shared_ptr<Channel>)> create_stub)160   AsyncClient(const ClientConfig& config,
161               std::function<ClientRpcContext*(
162                   StubType*, std::function<gpr_timespec()> next_issue,
163                   const RequestType&)>
164                   setup_ctx,
165               std::function<std::unique_ptr<StubType>(std::shared_ptr<Channel>)>
166                   create_stub)
167       : ClientImpl<StubType, RequestType>(config, create_stub),
168         num_async_threads_(NumThreads(config)) {
169     SetupLoadTest(config, num_async_threads_);
170 
171     int tpc = std::max(1, config.threads_per_cq());      // 1 if unspecified
172     int num_cqs = (num_async_threads_ + tpc - 1) / tpc;  // ceiling operator
173     for (int i = 0; i < num_cqs; i++) {
174       cli_cqs_.emplace_back(new CompletionQueue);
175     }
176 
177     for (int i = 0; i < num_async_threads_; i++) {
178       cq_.emplace_back(i % cli_cqs_.size());
179       next_issuers_.emplace_back(NextIssuer(i));
180       shutdown_state_.emplace_back(new PerThreadShutdownState());
181     }
182 
183     int t = 0;
184     for (int ch = 0; ch < config.client_channels(); ch++) {
185       for (int i = 0; i < config.outstanding_rpcs_per_channel(); i++) {
186         auto* cq = cli_cqs_[t].get();
187         auto ctx =
188             setup_ctx(channels_[ch].get_stub(), next_issuers_[t], request_);
189         ctx->Start(cq, config);
190       }
191       t = (t + 1) % cli_cqs_.size();
192     }
193   }
~AsyncClient()194   virtual ~AsyncClient() {
195     for (auto cq = cli_cqs_.begin(); cq != cli_cqs_.end(); cq++) {
196       void* got_tag;
197       bool ok;
198       while ((*cq)->Next(&got_tag, &ok)) {
199         delete ClientRpcContext::detag(got_tag);
200       }
201     }
202   }
203 
GetPollCount()204   int GetPollCount() override {
205     int count = 0;
206     for (auto cq = cli_cqs_.begin(); cq != cli_cqs_.end(); cq++) {
207       count += grpc_get_cq_poll_num((*cq)->cq());
208     }
209     return count;
210   }
211 
212  protected:
213   const int num_async_threads_;
214 
215  private:
216   struct PerThreadShutdownState {
217     mutable std::mutex mutex;
218     bool shutdown;
PerThreadShutdownStategrpc::testing::AsyncClient::PerThreadShutdownState219     PerThreadShutdownState() : shutdown(false) {}
220   };
221 
NumThreads(const ClientConfig & config)222   int NumThreads(const ClientConfig& config) {
223     int num_threads = config.async_client_threads();
224     if (num_threads <= 0) {  // Use dynamic sizing
225       num_threads = cores_;
226       gpr_log(GPR_INFO, "Sizing async client to %d threads", num_threads);
227     }
228     return num_threads;
229   }
DestroyMultithreading()230   void DestroyMultithreading() override final {
231     for (auto ss = shutdown_state_.begin(); ss != shutdown_state_.end(); ++ss) {
232       std::lock_guard<std::mutex> lock((*ss)->mutex);
233       (*ss)->shutdown = true;
234     }
235     for (auto cq = cli_cqs_.begin(); cq != cli_cqs_.end(); cq++) {
236       (*cq)->Shutdown();
237     }
238     this->EndThreads();  // this needed for resolution
239   }
240 
ProcessTag(size_t thread_idx,void * tag)241   ClientRpcContext* ProcessTag(size_t thread_idx, void* tag) {
242     ClientRpcContext* ctx = ClientRpcContext::detag(tag);
243     if (shutdown_state_[thread_idx]->shutdown) {
244       ctx->TryCancel();
245       delete ctx;
246       bool ok;
247       while (cli_cqs_[cq_[thread_idx]]->Next(&tag, &ok)) {
248         ctx = ClientRpcContext::detag(tag);
249         ctx->TryCancel();
250         delete ctx;
251       }
252       return nullptr;
253     }
254     return ctx;
255   }
256 
ThreadFunc(size_t thread_idx,Client::Thread * t)257   void ThreadFunc(size_t thread_idx, Client::Thread* t) override final {
258     void* got_tag;
259     bool ok;
260 
261     HistogramEntry entry;
262     HistogramEntry* entry_ptr = &entry;
263     if (!cli_cqs_[cq_[thread_idx]]->Next(&got_tag, &ok)) {
264       return;
265     }
266     std::mutex* shutdown_mu = &shutdown_state_[thread_idx]->mutex;
267     shutdown_mu->lock();
268     ClientRpcContext* ctx = ProcessTag(thread_idx, got_tag);
269     if (ctx == nullptr) {
270       shutdown_mu->unlock();
271       return;
272     }
273     while (cli_cqs_[cq_[thread_idx]]->DoThenAsyncNext(
274         [&, ctx, ok, entry_ptr, shutdown_mu]() {
275           if (!ctx->RunNextState(ok, entry_ptr)) {
276             // The RPC and callback are done, so clone the ctx
277             // and kickstart the new one
278             ctx->StartNewClone(cli_cqs_[cq_[thread_idx]].get());
279             delete ctx;
280           }
281           shutdown_mu->unlock();
282         },
283         &got_tag, &ok, gpr_inf_future(GPR_CLOCK_REALTIME))) {
284       t->UpdateHistogram(entry_ptr);
285       entry = HistogramEntry();
286       shutdown_mu->lock();
287       ctx = ProcessTag(thread_idx, got_tag);
288       if (ctx == nullptr) {
289         shutdown_mu->unlock();
290         return;
291       }
292     }
293   }
294 
295   std::vector<std::unique_ptr<CompletionQueue>> cli_cqs_;
296   std::vector<int> cq_;
297   std::vector<std::function<gpr_timespec()>> next_issuers_;
298   std::vector<std::unique_ptr<PerThreadShutdownState>> shutdown_state_;
299 };
300 
BenchmarkStubCreator(const std::shared_ptr<Channel> & ch)301 static std::unique_ptr<BenchmarkService::Stub> BenchmarkStubCreator(
302     const std::shared_ptr<Channel>& ch) {
303   return BenchmarkService::NewStub(ch);
304 }
305 
306 class AsyncUnaryClient final
307     : public AsyncClient<BenchmarkService::Stub, SimpleRequest> {
308  public:
AsyncUnaryClient(const ClientConfig & config)309   explicit AsyncUnaryClient(const ClientConfig& config)
310       : AsyncClient<BenchmarkService::Stub, SimpleRequest>(
311             config, SetupCtx, BenchmarkStubCreator) {
312     StartThreads(num_async_threads_);
313   }
~AsyncUnaryClient()314   ~AsyncUnaryClient() override {}
315 
316  private:
CheckDone(const grpc::Status & s,SimpleResponse * response,HistogramEntry * entry)317   static void CheckDone(const grpc::Status& s, SimpleResponse* response,
318                         HistogramEntry* entry) {
319     entry->set_status(s.error_code());
320   }
321   static std::unique_ptr<grpc::ClientAsyncResponseReader<SimpleResponse>>
PrepareReq(BenchmarkService::Stub * stub,grpc::ClientContext * ctx,const SimpleRequest & request,CompletionQueue * cq)322   PrepareReq(BenchmarkService::Stub* stub, grpc::ClientContext* ctx,
323              const SimpleRequest& request, CompletionQueue* cq) {
324     return stub->PrepareAsyncUnaryCall(ctx, request, cq);
325   };
SetupCtx(BenchmarkService::Stub * stub,std::function<gpr_timespec ()> next_issue,const SimpleRequest & req)326   static ClientRpcContext* SetupCtx(BenchmarkService::Stub* stub,
327                                     std::function<gpr_timespec()> next_issue,
328                                     const SimpleRequest& req) {
329     return new ClientRpcContextUnaryImpl<SimpleRequest, SimpleResponse>(
330         stub, req, std::move(next_issue), AsyncUnaryClient::PrepareReq,
331         AsyncUnaryClient::CheckDone);
332   }
333 };
334 
335 template <class RequestType, class ResponseType>
336 class ClientRpcContextStreamingPingPongImpl : public ClientRpcContext {
337  public:
ClientRpcContextStreamingPingPongImpl(BenchmarkService::Stub * stub,const RequestType & req,std::function<gpr_timespec ()> next_issue,std::function<std::unique_ptr<grpc::ClientAsyncReaderWriter<RequestType,ResponseType>> (BenchmarkService::Stub *,grpc::ClientContext *,CompletionQueue *)> prepare_req,std::function<void (grpc::Status,ResponseType *)> on_done)338   ClientRpcContextStreamingPingPongImpl(
339       BenchmarkService::Stub* stub, const RequestType& req,
340       std::function<gpr_timespec()> next_issue,
341       std::function<std::unique_ptr<
342           grpc::ClientAsyncReaderWriter<RequestType, ResponseType>>(
343           BenchmarkService::Stub*, grpc::ClientContext*, CompletionQueue*)>
344           prepare_req,
345       std::function<void(grpc::Status, ResponseType*)> on_done)
346       : context_(),
347         stub_(stub),
348         cq_(nullptr),
349         req_(req),
350         response_(),
351         next_state_(State::INVALID),
352         callback_(on_done),
353         next_issue_(std::move(next_issue)),
354         prepare_req_(prepare_req),
355         coalesce_(false) {}
~ClientRpcContextStreamingPingPongImpl()356   ~ClientRpcContextStreamingPingPongImpl() override {}
Start(CompletionQueue * cq,const ClientConfig & config)357   void Start(CompletionQueue* cq, const ClientConfig& config) override {
358     StartInternal(cq, config.messages_per_stream(), config.use_coalesce_api());
359   }
RunNextState(bool ok,HistogramEntry * entry)360   bool RunNextState(bool ok, HistogramEntry* entry) override {
361     while (true) {
362       switch (next_state_) {
363         case State::STREAM_IDLE:
364           if (!next_issue_) {  // ready to issue
365             next_state_ = State::READY_TO_WRITE;
366           } else {
367             next_state_ = State::WAIT;
368           }
369           break;  // loop around, don't return
370         case State::WAIT:
371           next_state_ = State::READY_TO_WRITE;
372           alarm_.reset(new Alarm);
373           alarm_->Set(cq_, next_issue_(), ClientRpcContext::tag(this));
374           return true;
375         case State::READY_TO_WRITE:
376           if (!ok) {
377             return false;
378           }
379           start_ = UsageTimer::Now();
380           next_state_ = State::WRITE_DONE;
381           if (coalesce_ && messages_issued_ == messages_per_stream_ - 1) {
382             stream_->WriteLast(req_, WriteOptions(),
383                                ClientRpcContext::tag(this));
384           } else {
385             stream_->Write(req_, ClientRpcContext::tag(this));
386           }
387           return true;
388         case State::WRITE_DONE:
389           if (!ok) {
390             return false;
391           }
392           next_state_ = State::READ_DONE;
393           stream_->Read(&response_, ClientRpcContext::tag(this));
394           return true;
395           break;
396         case State::READ_DONE:
397           entry->set_value((UsageTimer::Now() - start_) * 1e9);
398           callback_(status_, &response_);
399           if ((messages_per_stream_ != 0) &&
400               (++messages_issued_ >= messages_per_stream_)) {
401             next_state_ = State::WRITES_DONE_DONE;
402             if (coalesce_) {
403               // WritesDone should have been called on the last Write.
404               // loop around to call Finish.
405               break;
406             }
407             stream_->WritesDone(ClientRpcContext::tag(this));
408             return true;
409           }
410           next_state_ = State::STREAM_IDLE;
411           break;  // loop around
412         case State::WRITES_DONE_DONE:
413           next_state_ = State::FINISH_DONE;
414           stream_->Finish(&status_, ClientRpcContext::tag(this));
415           return true;
416         case State::FINISH_DONE:
417           next_state_ = State::INVALID;
418           return false;
419           break;
420         default:
421           GPR_ASSERT(false);
422           return false;
423       }
424     }
425   }
StartNewClone(CompletionQueue * cq)426   void StartNewClone(CompletionQueue* cq) override {
427     auto* clone = new ClientRpcContextStreamingPingPongImpl(
428         stub_, req_, next_issue_, prepare_req_, callback_);
429     clone->StartInternal(cq, messages_per_stream_, coalesce_);
430   }
TryCancel()431   void TryCancel() override { context_.TryCancel(); }
432 
433  private:
434   grpc::ClientContext context_;
435   BenchmarkService::Stub* stub_;
436   CompletionQueue* cq_;
437   std::unique_ptr<Alarm> alarm_;
438   const RequestType& req_;
439   ResponseType response_;
440   enum State {
441     INVALID,
442     STREAM_IDLE,
443     WAIT,
444     READY_TO_WRITE,
445     WRITE_DONE,
446     READ_DONE,
447     WRITES_DONE_DONE,
448     FINISH_DONE
449   };
450   State next_state_;
451   std::function<void(grpc::Status, ResponseType*)> callback_;
452   std::function<gpr_timespec()> next_issue_;
453   std::function<
454       std::unique_ptr<grpc::ClientAsyncReaderWriter<RequestType, ResponseType>>(
455           BenchmarkService::Stub*, grpc::ClientContext*, CompletionQueue*)>
456       prepare_req_;
457   grpc::Status status_;
458   double start_;
459   std::unique_ptr<grpc::ClientAsyncReaderWriter<RequestType, ResponseType>>
460       stream_;
461 
462   // Allow a limit on number of messages in a stream
463   int messages_per_stream_;
464   int messages_issued_;
465   // Whether to use coalescing API.
466   bool coalesce_;
467 
StartInternal(CompletionQueue * cq,int messages_per_stream,bool coalesce)468   void StartInternal(CompletionQueue* cq, int messages_per_stream,
469                      bool coalesce) {
470     cq_ = cq;
471     messages_per_stream_ = messages_per_stream;
472     messages_issued_ = 0;
473     coalesce_ = coalesce;
474     if (coalesce_) {
475       GPR_ASSERT(messages_per_stream_ != 0);
476       context_.set_initial_metadata_corked(true);
477     }
478     stream_ = prepare_req_(stub_, &context_, cq);
479     next_state_ = State::STREAM_IDLE;
480     stream_->StartCall(ClientRpcContext::tag(this));
481     if (coalesce_) {
482       // When the intial metadata is corked, the tag will not come back and we
483       // need to manually drive the state machine.
484       RunNextState(true, nullptr);
485     }
486   }
487 };
488 
489 class AsyncStreamingPingPongClient final
490     : public AsyncClient<BenchmarkService::Stub, SimpleRequest> {
491  public:
AsyncStreamingPingPongClient(const ClientConfig & config)492   explicit AsyncStreamingPingPongClient(const ClientConfig& config)
493       : AsyncClient<BenchmarkService::Stub, SimpleRequest>(
494             config, SetupCtx, BenchmarkStubCreator) {
495     StartThreads(num_async_threads_);
496   }
497 
~AsyncStreamingPingPongClient()498   ~AsyncStreamingPingPongClient() override {}
499 
500  private:
CheckDone(const grpc::Status & s,SimpleResponse * response)501   static void CheckDone(const grpc::Status& s, SimpleResponse* response) {}
502   static std::unique_ptr<
503       grpc::ClientAsyncReaderWriter<SimpleRequest, SimpleResponse>>
PrepareReq(BenchmarkService::Stub * stub,grpc::ClientContext * ctx,CompletionQueue * cq)504   PrepareReq(BenchmarkService::Stub* stub, grpc::ClientContext* ctx,
505              CompletionQueue* cq) {
506     auto stream = stub->PrepareAsyncStreamingCall(ctx, cq);
507     return stream;
508   };
SetupCtx(BenchmarkService::Stub * stub,std::function<gpr_timespec ()> next_issue,const SimpleRequest & req)509   static ClientRpcContext* SetupCtx(BenchmarkService::Stub* stub,
510                                     std::function<gpr_timespec()> next_issue,
511                                     const SimpleRequest& req) {
512     return new ClientRpcContextStreamingPingPongImpl<SimpleRequest,
513                                                      SimpleResponse>(
514         stub, req, std::move(next_issue),
515         AsyncStreamingPingPongClient::PrepareReq,
516         AsyncStreamingPingPongClient::CheckDone);
517   }
518 };
519 
520 template <class RequestType, class ResponseType>
521 class ClientRpcContextStreamingFromClientImpl : public ClientRpcContext {
522  public:
ClientRpcContextStreamingFromClientImpl(BenchmarkService::Stub * stub,const RequestType & req,std::function<gpr_timespec ()> next_issue,std::function<std::unique_ptr<grpc::ClientAsyncWriter<RequestType>> (BenchmarkService::Stub *,grpc::ClientContext *,ResponseType *,CompletionQueue *)> prepare_req,std::function<void (grpc::Status,ResponseType *)> on_done)523   ClientRpcContextStreamingFromClientImpl(
524       BenchmarkService::Stub* stub, const RequestType& req,
525       std::function<gpr_timespec()> next_issue,
526       std::function<std::unique_ptr<grpc::ClientAsyncWriter<RequestType>>(
527           BenchmarkService::Stub*, grpc::ClientContext*, ResponseType*,
528           CompletionQueue*)>
529           prepare_req,
530       std::function<void(grpc::Status, ResponseType*)> on_done)
531       : context_(),
532         stub_(stub),
533         cq_(nullptr),
534         req_(req),
535         response_(),
536         next_state_(State::INVALID),
537         callback_(on_done),
538         next_issue_(std::move(next_issue)),
539         prepare_req_(prepare_req) {}
~ClientRpcContextStreamingFromClientImpl()540   ~ClientRpcContextStreamingFromClientImpl() override {}
Start(CompletionQueue * cq,const ClientConfig & config)541   void Start(CompletionQueue* cq, const ClientConfig& config) override {
542     GPR_ASSERT(!config.use_coalesce_api());  // not supported yet.
543     StartInternal(cq);
544   }
RunNextState(bool ok,HistogramEntry * entry)545   bool RunNextState(bool ok, HistogramEntry* entry) override {
546     while (true) {
547       switch (next_state_) {
548         case State::STREAM_IDLE:
549           if (!next_issue_) {  // ready to issue
550             next_state_ = State::READY_TO_WRITE;
551           } else {
552             next_state_ = State::WAIT;
553           }
554           break;  // loop around, don't return
555         case State::WAIT:
556           alarm_.reset(new Alarm);
557           alarm_->Set(cq_, next_issue_(), ClientRpcContext::tag(this));
558           next_state_ = State::READY_TO_WRITE;
559           return true;
560         case State::READY_TO_WRITE:
561           if (!ok) {
562             return false;
563           }
564           start_ = UsageTimer::Now();
565           next_state_ = State::WRITE_DONE;
566           stream_->Write(req_, ClientRpcContext::tag(this));
567           return true;
568         case State::WRITE_DONE:
569           if (!ok) {
570             return false;
571           }
572           entry->set_value((UsageTimer::Now() - start_) * 1e9);
573           next_state_ = State::STREAM_IDLE;
574           break;  // loop around
575         default:
576           GPR_ASSERT(false);
577           return false;
578       }
579     }
580   }
StartNewClone(CompletionQueue * cq)581   void StartNewClone(CompletionQueue* cq) override {
582     auto* clone = new ClientRpcContextStreamingFromClientImpl(
583         stub_, req_, next_issue_, prepare_req_, callback_);
584     clone->StartInternal(cq);
585   }
TryCancel()586   void TryCancel() override { context_.TryCancel(); }
587 
588  private:
589   grpc::ClientContext context_;
590   BenchmarkService::Stub* stub_;
591   CompletionQueue* cq_;
592   std::unique_ptr<Alarm> alarm_;
593   const RequestType& req_;
594   ResponseType response_;
595   enum State {
596     INVALID,
597     STREAM_IDLE,
598     WAIT,
599     READY_TO_WRITE,
600     WRITE_DONE,
601   };
602   State next_state_;
603   std::function<void(grpc::Status, ResponseType*)> callback_;
604   std::function<gpr_timespec()> next_issue_;
605   std::function<std::unique_ptr<grpc::ClientAsyncWriter<RequestType>>(
606       BenchmarkService::Stub*, grpc::ClientContext*, ResponseType*,
607       CompletionQueue*)>
608       prepare_req_;
609   grpc::Status status_;
610   double start_;
611   std::unique_ptr<grpc::ClientAsyncWriter<RequestType>> stream_;
612 
StartInternal(CompletionQueue * cq)613   void StartInternal(CompletionQueue* cq) {
614     cq_ = cq;
615     stream_ = prepare_req_(stub_, &context_, &response_, cq);
616     next_state_ = State::STREAM_IDLE;
617     stream_->StartCall(ClientRpcContext::tag(this));
618   }
619 };
620 
621 class AsyncStreamingFromClientClient final
622     : public AsyncClient<BenchmarkService::Stub, SimpleRequest> {
623  public:
AsyncStreamingFromClientClient(const ClientConfig & config)624   explicit AsyncStreamingFromClientClient(const ClientConfig& config)
625       : AsyncClient<BenchmarkService::Stub, SimpleRequest>(
626             config, SetupCtx, BenchmarkStubCreator) {
627     StartThreads(num_async_threads_);
628   }
629 
~AsyncStreamingFromClientClient()630   ~AsyncStreamingFromClientClient() override {}
631 
632  private:
CheckDone(const grpc::Status & s,SimpleResponse * response)633   static void CheckDone(const grpc::Status& s, SimpleResponse* response) {}
PrepareReq(BenchmarkService::Stub * stub,grpc::ClientContext * ctx,SimpleResponse * resp,CompletionQueue * cq)634   static std::unique_ptr<grpc::ClientAsyncWriter<SimpleRequest>> PrepareReq(
635       BenchmarkService::Stub* stub, grpc::ClientContext* ctx,
636       SimpleResponse* resp, CompletionQueue* cq) {
637     auto stream = stub->PrepareAsyncStreamingFromClient(ctx, resp, cq);
638     return stream;
639   };
SetupCtx(BenchmarkService::Stub * stub,std::function<gpr_timespec ()> next_issue,const SimpleRequest & req)640   static ClientRpcContext* SetupCtx(BenchmarkService::Stub* stub,
641                                     std::function<gpr_timespec()> next_issue,
642                                     const SimpleRequest& req) {
643     return new ClientRpcContextStreamingFromClientImpl<SimpleRequest,
644                                                        SimpleResponse>(
645         stub, req, std::move(next_issue),
646         AsyncStreamingFromClientClient::PrepareReq,
647         AsyncStreamingFromClientClient::CheckDone);
648   }
649 };
650 
651 template <class RequestType, class ResponseType>
652 class ClientRpcContextStreamingFromServerImpl : public ClientRpcContext {
653  public:
ClientRpcContextStreamingFromServerImpl(BenchmarkService::Stub * stub,const RequestType & req,std::function<gpr_timespec ()> next_issue,std::function<std::unique_ptr<grpc::ClientAsyncReader<ResponseType>> (BenchmarkService::Stub *,grpc::ClientContext *,const RequestType &,CompletionQueue *)> prepare_req,std::function<void (grpc::Status,ResponseType *)> on_done)654   ClientRpcContextStreamingFromServerImpl(
655       BenchmarkService::Stub* stub, const RequestType& req,
656       std::function<gpr_timespec()> next_issue,
657       std::function<std::unique_ptr<grpc::ClientAsyncReader<ResponseType>>(
658           BenchmarkService::Stub*, grpc::ClientContext*, const RequestType&,
659           CompletionQueue*)>
660           prepare_req,
661       std::function<void(grpc::Status, ResponseType*)> on_done)
662       : context_(),
663         stub_(stub),
664         cq_(nullptr),
665         req_(req),
666         response_(),
667         next_state_(State::INVALID),
668         callback_(on_done),
669         next_issue_(std::move(next_issue)),
670         prepare_req_(prepare_req) {}
~ClientRpcContextStreamingFromServerImpl()671   ~ClientRpcContextStreamingFromServerImpl() override {}
Start(CompletionQueue * cq,const ClientConfig & config)672   void Start(CompletionQueue* cq, const ClientConfig& config) override {
673     GPR_ASSERT(!config.use_coalesce_api());  // not supported
674     StartInternal(cq);
675   }
RunNextState(bool ok,HistogramEntry * entry)676   bool RunNextState(bool ok, HistogramEntry* entry) override {
677     while (true) {
678       switch (next_state_) {
679         case State::STREAM_IDLE:
680           if (!ok) {
681             return false;
682           }
683           start_ = UsageTimer::Now();
684           next_state_ = State::READ_DONE;
685           stream_->Read(&response_, ClientRpcContext::tag(this));
686           return true;
687         case State::READ_DONE:
688           if (!ok) {
689             return false;
690           }
691           entry->set_value((UsageTimer::Now() - start_) * 1e9);
692           callback_(status_, &response_);
693           next_state_ = State::STREAM_IDLE;
694           break;  // loop around
695         default:
696           GPR_ASSERT(false);
697           return false;
698       }
699     }
700   }
StartNewClone(CompletionQueue * cq)701   void StartNewClone(CompletionQueue* cq) override {
702     auto* clone = new ClientRpcContextStreamingFromServerImpl(
703         stub_, req_, next_issue_, prepare_req_, callback_);
704     clone->StartInternal(cq);
705   }
TryCancel()706   void TryCancel() override { context_.TryCancel(); }
707 
708  private:
709   grpc::ClientContext context_;
710   BenchmarkService::Stub* stub_;
711   CompletionQueue* cq_;
712   std::unique_ptr<Alarm> alarm_;
713   const RequestType& req_;
714   ResponseType response_;
715   enum State { INVALID, STREAM_IDLE, READ_DONE };
716   State next_state_;
717   std::function<void(grpc::Status, ResponseType*)> callback_;
718   std::function<gpr_timespec()> next_issue_;
719   std::function<std::unique_ptr<grpc::ClientAsyncReader<ResponseType>>(
720       BenchmarkService::Stub*, grpc::ClientContext*, const RequestType&,
721       CompletionQueue*)>
722       prepare_req_;
723   grpc::Status status_;
724   double start_;
725   std::unique_ptr<grpc::ClientAsyncReader<ResponseType>> stream_;
726 
StartInternal(CompletionQueue * cq)727   void StartInternal(CompletionQueue* cq) {
728     // TODO(vjpai): Add support to rate-pace this
729     cq_ = cq;
730     stream_ = prepare_req_(stub_, &context_, req_, cq);
731     next_state_ = State::STREAM_IDLE;
732     stream_->StartCall(ClientRpcContext::tag(this));
733   }
734 };
735 
736 class AsyncStreamingFromServerClient final
737     : public AsyncClient<BenchmarkService::Stub, SimpleRequest> {
738  public:
AsyncStreamingFromServerClient(const ClientConfig & config)739   explicit AsyncStreamingFromServerClient(const ClientConfig& config)
740       : AsyncClient<BenchmarkService::Stub, SimpleRequest>(
741             config, SetupCtx, BenchmarkStubCreator) {
742     StartThreads(num_async_threads_);
743   }
744 
~AsyncStreamingFromServerClient()745   ~AsyncStreamingFromServerClient() override {}
746 
747  private:
CheckDone(const grpc::Status & s,SimpleResponse * response)748   static void CheckDone(const grpc::Status& s, SimpleResponse* response) {}
PrepareReq(BenchmarkService::Stub * stub,grpc::ClientContext * ctx,const SimpleRequest & req,CompletionQueue * cq)749   static std::unique_ptr<grpc::ClientAsyncReader<SimpleResponse>> PrepareReq(
750       BenchmarkService::Stub* stub, grpc::ClientContext* ctx,
751       const SimpleRequest& req, CompletionQueue* cq) {
752     auto stream = stub->PrepareAsyncStreamingFromServer(ctx, req, cq);
753     return stream;
754   };
SetupCtx(BenchmarkService::Stub * stub,std::function<gpr_timespec ()> next_issue,const SimpleRequest & req)755   static ClientRpcContext* SetupCtx(BenchmarkService::Stub* stub,
756                                     std::function<gpr_timespec()> next_issue,
757                                     const SimpleRequest& req) {
758     return new ClientRpcContextStreamingFromServerImpl<SimpleRequest,
759                                                        SimpleResponse>(
760         stub, req, std::move(next_issue),
761         AsyncStreamingFromServerClient::PrepareReq,
762         AsyncStreamingFromServerClient::CheckDone);
763   }
764 };
765 
766 class ClientRpcContextGenericStreamingImpl : public ClientRpcContext {
767  public:
ClientRpcContextGenericStreamingImpl(grpc::GenericStub * stub,const ByteBuffer & req,std::function<gpr_timespec ()> next_issue,std::function<std::unique_ptr<grpc::GenericClientAsyncReaderWriter> (grpc::GenericStub *,grpc::ClientContext *,const grpc::string & method_name,CompletionQueue *)> prepare_req,std::function<void (grpc::Status,ByteBuffer *)> on_done)768   ClientRpcContextGenericStreamingImpl(
769       grpc::GenericStub* stub, const ByteBuffer& req,
770       std::function<gpr_timespec()> next_issue,
771       std::function<std::unique_ptr<grpc::GenericClientAsyncReaderWriter>(
772           grpc::GenericStub*, grpc::ClientContext*,
773           const grpc::string& method_name, CompletionQueue*)>
774           prepare_req,
775       std::function<void(grpc::Status, ByteBuffer*)> on_done)
776       : context_(),
777         stub_(stub),
778         cq_(nullptr),
779         req_(req),
780         response_(),
781         next_state_(State::INVALID),
782         callback_(std::move(on_done)),
783         next_issue_(std::move(next_issue)),
784         prepare_req_(std::move(prepare_req)) {}
~ClientRpcContextGenericStreamingImpl()785   ~ClientRpcContextGenericStreamingImpl() override {}
Start(CompletionQueue * cq,const ClientConfig & config)786   void Start(CompletionQueue* cq, const ClientConfig& config) override {
787     GPR_ASSERT(!config.use_coalesce_api());  // not supported yet.
788     StartInternal(cq, config.messages_per_stream());
789   }
RunNextState(bool ok,HistogramEntry * entry)790   bool RunNextState(bool ok, HistogramEntry* entry) override {
791     while (true) {
792       switch (next_state_) {
793         case State::STREAM_IDLE:
794           if (!next_issue_) {  // ready to issue
795             next_state_ = State::READY_TO_WRITE;
796           } else {
797             next_state_ = State::WAIT;
798           }
799           break;  // loop around, don't return
800         case State::WAIT:
801           next_state_ = State::READY_TO_WRITE;
802           alarm_.reset(new Alarm);
803           alarm_->Set(cq_, next_issue_(), ClientRpcContext::tag(this));
804           return true;
805         case State::READY_TO_WRITE:
806           if (!ok) {
807             return false;
808           }
809           start_ = UsageTimer::Now();
810           next_state_ = State::WRITE_DONE;
811           stream_->Write(req_, ClientRpcContext::tag(this));
812           return true;
813         case State::WRITE_DONE:
814           if (!ok) {
815             return false;
816           }
817           next_state_ = State::READ_DONE;
818           stream_->Read(&response_, ClientRpcContext::tag(this));
819           return true;
820           break;
821         case State::READ_DONE:
822           entry->set_value((UsageTimer::Now() - start_) * 1e9);
823           callback_(status_, &response_);
824           if ((messages_per_stream_ != 0) &&
825               (++messages_issued_ >= messages_per_stream_)) {
826             next_state_ = State::WRITES_DONE_DONE;
827             stream_->WritesDone(ClientRpcContext::tag(this));
828             return true;
829           }
830           next_state_ = State::STREAM_IDLE;
831           break;  // loop around
832         case State::WRITES_DONE_DONE:
833           next_state_ = State::FINISH_DONE;
834           stream_->Finish(&status_, ClientRpcContext::tag(this));
835           return true;
836         case State::FINISH_DONE:
837           next_state_ = State::INVALID;
838           return false;
839           break;
840         default:
841           GPR_ASSERT(false);
842           return false;
843       }
844     }
845   }
StartNewClone(CompletionQueue * cq)846   void StartNewClone(CompletionQueue* cq) override {
847     auto* clone = new ClientRpcContextGenericStreamingImpl(
848         stub_, req_, next_issue_, prepare_req_, callback_);
849     clone->StartInternal(cq, messages_per_stream_);
850   }
TryCancel()851   void TryCancel() override { context_.TryCancel(); }
852 
853  private:
854   grpc::ClientContext context_;
855   grpc::GenericStub* stub_;
856   CompletionQueue* cq_;
857   std::unique_ptr<Alarm> alarm_;
858   ByteBuffer req_;
859   ByteBuffer response_;
860   enum State {
861     INVALID,
862     STREAM_IDLE,
863     WAIT,
864     READY_TO_WRITE,
865     WRITE_DONE,
866     READ_DONE,
867     WRITES_DONE_DONE,
868     FINISH_DONE
869   };
870   State next_state_;
871   std::function<void(grpc::Status, ByteBuffer*)> callback_;
872   std::function<gpr_timespec()> next_issue_;
873   std::function<std::unique_ptr<grpc::GenericClientAsyncReaderWriter>(
874       grpc::GenericStub*, grpc::ClientContext*, const grpc::string&,
875       CompletionQueue*)>
876       prepare_req_;
877   grpc::Status status_;
878   double start_;
879   std::unique_ptr<grpc::GenericClientAsyncReaderWriter> stream_;
880 
881   // Allow a limit on number of messages in a stream
882   int messages_per_stream_;
883   int messages_issued_;
884 
StartInternal(CompletionQueue * cq,int messages_per_stream)885   void StartInternal(CompletionQueue* cq, int messages_per_stream) {
886     cq_ = cq;
887     const grpc::string kMethodName(
888         "/grpc.testing.BenchmarkService/StreamingCall");
889     messages_per_stream_ = messages_per_stream;
890     messages_issued_ = 0;
891     stream_ = prepare_req_(stub_, &context_, kMethodName, cq);
892     next_state_ = State::STREAM_IDLE;
893     stream_->StartCall(ClientRpcContext::tag(this));
894   }
895 };
896 
GenericStubCreator(const std::shared_ptr<Channel> & ch)897 static std::unique_ptr<grpc::GenericStub> GenericStubCreator(
898     const std::shared_ptr<Channel>& ch) {
899   return std::unique_ptr<grpc::GenericStub>(new grpc::GenericStub(ch));
900 }
901 
902 class GenericAsyncStreamingClient final
903     : public AsyncClient<grpc::GenericStub, ByteBuffer> {
904  public:
GenericAsyncStreamingClient(const ClientConfig & config)905   explicit GenericAsyncStreamingClient(const ClientConfig& config)
906       : AsyncClient<grpc::GenericStub, ByteBuffer>(config, SetupCtx,
907                                                    GenericStubCreator) {
908     StartThreads(num_async_threads_);
909   }
910 
~GenericAsyncStreamingClient()911   ~GenericAsyncStreamingClient() override {}
912 
913  private:
CheckDone(const grpc::Status & s,ByteBuffer * response)914   static void CheckDone(const grpc::Status& s, ByteBuffer* response) {}
PrepareReq(grpc::GenericStub * stub,grpc::ClientContext * ctx,const grpc::string & method_name,CompletionQueue * cq)915   static std::unique_ptr<grpc::GenericClientAsyncReaderWriter> PrepareReq(
916       grpc::GenericStub* stub, grpc::ClientContext* ctx,
917       const grpc::string& method_name, CompletionQueue* cq) {
918     auto stream = stub->PrepareCall(ctx, method_name, cq);
919     return stream;
920   };
SetupCtx(grpc::GenericStub * stub,std::function<gpr_timespec ()> next_issue,const ByteBuffer & req)921   static ClientRpcContext* SetupCtx(grpc::GenericStub* stub,
922                                     std::function<gpr_timespec()> next_issue,
923                                     const ByteBuffer& req) {
924     return new ClientRpcContextGenericStreamingImpl(
925         stub, req, std::move(next_issue),
926         GenericAsyncStreamingClient::PrepareReq,
927         GenericAsyncStreamingClient::CheckDone);
928   }
929 };
930 
CreateAsyncClient(const ClientConfig & config)931 std::unique_ptr<Client> CreateAsyncClient(const ClientConfig& config) {
932   switch (config.rpc_type()) {
933     case UNARY:
934       return std::unique_ptr<Client>(new AsyncUnaryClient(config));
935     case STREAMING:
936       return std::unique_ptr<Client>(new AsyncStreamingPingPongClient(config));
937     case STREAMING_FROM_CLIENT:
938       return std::unique_ptr<Client>(
939           new AsyncStreamingFromClientClient(config));
940     case STREAMING_FROM_SERVER:
941       return std::unique_ptr<Client>(
942           new AsyncStreamingFromServerClient(config));
943     case STREAMING_BOTH_WAYS:
944       // TODO(vjpai): Implement this
945       assert(false);
946       return nullptr;
947     default:
948       assert(false);
949       return nullptr;
950   }
951 }
CreateGenericAsyncStreamingClient(const ClientConfig & args)952 std::unique_ptr<Client> CreateGenericAsyncStreamingClient(
953     const ClientConfig& args) {
954   return std::unique_ptr<Client>(new GenericAsyncStreamingClient(args));
955 }
956 
957 }  // namespace testing
958 }  // namespace grpc
959