1 /*
2  *
3  * Copyright 2015 gRPC authors.
4  *
5  * Licensed under the Apache License, Version 2.0 (the "License");
6  * you may not use this file except in compliance with the License.
7  * You may obtain a copy of the License at
8  *
9  *     http://www.apache.org/licenses/LICENSE-2.0
10  *
11  * Unless required by applicable law or agreed to in writing, software
12  * distributed under the License is distributed on an "AS IS" BASIS,
13  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14  * See the License for the specific language governing permissions and
15  * limitations under the License.
16  *
17  */
18 
19 /* With the addition of a libuv endpoint, sockaddr.h now includes uv.h when
20    using that endpoint. Because of various transitive includes in uv.h,
21    including windows.h on Windows, uv.h must be included before other system
22    headers. Therefore, sockaddr.h must always be included first */
23 #include <grpc/support/port_platform.h>
24 
25 #include "src/core/lib/iomgr/sockaddr.h"
26 
27 #include <grpc/slice.h>
28 #include <grpc/slice_buffer.h>
29 #include <grpc/support/alloc.h>
30 #include <grpc/support/log.h>
31 #include <grpc/support/sync.h>
32 #include "src/core/lib/debug/trace.h"
33 #include "src/core/lib/gpr/string.h"
34 #include "src/core/lib/profiling/timers.h"
35 #include "src/core/lib/security/transport/secure_endpoint.h"
36 #include "src/core/lib/security/transport/tsi_error.h"
37 #include "src/core/lib/slice/slice_internal.h"
38 #include "src/core/lib/slice/slice_string_helpers.h"
39 #include "src/core/tsi/transport_security_grpc.h"
40 
41 #define STAGING_BUFFER_SIZE 8192
42 
43 typedef struct {
44   grpc_endpoint base;
45   grpc_endpoint* wrapped_ep;
46   struct tsi_frame_protector* protector;
47   struct tsi_zero_copy_grpc_protector* zero_copy_protector;
48   gpr_mu protector_mu;
49   /* saved upper level callbacks and user_data. */
50   grpc_closure* read_cb;
51   grpc_closure* write_cb;
52   grpc_closure on_read;
53   grpc_slice_buffer* read_buffer;
54   grpc_slice_buffer source_buffer;
55   /* saved handshaker leftover data to unprotect. */
56   grpc_slice_buffer leftover_bytes;
57   /* buffers for read and write */
58   grpc_slice read_staging_buffer;
59 
60   grpc_slice write_staging_buffer;
61   grpc_slice_buffer output_buffer;
62 
63   gpr_refcount ref;
64 } secure_endpoint;
65 
66 grpc_core::TraceFlag grpc_trace_secure_endpoint(false, "secure_endpoint");
67 
destroy(secure_endpoint * secure_ep)68 static void destroy(secure_endpoint* secure_ep) {
69   secure_endpoint* ep = secure_ep;
70   grpc_endpoint_destroy(ep->wrapped_ep);
71   tsi_frame_protector_destroy(ep->protector);
72   tsi_zero_copy_grpc_protector_destroy(ep->zero_copy_protector);
73   grpc_slice_buffer_destroy_internal(&ep->leftover_bytes);
74   grpc_slice_unref_internal(ep->read_staging_buffer);
75   grpc_slice_unref_internal(ep->write_staging_buffer);
76   grpc_slice_buffer_destroy_internal(&ep->output_buffer);
77   grpc_slice_buffer_destroy_internal(&ep->source_buffer);
78   gpr_mu_destroy(&ep->protector_mu);
79   gpr_free(ep);
80 }
81 
82 #ifndef NDEBUG
83 #define SECURE_ENDPOINT_UNREF(ep, reason) \
84   secure_endpoint_unref((ep), (reason), __FILE__, __LINE__)
85 #define SECURE_ENDPOINT_REF(ep, reason) \
86   secure_endpoint_ref((ep), (reason), __FILE__, __LINE__)
secure_endpoint_unref(secure_endpoint * ep,const char * reason,const char * file,int line)87 static void secure_endpoint_unref(secure_endpoint* ep, const char* reason,
88                                   const char* file, int line) {
89   if (grpc_trace_secure_endpoint.enabled()) {
90     gpr_atm val = gpr_atm_no_barrier_load(&ep->ref.count);
91     gpr_log(file, line, GPR_LOG_SEVERITY_DEBUG,
92             "SECENDP unref %p : %s %" PRIdPTR " -> %" PRIdPTR, ep, reason, val,
93             val - 1);
94   }
95   if (gpr_unref(&ep->ref)) {
96     destroy(ep);
97   }
98 }
99 
secure_endpoint_ref(secure_endpoint * ep,const char * reason,const char * file,int line)100 static void secure_endpoint_ref(secure_endpoint* ep, const char* reason,
101                                 const char* file, int line) {
102   if (grpc_trace_secure_endpoint.enabled()) {
103     gpr_atm val = gpr_atm_no_barrier_load(&ep->ref.count);
104     gpr_log(file, line, GPR_LOG_SEVERITY_DEBUG,
105             "SECENDP   ref %p : %s %" PRIdPTR " -> %" PRIdPTR, ep, reason, val,
106             val + 1);
107   }
108   gpr_ref(&ep->ref);
109 }
110 #else
111 #define SECURE_ENDPOINT_UNREF(ep, reason) secure_endpoint_unref((ep))
112 #define SECURE_ENDPOINT_REF(ep, reason) secure_endpoint_ref((ep))
secure_endpoint_unref(secure_endpoint * ep)113 static void secure_endpoint_unref(secure_endpoint* ep) {
114   if (gpr_unref(&ep->ref)) {
115     destroy(ep);
116   }
117 }
118 
secure_endpoint_ref(secure_endpoint * ep)119 static void secure_endpoint_ref(secure_endpoint* ep) { gpr_ref(&ep->ref); }
120 #endif
121 
flush_read_staging_buffer(secure_endpoint * ep,uint8_t ** cur,uint8_t ** end)122 static void flush_read_staging_buffer(secure_endpoint* ep, uint8_t** cur,
123                                       uint8_t** end) {
124   grpc_slice_buffer_add(ep->read_buffer, ep->read_staging_buffer);
125   ep->read_staging_buffer = GRPC_SLICE_MALLOC(STAGING_BUFFER_SIZE);
126   *cur = GRPC_SLICE_START_PTR(ep->read_staging_buffer);
127   *end = GRPC_SLICE_END_PTR(ep->read_staging_buffer);
128 }
129 
call_read_cb(secure_endpoint * ep,grpc_error * error)130 static void call_read_cb(secure_endpoint* ep, grpc_error* error) {
131   if (grpc_trace_secure_endpoint.enabled()) {
132     size_t i;
133     for (i = 0; i < ep->read_buffer->count; i++) {
134       char* data = grpc_dump_slice(ep->read_buffer->slices[i],
135                                    GPR_DUMP_HEX | GPR_DUMP_ASCII);
136       gpr_log(GPR_INFO, "READ %p: %s", ep, data);
137       gpr_free(data);
138     }
139   }
140   ep->read_buffer = nullptr;
141   GRPC_CLOSURE_SCHED(ep->read_cb, error);
142   SECURE_ENDPOINT_UNREF(ep, "read");
143 }
144 
on_read(void * user_data,grpc_error * error)145 static void on_read(void* user_data, grpc_error* error) {
146   unsigned i;
147   uint8_t keep_looping = 0;
148   tsi_result result = TSI_OK;
149   secure_endpoint* ep = static_cast<secure_endpoint*>(user_data);
150   uint8_t* cur = GRPC_SLICE_START_PTR(ep->read_staging_buffer);
151   uint8_t* end = GRPC_SLICE_END_PTR(ep->read_staging_buffer);
152 
153   if (error != GRPC_ERROR_NONE) {
154     grpc_slice_buffer_reset_and_unref_internal(ep->read_buffer);
155     call_read_cb(ep, GRPC_ERROR_CREATE_REFERENCING_FROM_STATIC_STRING(
156                          "Secure read failed", &error, 1));
157     return;
158   }
159 
160   if (ep->zero_copy_protector != nullptr) {
161     // Use zero-copy grpc protector to unprotect.
162     result = tsi_zero_copy_grpc_protector_unprotect(
163         ep->zero_copy_protector, &ep->source_buffer, ep->read_buffer);
164   } else {
165     // Use frame protector to unprotect.
166     /* TODO(yangg) check error, maybe bail out early */
167     for (i = 0; i < ep->source_buffer.count; i++) {
168       grpc_slice encrypted = ep->source_buffer.slices[i];
169       uint8_t* message_bytes = GRPC_SLICE_START_PTR(encrypted);
170       size_t message_size = GRPC_SLICE_LENGTH(encrypted);
171 
172       while (message_size > 0 || keep_looping) {
173         size_t unprotected_buffer_size_written = static_cast<size_t>(end - cur);
174         size_t processed_message_size = message_size;
175         gpr_mu_lock(&ep->protector_mu);
176         result = tsi_frame_protector_unprotect(
177             ep->protector, message_bytes, &processed_message_size, cur,
178             &unprotected_buffer_size_written);
179         gpr_mu_unlock(&ep->protector_mu);
180         if (result != TSI_OK) {
181           gpr_log(GPR_ERROR, "Decryption error: %s",
182                   tsi_result_to_string(result));
183           break;
184         }
185         message_bytes += processed_message_size;
186         message_size -= processed_message_size;
187         cur += unprotected_buffer_size_written;
188 
189         if (cur == end) {
190           flush_read_staging_buffer(ep, &cur, &end);
191           /* Force to enter the loop again to extract buffered bytes in
192              protector. The bytes could be buffered because of running out of
193              staging_buffer. If this happens at the end of all slices, doing
194              another unprotect avoids leaving data in the protector. */
195           keep_looping = 1;
196         } else if (unprotected_buffer_size_written > 0) {
197           keep_looping = 1;
198         } else {
199           keep_looping = 0;
200         }
201       }
202       if (result != TSI_OK) break;
203     }
204 
205     if (cur != GRPC_SLICE_START_PTR(ep->read_staging_buffer)) {
206       grpc_slice_buffer_add(
207           ep->read_buffer,
208           grpc_slice_split_head(
209               &ep->read_staging_buffer,
210               static_cast<size_t>(
211                   cur - GRPC_SLICE_START_PTR(ep->read_staging_buffer))));
212     }
213   }
214 
215   /* TODO(yangg) experiment with moving this block after read_cb to see if it
216      helps latency */
217   grpc_slice_buffer_reset_and_unref_internal(&ep->source_buffer);
218 
219   if (result != TSI_OK) {
220     grpc_slice_buffer_reset_and_unref_internal(ep->read_buffer);
221     call_read_cb(
222         ep, grpc_set_tsi_error_result(
223                 GRPC_ERROR_CREATE_FROM_STATIC_STRING("Unwrap failed"), result));
224     return;
225   }
226 
227   call_read_cb(ep, GRPC_ERROR_NONE);
228 }
229 
endpoint_read(grpc_endpoint * secure_ep,grpc_slice_buffer * slices,grpc_closure * cb)230 static void endpoint_read(grpc_endpoint* secure_ep, grpc_slice_buffer* slices,
231                           grpc_closure* cb) {
232   secure_endpoint* ep = reinterpret_cast<secure_endpoint*>(secure_ep);
233   ep->read_cb = cb;
234   ep->read_buffer = slices;
235   grpc_slice_buffer_reset_and_unref_internal(ep->read_buffer);
236 
237   SECURE_ENDPOINT_REF(ep, "read");
238   if (ep->leftover_bytes.count) {
239     grpc_slice_buffer_swap(&ep->leftover_bytes, &ep->source_buffer);
240     GPR_ASSERT(ep->leftover_bytes.count == 0);
241     on_read(ep, GRPC_ERROR_NONE);
242     return;
243   }
244 
245   grpc_endpoint_read(ep->wrapped_ep, &ep->source_buffer, &ep->on_read);
246 }
247 
flush_write_staging_buffer(secure_endpoint * ep,uint8_t ** cur,uint8_t ** end)248 static void flush_write_staging_buffer(secure_endpoint* ep, uint8_t** cur,
249                                        uint8_t** end) {
250   grpc_slice_buffer_add(&ep->output_buffer, ep->write_staging_buffer);
251   ep->write_staging_buffer = GRPC_SLICE_MALLOC(STAGING_BUFFER_SIZE);
252   *cur = GRPC_SLICE_START_PTR(ep->write_staging_buffer);
253   *end = GRPC_SLICE_END_PTR(ep->write_staging_buffer);
254 }
255 
endpoint_write(grpc_endpoint * secure_ep,grpc_slice_buffer * slices,grpc_closure * cb,void * arg)256 static void endpoint_write(grpc_endpoint* secure_ep, grpc_slice_buffer* slices,
257                            grpc_closure* cb, void* arg) {
258   GPR_TIMER_SCOPE("secure_endpoint.endpoint_write", 0);
259 
260   unsigned i;
261   tsi_result result = TSI_OK;
262   secure_endpoint* ep = reinterpret_cast<secure_endpoint*>(secure_ep);
263   uint8_t* cur = GRPC_SLICE_START_PTR(ep->write_staging_buffer);
264   uint8_t* end = GRPC_SLICE_END_PTR(ep->write_staging_buffer);
265 
266   grpc_slice_buffer_reset_and_unref_internal(&ep->output_buffer);
267 
268   if (grpc_trace_secure_endpoint.enabled()) {
269     for (i = 0; i < slices->count; i++) {
270       char* data =
271           grpc_dump_slice(slices->slices[i], GPR_DUMP_HEX | GPR_DUMP_ASCII);
272       gpr_log(GPR_INFO, "WRITE %p: %s", ep, data);
273       gpr_free(data);
274     }
275   }
276 
277   if (ep->zero_copy_protector != nullptr) {
278     // Use zero-copy grpc protector to protect.
279     result = tsi_zero_copy_grpc_protector_protect(ep->zero_copy_protector,
280                                                   slices, &ep->output_buffer);
281   } else {
282     // Use frame protector to protect.
283     for (i = 0; i < slices->count; i++) {
284       grpc_slice plain = slices->slices[i];
285       uint8_t* message_bytes = GRPC_SLICE_START_PTR(plain);
286       size_t message_size = GRPC_SLICE_LENGTH(plain);
287       while (message_size > 0) {
288         size_t protected_buffer_size_to_send = static_cast<size_t>(end - cur);
289         size_t processed_message_size = message_size;
290         gpr_mu_lock(&ep->protector_mu);
291         result = tsi_frame_protector_protect(ep->protector, message_bytes,
292                                              &processed_message_size, cur,
293                                              &protected_buffer_size_to_send);
294         gpr_mu_unlock(&ep->protector_mu);
295         if (result != TSI_OK) {
296           gpr_log(GPR_ERROR, "Encryption error: %s",
297                   tsi_result_to_string(result));
298           break;
299         }
300         message_bytes += processed_message_size;
301         message_size -= processed_message_size;
302         cur += protected_buffer_size_to_send;
303 
304         if (cur == end) {
305           flush_write_staging_buffer(ep, &cur, &end);
306         }
307       }
308       if (result != TSI_OK) break;
309     }
310     if (result == TSI_OK) {
311       size_t still_pending_size;
312       do {
313         size_t protected_buffer_size_to_send = static_cast<size_t>(end - cur);
314         gpr_mu_lock(&ep->protector_mu);
315         result = tsi_frame_protector_protect_flush(
316             ep->protector, cur, &protected_buffer_size_to_send,
317             &still_pending_size);
318         gpr_mu_unlock(&ep->protector_mu);
319         if (result != TSI_OK) break;
320         cur += protected_buffer_size_to_send;
321         if (cur == end) {
322           flush_write_staging_buffer(ep, &cur, &end);
323         }
324       } while (still_pending_size > 0);
325       if (cur != GRPC_SLICE_START_PTR(ep->write_staging_buffer)) {
326         grpc_slice_buffer_add(
327             &ep->output_buffer,
328             grpc_slice_split_head(
329                 &ep->write_staging_buffer,
330                 static_cast<size_t>(
331                     cur - GRPC_SLICE_START_PTR(ep->write_staging_buffer))));
332       }
333     }
334   }
335 
336   if (result != TSI_OK) {
337     /* TODO(yangg) do different things according to the error type? */
338     grpc_slice_buffer_reset_and_unref_internal(&ep->output_buffer);
339     GRPC_CLOSURE_SCHED(
340         cb, grpc_set_tsi_error_result(
341                 GRPC_ERROR_CREATE_FROM_STATIC_STRING("Wrap failed"), result));
342     return;
343   }
344 
345   grpc_endpoint_write(ep->wrapped_ep, &ep->output_buffer, cb, arg);
346 }
347 
endpoint_shutdown(grpc_endpoint * secure_ep,grpc_error * why)348 static void endpoint_shutdown(grpc_endpoint* secure_ep, grpc_error* why) {
349   secure_endpoint* ep = reinterpret_cast<secure_endpoint*>(secure_ep);
350   grpc_endpoint_shutdown(ep->wrapped_ep, why);
351 }
352 
endpoint_destroy(grpc_endpoint * secure_ep)353 static void endpoint_destroy(grpc_endpoint* secure_ep) {
354   secure_endpoint* ep = reinterpret_cast<secure_endpoint*>(secure_ep);
355   SECURE_ENDPOINT_UNREF(ep, "destroy");
356 }
357 
endpoint_add_to_pollset(grpc_endpoint * secure_ep,grpc_pollset * pollset)358 static void endpoint_add_to_pollset(grpc_endpoint* secure_ep,
359                                     grpc_pollset* pollset) {
360   secure_endpoint* ep = reinterpret_cast<secure_endpoint*>(secure_ep);
361   grpc_endpoint_add_to_pollset(ep->wrapped_ep, pollset);
362 }
363 
endpoint_add_to_pollset_set(grpc_endpoint * secure_ep,grpc_pollset_set * pollset_set)364 static void endpoint_add_to_pollset_set(grpc_endpoint* secure_ep,
365                                         grpc_pollset_set* pollset_set) {
366   secure_endpoint* ep = reinterpret_cast<secure_endpoint*>(secure_ep);
367   grpc_endpoint_add_to_pollset_set(ep->wrapped_ep, pollset_set);
368 }
369 
endpoint_delete_from_pollset_set(grpc_endpoint * secure_ep,grpc_pollset_set * pollset_set)370 static void endpoint_delete_from_pollset_set(grpc_endpoint* secure_ep,
371                                              grpc_pollset_set* pollset_set) {
372   secure_endpoint* ep = reinterpret_cast<secure_endpoint*>(secure_ep);
373   grpc_endpoint_delete_from_pollset_set(ep->wrapped_ep, pollset_set);
374 }
375 
endpoint_get_peer(grpc_endpoint * secure_ep)376 static char* endpoint_get_peer(grpc_endpoint* secure_ep) {
377   secure_endpoint* ep = reinterpret_cast<secure_endpoint*>(secure_ep);
378   return grpc_endpoint_get_peer(ep->wrapped_ep);
379 }
380 
endpoint_get_fd(grpc_endpoint * secure_ep)381 static int endpoint_get_fd(grpc_endpoint* secure_ep) {
382   secure_endpoint* ep = reinterpret_cast<secure_endpoint*>(secure_ep);
383   return grpc_endpoint_get_fd(ep->wrapped_ep);
384 }
385 
endpoint_get_resource_user(grpc_endpoint * secure_ep)386 static grpc_resource_user* endpoint_get_resource_user(
387     grpc_endpoint* secure_ep) {
388   secure_endpoint* ep = reinterpret_cast<secure_endpoint*>(secure_ep);
389   return grpc_endpoint_get_resource_user(ep->wrapped_ep);
390 }
391 
392 static const grpc_endpoint_vtable vtable = {endpoint_read,
393                                             endpoint_write,
394                                             endpoint_add_to_pollset,
395                                             endpoint_add_to_pollset_set,
396                                             endpoint_delete_from_pollset_set,
397                                             endpoint_shutdown,
398                                             endpoint_destroy,
399                                             endpoint_get_resource_user,
400                                             endpoint_get_peer,
401                                             endpoint_get_fd};
402 
grpc_secure_endpoint_create(struct tsi_frame_protector * protector,struct tsi_zero_copy_grpc_protector * zero_copy_protector,grpc_endpoint * transport,grpc_slice * leftover_slices,size_t leftover_nslices)403 grpc_endpoint* grpc_secure_endpoint_create(
404     struct tsi_frame_protector* protector,
405     struct tsi_zero_copy_grpc_protector* zero_copy_protector,
406     grpc_endpoint* transport, grpc_slice* leftover_slices,
407     size_t leftover_nslices) {
408   size_t i;
409   secure_endpoint* ep =
410       static_cast<secure_endpoint*>(gpr_malloc(sizeof(secure_endpoint)));
411   ep->base.vtable = &vtable;
412   ep->wrapped_ep = transport;
413   ep->protector = protector;
414   ep->zero_copy_protector = zero_copy_protector;
415   grpc_slice_buffer_init(&ep->leftover_bytes);
416   for (i = 0; i < leftover_nslices; i++) {
417     grpc_slice_buffer_add(&ep->leftover_bytes,
418                           grpc_slice_ref_internal(leftover_slices[i]));
419   }
420   ep->write_staging_buffer = GRPC_SLICE_MALLOC(STAGING_BUFFER_SIZE);
421   ep->read_staging_buffer = GRPC_SLICE_MALLOC(STAGING_BUFFER_SIZE);
422   grpc_slice_buffer_init(&ep->output_buffer);
423   grpc_slice_buffer_init(&ep->source_buffer);
424   ep->read_buffer = nullptr;
425   GRPC_CLOSURE_INIT(&ep->on_read, on_read, ep, grpc_schedule_on_exec_ctx);
426   gpr_mu_init(&ep->protector_mu);
427   gpr_ref_init(&ep->ref, 1);
428   return &ep->base;
429 }
430