1# Copyright 2015 gRPC authors. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14 15import contextlib 16import importlib 17import os 18from os import path 19import pkgutil 20import shutil 21import sys 22import tempfile 23import threading 24import unittest 25 26from six import moves 27 28from grpc.beta import implementations 29from grpc.beta import interfaces 30from grpc.framework.foundation import future 31from grpc.framework.interfaces.face import face 32from grpc_tools import protoc 33from tests.unit.framework.common import test_constants 34 35_RELATIVE_PROTO_PATH = 'relative_proto_path' 36_RELATIVE_PYTHON_OUT = 'relative_python_out' 37 38_PROTO_FILES_PATH_COMPONENTS = ( 39 ( 40 'beta_grpc_plugin_test', 41 'payload', 42 'test_payload.proto', 43 ), 44 ( 45 'beta_grpc_plugin_test', 46 'requests', 47 'r', 48 'test_requests.proto', 49 ), 50 ( 51 'beta_grpc_plugin_test', 52 'responses', 53 'test_responses.proto', 54 ), 55 ( 56 'beta_grpc_plugin_test', 57 'service', 58 'test_service.proto', 59 ), 60) 61 62_PAYLOAD_PB2 = 'beta_grpc_plugin_test.payload.test_payload_pb2' 63_REQUESTS_PB2 = 'beta_grpc_plugin_test.requests.r.test_requests_pb2' 64_RESPONSES_PB2 = 'beta_grpc_plugin_test.responses.test_responses_pb2' 65_SERVICE_PB2 = 'beta_grpc_plugin_test.service.test_service_pb2' 66 67# Identifiers of entities we expect to find in the generated module. 68SERVICER_IDENTIFIER = 'BetaTestServiceServicer' 69STUB_IDENTIFIER = 'BetaTestServiceStub' 70SERVER_FACTORY_IDENTIFIER = 'beta_create_TestService_server' 71STUB_FACTORY_IDENTIFIER = 'beta_create_TestService_stub' 72 73 74@contextlib.contextmanager 75def _system_path(path_insertion): 76 old_system_path = sys.path[:] 77 sys.path = sys.path[0:1] + path_insertion + sys.path[1:] 78 yield 79 sys.path = old_system_path 80 81 82def _create_directory_tree(root, path_components_sequence): 83 created = set() 84 for path_components in path_components_sequence: 85 thus_far = '' 86 for path_component in path_components: 87 relative_path = path.join(thus_far, path_component) 88 if relative_path not in created: 89 os.makedirs(path.join(root, relative_path)) 90 created.add(relative_path) 91 thus_far = path.join(thus_far, path_component) 92 93 94def _massage_proto_content(raw_proto_content): 95 imports_substituted = raw_proto_content.replace( 96 b'import "tests/protoc_plugin/protos/', 97 b'import "beta_grpc_plugin_test/') 98 package_statement_substituted = imports_substituted.replace( 99 b'package grpc_protoc_plugin;', b'package beta_grpc_protoc_plugin;') 100 return package_statement_substituted 101 102 103def _packagify(directory): 104 for subdirectory, _, _ in os.walk(directory): 105 init_file_name = path.join(subdirectory, '__init__.py') 106 with open(init_file_name, 'wb') as init_file: 107 init_file.write(b'') 108 109 110class _ServicerMethods(object): 111 112 def __init__(self, payload_pb2, responses_pb2): 113 self._condition = threading.Condition() 114 self._paused = False 115 self._fail = False 116 self._payload_pb2 = payload_pb2 117 self._responses_pb2 = responses_pb2 118 119 @contextlib.contextmanager 120 def pause(self): # pylint: disable=invalid-name 121 with self._condition: 122 self._paused = True 123 yield 124 with self._condition: 125 self._paused = False 126 self._condition.notify_all() 127 128 @contextlib.contextmanager 129 def fail(self): # pylint: disable=invalid-name 130 with self._condition: 131 self._fail = True 132 yield 133 with self._condition: 134 self._fail = False 135 136 def _control(self): # pylint: disable=invalid-name 137 with self._condition: 138 if self._fail: 139 raise ValueError() 140 while self._paused: 141 self._condition.wait() 142 143 def UnaryCall(self, request, unused_rpc_context): 144 response = self._responses_pb2.SimpleResponse() 145 response.payload.payload_type = self._payload_pb2.COMPRESSABLE 146 response.payload.payload_compressable = 'a' * request.response_size 147 self._control() 148 return response 149 150 def StreamingOutputCall(self, request, unused_rpc_context): 151 for parameter in request.response_parameters: 152 response = self._responses_pb2.StreamingOutputCallResponse() 153 response.payload.payload_type = self._payload_pb2.COMPRESSABLE 154 response.payload.payload_compressable = 'a' * parameter.size 155 self._control() 156 yield response 157 158 def StreamingInputCall(self, request_iter, unused_rpc_context): 159 response = self._responses_pb2.StreamingInputCallResponse() 160 aggregated_payload_size = 0 161 for request in request_iter: 162 aggregated_payload_size += len(request.payload.payload_compressable) 163 response.aggregated_payload_size = aggregated_payload_size 164 self._control() 165 return response 166 167 def FullDuplexCall(self, request_iter, unused_rpc_context): 168 for request in request_iter: 169 for parameter in request.response_parameters: 170 response = self._responses_pb2.StreamingOutputCallResponse() 171 response.payload.payload_type = self._payload_pb2.COMPRESSABLE 172 response.payload.payload_compressable = 'a' * parameter.size 173 self._control() 174 yield response 175 176 def HalfDuplexCall(self, request_iter, unused_rpc_context): 177 responses = [] 178 for request in request_iter: 179 for parameter in request.response_parameters: 180 response = self._responses_pb2.StreamingOutputCallResponse() 181 response.payload.payload_type = self._payload_pb2.COMPRESSABLE 182 response.payload.payload_compressable = 'a' * parameter.size 183 self._control() 184 responses.append(response) 185 for response in responses: 186 yield response 187 188 189@contextlib.contextmanager 190def _CreateService(payload_pb2, responses_pb2, service_pb2): 191 """Provides a servicer backend and a stub. 192 193 The servicer is just the implementation of the actual servicer passed to the 194 face player of the python RPC implementation; the two are detached. 195 196 Yields: 197 A (servicer_methods, stub) pair where servicer_methods is the back-end of 198 the service bound to the stub and and stub is the stub on which to invoke 199 RPCs. 200 """ 201 servicer_methods = _ServicerMethods(payload_pb2, responses_pb2) 202 203 class Servicer(getattr(service_pb2, SERVICER_IDENTIFIER)): 204 205 def UnaryCall(self, request, context): 206 return servicer_methods.UnaryCall(request, context) 207 208 def StreamingOutputCall(self, request, context): 209 return servicer_methods.StreamingOutputCall(request, context) 210 211 def StreamingInputCall(self, request_iter, context): 212 return servicer_methods.StreamingInputCall(request_iter, context) 213 214 def FullDuplexCall(self, request_iter, context): 215 return servicer_methods.FullDuplexCall(request_iter, context) 216 217 def HalfDuplexCall(self, request_iter, context): 218 return servicer_methods.HalfDuplexCall(request_iter, context) 219 220 servicer = Servicer() 221 server = getattr(service_pb2, SERVER_FACTORY_IDENTIFIER)(servicer) 222 port = server.add_insecure_port('[::]:0') 223 server.start() 224 channel = implementations.insecure_channel('localhost', port) 225 stub = getattr(service_pb2, STUB_FACTORY_IDENTIFIER)(channel) 226 yield servicer_methods, stub, 227 server.stop(0) 228 229 230@contextlib.contextmanager 231def _CreateIncompleteService(service_pb2): 232 """Provides a servicer backend that fails to implement methods and its stub. 233 234 The servicer is just the implementation of the actual servicer passed to the 235 face player of the python RPC implementation; the two are detached. 236 Args: 237 service_pb2: The service_pb2 module generated by this test. 238 Yields: 239 A (servicer_methods, stub) pair where servicer_methods is the back-end of 240 the service bound to the stub and and stub is the stub on which to invoke 241 RPCs. 242 """ 243 244 class Servicer(getattr(service_pb2, SERVICER_IDENTIFIER)): 245 pass 246 247 servicer = Servicer() 248 server = getattr(service_pb2, SERVER_FACTORY_IDENTIFIER)(servicer) 249 port = server.add_insecure_port('[::]:0') 250 server.start() 251 channel = implementations.insecure_channel('localhost', port) 252 stub = getattr(service_pb2, STUB_FACTORY_IDENTIFIER)(channel) 253 yield None, stub 254 server.stop(0) 255 256 257def _streaming_input_request_iterator(payload_pb2, requests_pb2): 258 for _ in range(3): 259 request = requests_pb2.StreamingInputCallRequest() 260 request.payload.payload_type = payload_pb2.COMPRESSABLE 261 request.payload.payload_compressable = 'a' 262 yield request 263 264 265def _streaming_output_request(requests_pb2): 266 request = requests_pb2.StreamingOutputCallRequest() 267 sizes = [1, 2, 3] 268 request.response_parameters.add(size=sizes[0], interval_us=0) 269 request.response_parameters.add(size=sizes[1], interval_us=0) 270 request.response_parameters.add(size=sizes[2], interval_us=0) 271 return request 272 273 274def _full_duplex_request_iterator(requests_pb2): 275 request = requests_pb2.StreamingOutputCallRequest() 276 request.response_parameters.add(size=1, interval_us=0) 277 yield request 278 request = requests_pb2.StreamingOutputCallRequest() 279 request.response_parameters.add(size=2, interval_us=0) 280 request.response_parameters.add(size=3, interval_us=0) 281 yield request 282 283 284class PythonPluginTest(unittest.TestCase): 285 """Test case for the gRPC Python protoc-plugin. 286 287 While reading these tests, remember that the futures API 288 (`stub.method.future()`) only gives futures for the *response-unary* 289 methods and does not exist for response-streaming methods. 290 """ 291 292 def setUp(self): 293 self._directory = tempfile.mkdtemp(dir='.') 294 self._proto_path = path.join(self._directory, _RELATIVE_PROTO_PATH) 295 self._python_out = path.join(self._directory, _RELATIVE_PYTHON_OUT) 296 297 os.makedirs(self._proto_path) 298 os.makedirs(self._python_out) 299 300 directories_path_components = { 301 proto_file_path_components[:-1] 302 for proto_file_path_components in _PROTO_FILES_PATH_COMPONENTS 303 } 304 _create_directory_tree(self._proto_path, directories_path_components) 305 self._proto_file_names = set() 306 for proto_file_path_components in _PROTO_FILES_PATH_COMPONENTS: 307 raw_proto_content = pkgutil.get_data( 308 'tests.protoc_plugin.protos', 309 path.join(*proto_file_path_components[1:])) 310 massaged_proto_content = _massage_proto_content(raw_proto_content) 311 proto_file_name = path.join(self._proto_path, 312 *proto_file_path_components) 313 with open(proto_file_name, 'wb') as proto_file: 314 proto_file.write(massaged_proto_content) 315 self._proto_file_names.add(proto_file_name) 316 317 def tearDown(self): 318 shutil.rmtree(self._directory) 319 320 def _protoc(self): 321 args = [ 322 '', 323 '--proto_path={}'.format(self._proto_path), 324 '--python_out={}'.format(self._python_out), 325 '--grpc_python_out=grpc_1_0:{}'.format(self._python_out), 326 ] + list(self._proto_file_names) 327 protoc_exit_code = protoc.main(args) 328 self.assertEqual(0, protoc_exit_code) 329 330 _packagify(self._python_out) 331 332 with _system_path([self._python_out]): 333 self._payload_pb2 = importlib.import_module(_PAYLOAD_PB2) 334 self._requests_pb2 = importlib.import_module(_REQUESTS_PB2) 335 self._responses_pb2 = importlib.import_module(_RESPONSES_PB2) 336 self._service_pb2 = importlib.import_module(_SERVICE_PB2) 337 338 def testImportAttributes(self): 339 self._protoc() 340 341 # check that we can access the generated module and its members. 342 self.assertIsNotNone( 343 getattr(self._service_pb2, SERVICER_IDENTIFIER, None)) 344 self.assertIsNotNone(getattr(self._service_pb2, STUB_IDENTIFIER, None)) 345 self.assertIsNotNone( 346 getattr(self._service_pb2, SERVER_FACTORY_IDENTIFIER, None)) 347 self.assertIsNotNone( 348 getattr(self._service_pb2, STUB_FACTORY_IDENTIFIER, None)) 349 350 def testUpDown(self): 351 self._protoc() 352 353 with _CreateService(self._payload_pb2, self._responses_pb2, 354 self._service_pb2): 355 self._requests_pb2.SimpleRequest(response_size=13) 356 357 def testIncompleteServicer(self): 358 self._protoc() 359 360 with _CreateIncompleteService(self._service_pb2) as (_, stub): 361 request = self._requests_pb2.SimpleRequest(response_size=13) 362 try: 363 stub.UnaryCall(request, test_constants.LONG_TIMEOUT) 364 except face.AbortionError as error: 365 self.assertEqual(interfaces.StatusCode.UNIMPLEMENTED, 366 error.code) 367 368 def testUnaryCall(self): 369 self._protoc() 370 371 with _CreateService(self._payload_pb2, self._responses_pb2, 372 self._service_pb2) as (methods, stub): 373 request = self._requests_pb2.SimpleRequest(response_size=13) 374 response = stub.UnaryCall(request, test_constants.LONG_TIMEOUT) 375 expected_response = methods.UnaryCall(request, 'not a real context!') 376 self.assertEqual(expected_response, response) 377 378 def testUnaryCallFuture(self): 379 self._protoc() 380 381 with _CreateService(self._payload_pb2, self._responses_pb2, 382 self._service_pb2) as (methods, stub): 383 request = self._requests_pb2.SimpleRequest(response_size=13) 384 # Check that the call does not block waiting for the server to respond. 385 with methods.pause(): 386 response_future = stub.UnaryCall.future( 387 request, test_constants.LONG_TIMEOUT) 388 response = response_future.result() 389 expected_response = methods.UnaryCall(request, 'not a real RpcContext!') 390 self.assertEqual(expected_response, response) 391 392 def testUnaryCallFutureExpired(self): 393 self._protoc() 394 395 with _CreateService(self._payload_pb2, self._responses_pb2, 396 self._service_pb2) as (methods, stub): 397 request = self._requests_pb2.SimpleRequest(response_size=13) 398 with methods.pause(): 399 response_future = stub.UnaryCall.future( 400 request, test_constants.SHORT_TIMEOUT) 401 with self.assertRaises(face.ExpirationError): 402 response_future.result() 403 404 def testUnaryCallFutureCancelled(self): 405 self._protoc() 406 407 with _CreateService(self._payload_pb2, self._responses_pb2, 408 self._service_pb2) as (methods, stub): 409 request = self._requests_pb2.SimpleRequest(response_size=13) 410 with methods.pause(): 411 response_future = stub.UnaryCall.future(request, 1) 412 response_future.cancel() 413 self.assertTrue(response_future.cancelled()) 414 415 def testUnaryCallFutureFailed(self): 416 self._protoc() 417 418 with _CreateService(self._payload_pb2, self._responses_pb2, 419 self._service_pb2) as (methods, stub): 420 request = self._requests_pb2.SimpleRequest(response_size=13) 421 with methods.fail(): 422 response_future = stub.UnaryCall.future( 423 request, test_constants.LONG_TIMEOUT) 424 self.assertIsNotNone(response_future.exception()) 425 426 def testStreamingOutputCall(self): 427 self._protoc() 428 429 with _CreateService(self._payload_pb2, self._responses_pb2, 430 self._service_pb2) as (methods, stub): 431 request = _streaming_output_request(self._requests_pb2) 432 responses = stub.StreamingOutputCall(request, 433 test_constants.LONG_TIMEOUT) 434 expected_responses = methods.StreamingOutputCall( 435 request, 'not a real RpcContext!') 436 for expected_response, response in moves.zip_longest( 437 expected_responses, responses): 438 self.assertEqual(expected_response, response) 439 440 def testStreamingOutputCallExpired(self): 441 self._protoc() 442 443 with _CreateService(self._payload_pb2, self._responses_pb2, 444 self._service_pb2) as (methods, stub): 445 request = _streaming_output_request(self._requests_pb2) 446 with methods.pause(): 447 responses = stub.StreamingOutputCall( 448 request, test_constants.SHORT_TIMEOUT) 449 with self.assertRaises(face.ExpirationError): 450 list(responses) 451 452 def testStreamingOutputCallCancelled(self): 453 self._protoc() 454 455 with _CreateService(self._payload_pb2, self._responses_pb2, 456 self._service_pb2) as (methods, stub): 457 request = _streaming_output_request(self._requests_pb2) 458 responses = stub.StreamingOutputCall(request, 459 test_constants.LONG_TIMEOUT) 460 next(responses) 461 responses.cancel() 462 with self.assertRaises(face.CancellationError): 463 next(responses) 464 465 def testStreamingOutputCallFailed(self): 466 self._protoc() 467 468 with _CreateService(self._payload_pb2, self._responses_pb2, 469 self._service_pb2) as (methods, stub): 470 request = _streaming_output_request(self._requests_pb2) 471 with methods.fail(): 472 responses = stub.StreamingOutputCall(request, 1) 473 self.assertIsNotNone(responses) 474 with self.assertRaises(face.RemoteError): 475 next(responses) 476 477 def testStreamingInputCall(self): 478 self._protoc() 479 480 with _CreateService(self._payload_pb2, self._responses_pb2, 481 self._service_pb2) as (methods, stub): 482 response = stub.StreamingInputCall( 483 _streaming_input_request_iterator(self._payload_pb2, 484 self._requests_pb2), 485 test_constants.LONG_TIMEOUT) 486 expected_response = methods.StreamingInputCall( 487 _streaming_input_request_iterator(self._payload_pb2, 488 self._requests_pb2), 489 'not a real RpcContext!') 490 self.assertEqual(expected_response, response) 491 492 def testStreamingInputCallFuture(self): 493 self._protoc() 494 495 with _CreateService(self._payload_pb2, self._responses_pb2, 496 self._service_pb2) as (methods, stub): 497 with methods.pause(): 498 response_future = stub.StreamingInputCall.future( 499 _streaming_input_request_iterator(self._payload_pb2, 500 self._requests_pb2), 501 test_constants.LONG_TIMEOUT) 502 response = response_future.result() 503 expected_response = methods.StreamingInputCall( 504 _streaming_input_request_iterator(self._payload_pb2, 505 self._requests_pb2), 506 'not a real RpcContext!') 507 self.assertEqual(expected_response, response) 508 509 def testStreamingInputCallFutureExpired(self): 510 self._protoc() 511 512 with _CreateService(self._payload_pb2, self._responses_pb2, 513 self._service_pb2) as (methods, stub): 514 with methods.pause(): 515 response_future = stub.StreamingInputCall.future( 516 _streaming_input_request_iterator(self._payload_pb2, 517 self._requests_pb2), 518 test_constants.SHORT_TIMEOUT) 519 with self.assertRaises(face.ExpirationError): 520 response_future.result() 521 self.assertIsInstance(response_future.exception(), 522 face.ExpirationError) 523 524 def testStreamingInputCallFutureCancelled(self): 525 self._protoc() 526 527 with _CreateService(self._payload_pb2, self._responses_pb2, 528 self._service_pb2) as (methods, stub): 529 with methods.pause(): 530 response_future = stub.StreamingInputCall.future( 531 _streaming_input_request_iterator(self._payload_pb2, 532 self._requests_pb2), 533 test_constants.LONG_TIMEOUT) 534 response_future.cancel() 535 self.assertTrue(response_future.cancelled()) 536 with self.assertRaises(future.CancelledError): 537 response_future.result() 538 539 def testStreamingInputCallFutureFailed(self): 540 self._protoc() 541 542 with _CreateService(self._payload_pb2, self._responses_pb2, 543 self._service_pb2) as (methods, stub): 544 with methods.fail(): 545 response_future = stub.StreamingInputCall.future( 546 _streaming_input_request_iterator(self._payload_pb2, 547 self._requests_pb2), 548 test_constants.LONG_TIMEOUT) 549 self.assertIsNotNone(response_future.exception()) 550 551 def testFullDuplexCall(self): 552 self._protoc() 553 554 with _CreateService(self._payload_pb2, self._responses_pb2, 555 self._service_pb2) as (methods, stub): 556 responses = stub.FullDuplexCall( 557 _full_duplex_request_iterator(self._requests_pb2), 558 test_constants.LONG_TIMEOUT) 559 expected_responses = methods.FullDuplexCall( 560 _full_duplex_request_iterator(self._requests_pb2), 561 'not a real RpcContext!') 562 for expected_response, response in moves.zip_longest( 563 expected_responses, responses): 564 self.assertEqual(expected_response, response) 565 566 def testFullDuplexCallExpired(self): 567 self._protoc() 568 569 request_iterator = _full_duplex_request_iterator(self._requests_pb2) 570 with _CreateService(self._payload_pb2, self._responses_pb2, 571 self._service_pb2) as (methods, stub): 572 with methods.pause(): 573 responses = stub.FullDuplexCall(request_iterator, 574 test_constants.SHORT_TIMEOUT) 575 with self.assertRaises(face.ExpirationError): 576 list(responses) 577 578 def testFullDuplexCallCancelled(self): 579 self._protoc() 580 581 with _CreateService(self._payload_pb2, self._responses_pb2, 582 self._service_pb2) as (methods, stub): 583 request_iterator = _full_duplex_request_iterator(self._requests_pb2) 584 responses = stub.FullDuplexCall(request_iterator, 585 test_constants.LONG_TIMEOUT) 586 next(responses) 587 responses.cancel() 588 with self.assertRaises(face.CancellationError): 589 next(responses) 590 591 def testFullDuplexCallFailed(self): 592 self._protoc() 593 594 request_iterator = _full_duplex_request_iterator(self._requests_pb2) 595 with _CreateService(self._payload_pb2, self._responses_pb2, 596 self._service_pb2) as (methods, stub): 597 with methods.fail(): 598 responses = stub.FullDuplexCall(request_iterator, 599 test_constants.LONG_TIMEOUT) 600 self.assertIsNotNone(responses) 601 with self.assertRaises(face.RemoteError): 602 next(responses) 603 604 def testHalfDuplexCall(self): 605 self._protoc() 606 607 with _CreateService(self._payload_pb2, self._responses_pb2, 608 self._service_pb2) as (methods, stub): 609 610 def half_duplex_request_iterator(): 611 request = self._requests_pb2.StreamingOutputCallRequest() 612 request.response_parameters.add(size=1, interval_us=0) 613 yield request 614 request = self._requests_pb2.StreamingOutputCallRequest() 615 request.response_parameters.add(size=2, interval_us=0) 616 request.response_parameters.add(size=3, interval_us=0) 617 yield request 618 619 responses = stub.HalfDuplexCall(half_duplex_request_iterator(), 620 test_constants.LONG_TIMEOUT) 621 expected_responses = methods.HalfDuplexCall( 622 half_duplex_request_iterator(), 'not a real RpcContext!') 623 for check in moves.zip_longest(expected_responses, responses): 624 expected_response, response = check 625 self.assertEqual(expected_response, response) 626 627 def testHalfDuplexCallWedged(self): 628 self._protoc() 629 630 condition = threading.Condition() 631 wait_cell = [False] 632 633 @contextlib.contextmanager 634 def wait(): # pylint: disable=invalid-name 635 # Where's Python 3's 'nonlocal' statement when you need it? 636 with condition: 637 wait_cell[0] = True 638 yield 639 with condition: 640 wait_cell[0] = False 641 condition.notify_all() 642 643 def half_duplex_request_iterator(): 644 request = self._requests_pb2.StreamingOutputCallRequest() 645 request.response_parameters.add(size=1, interval_us=0) 646 yield request 647 with condition: 648 while wait_cell[0]: 649 condition.wait() 650 651 with _CreateService(self._payload_pb2, self._responses_pb2, 652 self._service_pb2) as (methods, stub): 653 with wait(): 654 responses = stub.HalfDuplexCall(half_duplex_request_iterator(), 655 test_constants.SHORT_TIMEOUT) 656 # half-duplex waits for the client to send all info 657 with self.assertRaises(face.ExpirationError): 658 next(responses) 659 660 661if __name__ == '__main__': 662 unittest.main(verbosity=2) 663