1 /*
2  *
3  * Copyright 2015 gRPC authors.
4  *
5  * Licensed under the Apache License, Version 2.0 (the "License");
6  * you may not use this file except in compliance with the License.
7  * You may obtain a copy of the License at
8  *
9  *     http://www.apache.org/licenses/LICENSE-2.0
10  *
11  * Unless required by applicable law or agreed to in writing, software
12  * distributed under the License is distributed on an "AS IS" BASIS,
13  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14  * See the License for the specific language governing permissions and
15  * limitations under the License.
16  *
17  */
18 
19 #include <grpc/support/port_platform.h>
20 
21 #include "src/core/lib/security/transport/security_handshaker.h"
22 
23 #include <stdbool.h>
24 #include <string.h>
25 #include <limits>
26 
27 #include <grpc/slice_buffer.h>
28 #include <grpc/support/alloc.h>
29 #include <grpc/support/log.h>
30 
31 #include "src/core/lib/channel/channel_args.h"
32 #include "src/core/lib/channel/handshaker.h"
33 #include "src/core/lib/channel/handshaker_registry.h"
34 #include "src/core/lib/gprpp/ref_counted_ptr.h"
35 #include "src/core/lib/security/context/security_context.h"
36 #include "src/core/lib/security/transport/secure_endpoint.h"
37 #include "src/core/lib/security/transport/tsi_error.h"
38 #include "src/core/lib/slice/slice_internal.h"
39 #include "src/core/tsi/transport_security_grpc.h"
40 
41 #define GRPC_INITIAL_HANDSHAKE_BUFFER_SIZE 256
42 
43 namespace grpc_core {
44 
45 namespace {
46 
47 class SecurityHandshaker : public Handshaker {
48  public:
49   SecurityHandshaker(tsi_handshaker* handshaker,
50                      grpc_security_connector* connector,
51                      const grpc_channel_args* args);
52   ~SecurityHandshaker() override;
53   void Shutdown(grpc_error* why) override;
54   void DoHandshake(grpc_tcp_server_acceptor* acceptor,
55                    grpc_closure* on_handshake_done,
56                    HandshakerArgs* args) override;
name() const57   const char* name() const override { return "security"; }
58 
59  private:
60   grpc_error* DoHandshakerNextLocked(const unsigned char* bytes_received,
61                                      size_t bytes_received_size);
62 
63   grpc_error* OnHandshakeNextDoneLocked(
64       tsi_result result, const unsigned char* bytes_to_send,
65       size_t bytes_to_send_size, tsi_handshaker_result* handshaker_result);
66   void HandshakeFailedLocked(grpc_error* error);
67   void CleanupArgsForFailureLocked();
68 
69   static void OnHandshakeDataReceivedFromPeerFn(void* arg, grpc_error* error);
70   static void OnHandshakeDataSentToPeerFn(void* arg, grpc_error* error);
71   static void OnHandshakeDataReceivedFromPeerFnScheduler(void* arg,
72                                                          grpc_error* error);
73   static void OnHandshakeDataSentToPeerFnScheduler(void* arg,
74                                                    grpc_error* error);
75   static void OnHandshakeNextDoneGrpcWrapper(
76       tsi_result result, void* user_data, const unsigned char* bytes_to_send,
77       size_t bytes_to_send_size, tsi_handshaker_result* handshaker_result);
78   static void OnPeerCheckedFn(void* arg, grpc_error* error);
79   void OnPeerCheckedInner(grpc_error* error);
80   size_t MoveReadBufferIntoHandshakeBuffer();
81   grpc_error* CheckPeerLocked();
82 
83   // State set at creation time.
84   tsi_handshaker* handshaker_;
85   RefCountedPtr<grpc_security_connector> connector_;
86 
87   gpr_mu mu_;
88 
89   bool is_shutdown_ = false;
90   // Endpoint and read buffer to destroy after a shutdown.
91   grpc_endpoint* endpoint_to_destroy_ = nullptr;
92   grpc_slice_buffer* read_buffer_to_destroy_ = nullptr;
93 
94   // State saved while performing the handshake.
95   HandshakerArgs* args_ = nullptr;
96   grpc_closure* on_handshake_done_ = nullptr;
97 
98   size_t handshake_buffer_size_;
99   unsigned char* handshake_buffer_;
100   grpc_slice_buffer outgoing_;
101   grpc_closure on_handshake_data_sent_to_peer_;
102   grpc_closure on_handshake_data_received_from_peer_;
103   grpc_closure on_peer_checked_;
104   RefCountedPtr<grpc_auth_context> auth_context_;
105   tsi_handshaker_result* handshaker_result_ = nullptr;
106   size_t max_frame_size_ = 0;
107 };
108 
SecurityHandshaker(tsi_handshaker * handshaker,grpc_security_connector * connector,const grpc_channel_args * args)109 SecurityHandshaker::SecurityHandshaker(tsi_handshaker* handshaker,
110                                        grpc_security_connector* connector,
111                                        const grpc_channel_args* args)
112     : handshaker_(handshaker),
113       connector_(connector->Ref(DEBUG_LOCATION, "handshake")),
114       handshake_buffer_size_(GRPC_INITIAL_HANDSHAKE_BUFFER_SIZE),
115       handshake_buffer_(
116           static_cast<uint8_t*>(gpr_malloc(handshake_buffer_size_))) {
117   const grpc_arg* arg =
118       grpc_channel_args_find(args, GRPC_ARG_TSI_MAX_FRAME_SIZE);
119   if (arg != nullptr && arg->type == GRPC_ARG_INTEGER) {
120     max_frame_size_ = grpc_channel_arg_get_integer(
121         arg, {0, 0, std::numeric_limits<int>::max()});
122   }
123   gpr_mu_init(&mu_);
124   grpc_slice_buffer_init(&outgoing_);
125   GRPC_CLOSURE_INIT(&on_peer_checked_, &SecurityHandshaker::OnPeerCheckedFn,
126                     this, grpc_schedule_on_exec_ctx);
127 }
128 
~SecurityHandshaker()129 SecurityHandshaker::~SecurityHandshaker() {
130   gpr_mu_destroy(&mu_);
131   tsi_handshaker_destroy(handshaker_);
132   tsi_handshaker_result_destroy(handshaker_result_);
133   if (endpoint_to_destroy_ != nullptr) {
134     grpc_endpoint_destroy(endpoint_to_destroy_);
135   }
136   if (read_buffer_to_destroy_ != nullptr) {
137     grpc_slice_buffer_destroy_internal(read_buffer_to_destroy_);
138     gpr_free(read_buffer_to_destroy_);
139   }
140   gpr_free(handshake_buffer_);
141   grpc_slice_buffer_destroy_internal(&outgoing_);
142   auth_context_.reset(DEBUG_LOCATION, "handshake");
143   connector_.reset(DEBUG_LOCATION, "handshake");
144 }
145 
MoveReadBufferIntoHandshakeBuffer()146 size_t SecurityHandshaker::MoveReadBufferIntoHandshakeBuffer() {
147   size_t bytes_in_read_buffer = args_->read_buffer->length;
148   if (handshake_buffer_size_ < bytes_in_read_buffer) {
149     handshake_buffer_ = static_cast<uint8_t*>(
150         gpr_realloc(handshake_buffer_, bytes_in_read_buffer));
151     handshake_buffer_size_ = bytes_in_read_buffer;
152   }
153   size_t offset = 0;
154   while (args_->read_buffer->count > 0) {
155     grpc_slice* next_slice = grpc_slice_buffer_peek_first(args_->read_buffer);
156     memcpy(handshake_buffer_ + offset, GRPC_SLICE_START_PTR(*next_slice),
157            GRPC_SLICE_LENGTH(*next_slice));
158     offset += GRPC_SLICE_LENGTH(*next_slice);
159     grpc_slice_buffer_remove_first(args_->read_buffer);
160   }
161   return bytes_in_read_buffer;
162 }
163 
164 // Set args_ fields to NULL, saving the endpoint and read buffer for
165 // later destruction.
CleanupArgsForFailureLocked()166 void SecurityHandshaker::CleanupArgsForFailureLocked() {
167   endpoint_to_destroy_ = args_->endpoint;
168   args_->endpoint = nullptr;
169   read_buffer_to_destroy_ = args_->read_buffer;
170   args_->read_buffer = nullptr;
171   grpc_channel_args_destroy(args_->args);
172   args_->args = nullptr;
173 }
174 
175 // If the handshake failed or we're shutting down, clean up and invoke the
176 // callback with the error.
HandshakeFailedLocked(grpc_error * error)177 void SecurityHandshaker::HandshakeFailedLocked(grpc_error* error) {
178   if (error == GRPC_ERROR_NONE) {
179     // If we were shut down after the handshake succeeded but before an
180     // endpoint callback was invoked, we need to generate our own error.
181     error = GRPC_ERROR_CREATE_FROM_STATIC_STRING("Handshaker shutdown");
182   }
183   const char* msg = grpc_error_string(error);
184   gpr_log(GPR_DEBUG, "Security handshake failed: %s", msg);
185 
186   if (!is_shutdown_) {
187     tsi_handshaker_shutdown(handshaker_);
188     // TODO(ctiller): It is currently necessary to shutdown endpoints
189     // before destroying them, even if we know that there are no
190     // pending read/write callbacks.  This should be fixed, at which
191     // point this can be removed.
192     grpc_endpoint_shutdown(args_->endpoint, GRPC_ERROR_REF(error));
193     // Not shutting down, so the write failed.  Clean up before
194     // invoking the callback.
195     CleanupArgsForFailureLocked();
196     // Set shutdown to true so that subsequent calls to
197     // security_handshaker_shutdown() do nothing.
198     is_shutdown_ = true;
199   }
200   // Invoke callback.
201   ExecCtx::Run(DEBUG_LOCATION, on_handshake_done_, error);
202 }
203 
OnPeerCheckedInner(grpc_error * error)204 void SecurityHandshaker::OnPeerCheckedInner(grpc_error* error) {
205   MutexLock lock(&mu_);
206   if (error != GRPC_ERROR_NONE || is_shutdown_) {
207     HandshakeFailedLocked(error);
208     return;
209   }
210   // Create zero-copy frame protector, if implemented.
211   tsi_zero_copy_grpc_protector* zero_copy_protector = nullptr;
212   tsi_result result = tsi_handshaker_result_create_zero_copy_grpc_protector(
213       handshaker_result_, max_frame_size_ == 0 ? nullptr : &max_frame_size_,
214       &zero_copy_protector);
215   if (result != TSI_OK && result != TSI_UNIMPLEMENTED) {
216     error = grpc_set_tsi_error_result(
217         GRPC_ERROR_CREATE_FROM_STATIC_STRING(
218             "Zero-copy frame protector creation failed"),
219         result);
220     HandshakeFailedLocked(error);
221     return;
222   }
223   // Create frame protector if zero-copy frame protector is NULL.
224   tsi_frame_protector* protector = nullptr;
225   if (zero_copy_protector == nullptr) {
226     result = tsi_handshaker_result_create_frame_protector(
227         handshaker_result_, max_frame_size_ == 0 ? nullptr : &max_frame_size_,
228         &protector);
229     if (result != TSI_OK) {
230       error = grpc_set_tsi_error_result(GRPC_ERROR_CREATE_FROM_STATIC_STRING(
231                                             "Frame protector creation failed"),
232                                         result);
233       HandshakeFailedLocked(error);
234       return;
235     }
236   }
237   // Get unused bytes.
238   const unsigned char* unused_bytes = nullptr;
239   size_t unused_bytes_size = 0;
240   result = tsi_handshaker_result_get_unused_bytes(
241       handshaker_result_, &unused_bytes, &unused_bytes_size);
242   // Create secure endpoint.
243   if (unused_bytes_size > 0) {
244     grpc_slice slice = grpc_slice_from_copied_buffer(
245         reinterpret_cast<const char*>(unused_bytes), unused_bytes_size);
246     args_->endpoint = grpc_secure_endpoint_create(
247         protector, zero_copy_protector, args_->endpoint, &slice, 1);
248     grpc_slice_unref_internal(slice);
249   } else {
250     args_->endpoint = grpc_secure_endpoint_create(
251         protector, zero_copy_protector, args_->endpoint, nullptr, 0);
252   }
253   tsi_handshaker_result_destroy(handshaker_result_);
254   handshaker_result_ = nullptr;
255   // Add auth context to channel args.
256   grpc_arg auth_context_arg = grpc_auth_context_to_arg(auth_context_.get());
257   grpc_channel_args* tmp_args = args_->args;
258   args_->args = grpc_channel_args_copy_and_add(tmp_args, &auth_context_arg, 1);
259   grpc_channel_args_destroy(tmp_args);
260   // Invoke callback.
261   ExecCtx::Run(DEBUG_LOCATION, on_handshake_done_, GRPC_ERROR_NONE);
262   // Set shutdown to true so that subsequent calls to
263   // security_handshaker_shutdown() do nothing.
264   is_shutdown_ = true;
265 }
266 
OnPeerCheckedFn(void * arg,grpc_error * error)267 void SecurityHandshaker::OnPeerCheckedFn(void* arg, grpc_error* error) {
268   RefCountedPtr<SecurityHandshaker>(static_cast<SecurityHandshaker*>(arg))
269       ->OnPeerCheckedInner(GRPC_ERROR_REF(error));
270 }
271 
CheckPeerLocked()272 grpc_error* SecurityHandshaker::CheckPeerLocked() {
273   tsi_peer peer;
274   tsi_result result =
275       tsi_handshaker_result_extract_peer(handshaker_result_, &peer);
276   if (result != TSI_OK) {
277     return grpc_set_tsi_error_result(
278         GRPC_ERROR_CREATE_FROM_STATIC_STRING("Peer extraction failed"), result);
279   }
280   connector_->check_peer(peer, args_->endpoint, &auth_context_,
281                          &on_peer_checked_);
282   return GRPC_ERROR_NONE;
283 }
284 
OnHandshakeNextDoneLocked(tsi_result result,const unsigned char * bytes_to_send,size_t bytes_to_send_size,tsi_handshaker_result * handshaker_result)285 grpc_error* SecurityHandshaker::OnHandshakeNextDoneLocked(
286     tsi_result result, const unsigned char* bytes_to_send,
287     size_t bytes_to_send_size, tsi_handshaker_result* handshaker_result) {
288   grpc_error* error = GRPC_ERROR_NONE;
289   // Handshaker was shutdown.
290   if (is_shutdown_) {
291     return GRPC_ERROR_CREATE_FROM_STATIC_STRING("Handshaker shutdown");
292   }
293   // Read more if we need to.
294   if (result == TSI_INCOMPLETE_DATA) {
295     GPR_ASSERT(bytes_to_send_size == 0);
296     grpc_endpoint_read(
297         args_->endpoint, args_->read_buffer,
298         GRPC_CLOSURE_INIT(
299             &on_handshake_data_received_from_peer_,
300             &SecurityHandshaker::OnHandshakeDataReceivedFromPeerFnScheduler,
301             this, grpc_schedule_on_exec_ctx),
302         /*urgent=*/true);
303     return error;
304   }
305   if (result != TSI_OK) {
306     return grpc_set_tsi_error_result(
307         GRPC_ERROR_CREATE_FROM_STATIC_STRING("Handshake failed"), result);
308   }
309   // Update handshaker result.
310   if (handshaker_result != nullptr) {
311     GPR_ASSERT(handshaker_result_ == nullptr);
312     handshaker_result_ = handshaker_result;
313   }
314   if (bytes_to_send_size > 0) {
315     // Send data to peer, if needed.
316     grpc_slice to_send = grpc_slice_from_copied_buffer(
317         reinterpret_cast<const char*>(bytes_to_send), bytes_to_send_size);
318     grpc_slice_buffer_reset_and_unref_internal(&outgoing_);
319     grpc_slice_buffer_add(&outgoing_, to_send);
320     grpc_endpoint_write(
321         args_->endpoint, &outgoing_,
322         GRPC_CLOSURE_INIT(
323             &on_handshake_data_sent_to_peer_,
324             &SecurityHandshaker::OnHandshakeDataSentToPeerFnScheduler, this,
325             grpc_schedule_on_exec_ctx),
326         nullptr);
327   } else if (handshaker_result == nullptr) {
328     // There is nothing to send, but need to read from peer.
329     grpc_endpoint_read(
330         args_->endpoint, args_->read_buffer,
331         GRPC_CLOSURE_INIT(
332             &on_handshake_data_received_from_peer_,
333             &SecurityHandshaker::OnHandshakeDataReceivedFromPeerFnScheduler,
334             this, grpc_schedule_on_exec_ctx),
335         /*urgent=*/true);
336   } else {
337     // Handshake has finished, check peer and so on.
338     error = CheckPeerLocked();
339   }
340   return error;
341 }
342 
OnHandshakeNextDoneGrpcWrapper(tsi_result result,void * user_data,const unsigned char * bytes_to_send,size_t bytes_to_send_size,tsi_handshaker_result * handshaker_result)343 void SecurityHandshaker::OnHandshakeNextDoneGrpcWrapper(
344     tsi_result result, void* user_data, const unsigned char* bytes_to_send,
345     size_t bytes_to_send_size, tsi_handshaker_result* handshaker_result) {
346   RefCountedPtr<SecurityHandshaker> h(
347       static_cast<SecurityHandshaker*>(user_data));
348   MutexLock lock(&h->mu_);
349   grpc_error* error = h->OnHandshakeNextDoneLocked(
350       result, bytes_to_send, bytes_to_send_size, handshaker_result);
351   if (error != GRPC_ERROR_NONE) {
352     h->HandshakeFailedLocked(error);
353   } else {
354     h.release();  // Avoid unref
355   }
356 }
357 
DoHandshakerNextLocked(const unsigned char * bytes_received,size_t bytes_received_size)358 grpc_error* SecurityHandshaker::DoHandshakerNextLocked(
359     const unsigned char* bytes_received, size_t bytes_received_size) {
360   // Invoke TSI handshaker.
361   const unsigned char* bytes_to_send = nullptr;
362   size_t bytes_to_send_size = 0;
363   tsi_handshaker_result* hs_result = nullptr;
364   tsi_result result = tsi_handshaker_next(
365       handshaker_, bytes_received, bytes_received_size, &bytes_to_send,
366       &bytes_to_send_size, &hs_result, &OnHandshakeNextDoneGrpcWrapper, this);
367   if (result == TSI_ASYNC) {
368     // Handshaker operating asynchronously. Nothing else to do here;
369     // callback will be invoked in a TSI thread.
370     return GRPC_ERROR_NONE;
371   }
372   // Handshaker returned synchronously. Invoke callback directly in
373   // this thread with our existing exec_ctx.
374   return OnHandshakeNextDoneLocked(result, bytes_to_send, bytes_to_send_size,
375                                    hs_result);
376 }
377 
378 // This callback might be run inline while we are still holding on to the mutex,
379 // so schedule OnHandshakeDataReceivedFromPeerFn on ExecCtx to avoid a deadlock.
OnHandshakeDataReceivedFromPeerFnScheduler(void * arg,grpc_error * error)380 void SecurityHandshaker::OnHandshakeDataReceivedFromPeerFnScheduler(
381     void* arg, grpc_error* error) {
382   SecurityHandshaker* h = static_cast<SecurityHandshaker*>(arg);
383   grpc_core::ExecCtx::Run(
384       DEBUG_LOCATION,
385       GRPC_CLOSURE_INIT(&h->on_handshake_data_received_from_peer_,
386                         &SecurityHandshaker::OnHandshakeDataReceivedFromPeerFn,
387                         h, grpc_schedule_on_exec_ctx),
388       GRPC_ERROR_REF(error));
389 }
390 
OnHandshakeDataReceivedFromPeerFn(void * arg,grpc_error * error)391 void SecurityHandshaker::OnHandshakeDataReceivedFromPeerFn(void* arg,
392                                                            grpc_error* error) {
393   RefCountedPtr<SecurityHandshaker> h(static_cast<SecurityHandshaker*>(arg));
394   MutexLock lock(&h->mu_);
395   if (error != GRPC_ERROR_NONE || h->is_shutdown_) {
396     h->HandshakeFailedLocked(GRPC_ERROR_CREATE_REFERENCING_FROM_STATIC_STRING(
397         "Handshake read failed", &error, 1));
398     return;
399   }
400   // Copy all slices received.
401   size_t bytes_received_size = h->MoveReadBufferIntoHandshakeBuffer();
402   // Call TSI handshaker.
403   error = h->DoHandshakerNextLocked(h->handshake_buffer_, bytes_received_size);
404 
405   if (error != GRPC_ERROR_NONE) {
406     h->HandshakeFailedLocked(error);
407   } else {
408     h.release();  // Avoid unref
409   }
410 }
411 
412 // This callback might be run inline while we are still holding on to the mutex,
413 // so schedule OnHandshakeDataSentToPeerFn on ExecCtx to avoid a deadlock.
OnHandshakeDataSentToPeerFnScheduler(void * arg,grpc_error * error)414 void SecurityHandshaker::OnHandshakeDataSentToPeerFnScheduler(
415     void* arg, grpc_error* error) {
416   SecurityHandshaker* h = static_cast<SecurityHandshaker*>(arg);
417   grpc_core::ExecCtx::Run(
418       DEBUG_LOCATION,
419       GRPC_CLOSURE_INIT(&h->on_handshake_data_sent_to_peer_,
420                         &SecurityHandshaker::OnHandshakeDataSentToPeerFn, h,
421                         grpc_schedule_on_exec_ctx),
422       GRPC_ERROR_REF(error));
423 }
424 
OnHandshakeDataSentToPeerFn(void * arg,grpc_error * error)425 void SecurityHandshaker::OnHandshakeDataSentToPeerFn(void* arg,
426                                                      grpc_error* error) {
427   RefCountedPtr<SecurityHandshaker> h(static_cast<SecurityHandshaker*>(arg));
428   MutexLock lock(&h->mu_);
429   if (error != GRPC_ERROR_NONE || h->is_shutdown_) {
430     h->HandshakeFailedLocked(GRPC_ERROR_CREATE_REFERENCING_FROM_STATIC_STRING(
431         "Handshake write failed", &error, 1));
432     return;
433   }
434   // We may be done.
435   if (h->handshaker_result_ == nullptr) {
436     grpc_endpoint_read(
437         h->args_->endpoint, h->args_->read_buffer,
438         GRPC_CLOSURE_INIT(
439             &h->on_handshake_data_received_from_peer_,
440             &SecurityHandshaker::OnHandshakeDataReceivedFromPeerFnScheduler,
441             h.get(), grpc_schedule_on_exec_ctx),
442         /*urgent=*/true);
443   } else {
444     error = h->CheckPeerLocked();
445     if (error != GRPC_ERROR_NONE) {
446       h->HandshakeFailedLocked(error);
447       return;
448     }
449   }
450   h.release();  // Avoid unref
451 }
452 
453 //
454 // public handshaker API
455 //
456 
Shutdown(grpc_error * why)457 void SecurityHandshaker::Shutdown(grpc_error* why) {
458   MutexLock lock(&mu_);
459   if (!is_shutdown_) {
460     is_shutdown_ = true;
461     tsi_handshaker_shutdown(handshaker_);
462     grpc_endpoint_shutdown(args_->endpoint, GRPC_ERROR_REF(why));
463     CleanupArgsForFailureLocked();
464   }
465   GRPC_ERROR_UNREF(why);
466 }
467 
DoHandshake(grpc_tcp_server_acceptor *,grpc_closure * on_handshake_done,HandshakerArgs * args)468 void SecurityHandshaker::DoHandshake(grpc_tcp_server_acceptor* /*acceptor*/,
469                                      grpc_closure* on_handshake_done,
470                                      HandshakerArgs* args) {
471   auto ref = Ref();
472   MutexLock lock(&mu_);
473   args_ = args;
474   on_handshake_done_ = on_handshake_done;
475   size_t bytes_received_size = MoveReadBufferIntoHandshakeBuffer();
476   grpc_error* error =
477       DoHandshakerNextLocked(handshake_buffer_, bytes_received_size);
478   if (error != GRPC_ERROR_NONE) {
479     HandshakeFailedLocked(error);
480   } else {
481     ref.release();  // Avoid unref
482   }
483 }
484 
485 //
486 // FailHandshaker
487 //
488 
489 class FailHandshaker : public Handshaker {
490  public:
name() const491   const char* name() const override { return "security_fail"; }
Shutdown(grpc_error * why)492   void Shutdown(grpc_error* why) override { GRPC_ERROR_UNREF(why); }
DoHandshake(grpc_tcp_server_acceptor *,grpc_closure * on_handshake_done,HandshakerArgs *)493   void DoHandshake(grpc_tcp_server_acceptor* /*acceptor*/,
494                    grpc_closure* on_handshake_done,
495                    HandshakerArgs* /*args*/) override {
496     ExecCtx::Run(DEBUG_LOCATION, on_handshake_done,
497                  GRPC_ERROR_CREATE_FROM_STATIC_STRING(
498                      "Failed to create security handshaker"));
499   }
500 
501  private:
502   ~FailHandshaker() override = default;
503 };
504 
505 //
506 // handshaker factories
507 //
508 
509 class ClientSecurityHandshakerFactory : public HandshakerFactory {
510  public:
AddHandshakers(const grpc_channel_args * args,grpc_pollset_set * interested_parties,HandshakeManager * handshake_mgr)511   void AddHandshakers(const grpc_channel_args* args,
512                       grpc_pollset_set* interested_parties,
513                       HandshakeManager* handshake_mgr) override {
514     auto* security_connector =
515         reinterpret_cast<grpc_channel_security_connector*>(
516             grpc_security_connector_find_in_args(args));
517     if (security_connector) {
518       security_connector->add_handshakers(args, interested_parties,
519                                           handshake_mgr);
520     }
521   }
522   ~ClientSecurityHandshakerFactory() override = default;
523 };
524 
525 class ServerSecurityHandshakerFactory : public HandshakerFactory {
526  public:
AddHandshakers(const grpc_channel_args * args,grpc_pollset_set * interested_parties,HandshakeManager * handshake_mgr)527   void AddHandshakers(const grpc_channel_args* args,
528                       grpc_pollset_set* interested_parties,
529                       HandshakeManager* handshake_mgr) override {
530     auto* security_connector =
531         reinterpret_cast<grpc_server_security_connector*>(
532             grpc_security_connector_find_in_args(args));
533     if (security_connector) {
534       security_connector->add_handshakers(args, interested_parties,
535                                           handshake_mgr);
536     }
537   }
538   ~ServerSecurityHandshakerFactory() override = default;
539 };
540 
541 }  // namespace
542 
543 //
544 // exported functions
545 //
546 
SecurityHandshakerCreate(tsi_handshaker * handshaker,grpc_security_connector * connector,const grpc_channel_args * args)547 RefCountedPtr<Handshaker> SecurityHandshakerCreate(
548     tsi_handshaker* handshaker, grpc_security_connector* connector,
549     const grpc_channel_args* args) {
550   // If no TSI handshaker was created, return a handshaker that always fails.
551   // Otherwise, return a real security handshaker.
552   if (handshaker == nullptr) {
553     return MakeRefCounted<FailHandshaker>();
554   } else {
555     return MakeRefCounted<SecurityHandshaker>(handshaker, connector, args);
556   }
557 }
558 
SecurityRegisterHandshakerFactories()559 void SecurityRegisterHandshakerFactories() {
560   HandshakerRegistry::RegisterHandshakerFactory(
561       false /* at_start */, HANDSHAKER_CLIENT,
562       absl::make_unique<ClientSecurityHandshakerFactory>());
563   HandshakerRegistry::RegisterHandshakerFactory(
564       false /* at_start */, HANDSHAKER_SERVER,
565       absl::make_unique<ServerSecurityHandshakerFactory>());
566 }
567 
568 }  // namespace grpc_core
569 
grpc_security_handshaker_create(tsi_handshaker * handshaker,grpc_security_connector * connector,const grpc_channel_args * args)570 grpc_handshaker* grpc_security_handshaker_create(
571     tsi_handshaker* handshaker, grpc_security_connector* connector,
572     const grpc_channel_args* args) {
573   return SecurityHandshakerCreate(handshaker, connector, args).release();
574 }
575