1# Copyright 2015 gRPC authors. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14 15import collections 16import contextlib 17import distutils.spawn 18import errno 19import os 20import shutil 21import subprocess 22import sys 23import tempfile 24import threading 25import unittest 26 27from six import moves 28 29import grpc 30from tests.unit import test_common 31from tests.unit.framework.common import test_constants 32 33import tests.protoc_plugin.protos.payload.test_payload_pb2 as payload_pb2 34import tests.protoc_plugin.protos.requests.r.test_requests_pb2 as request_pb2 35import tests.protoc_plugin.protos.responses.test_responses_pb2 as response_pb2 36import tests.protoc_plugin.protos.service.test_service_pb2_grpc as service_pb2_grpc 37 38# Identifiers of entities we expect to find in the generated module. 39STUB_IDENTIFIER = 'TestServiceStub' 40SERVICER_IDENTIFIER = 'TestServiceServicer' 41ADD_SERVICER_TO_SERVER_IDENTIFIER = 'add_TestServiceServicer_to_server' 42 43 44class _ServicerMethods(object): 45 46 def __init__(self): 47 self._condition = threading.Condition() 48 self._paused = False 49 self._fail = False 50 51 @contextlib.contextmanager 52 def pause(self): # pylint: disable=invalid-name 53 with self._condition: 54 self._paused = True 55 yield 56 with self._condition: 57 self._paused = False 58 self._condition.notify_all() 59 60 @contextlib.contextmanager 61 def fail(self): # pylint: disable=invalid-name 62 with self._condition: 63 self._fail = True 64 yield 65 with self._condition: 66 self._fail = False 67 68 def _control(self): # pylint: disable=invalid-name 69 with self._condition: 70 if self._fail: 71 raise ValueError() 72 while self._paused: 73 self._condition.wait() 74 75 def UnaryCall(self, request, unused_rpc_context): 76 response = response_pb2.SimpleResponse() 77 response.payload.payload_type = payload_pb2.COMPRESSABLE 78 response.payload.payload_compressable = 'a' * request.response_size 79 self._control() 80 return response 81 82 def StreamingOutputCall(self, request, unused_rpc_context): 83 for parameter in request.response_parameters: 84 response = response_pb2.StreamingOutputCallResponse() 85 response.payload.payload_type = payload_pb2.COMPRESSABLE 86 response.payload.payload_compressable = 'a' * parameter.size 87 self._control() 88 yield response 89 90 def StreamingInputCall(self, request_iter, unused_rpc_context): 91 response = response_pb2.StreamingInputCallResponse() 92 aggregated_payload_size = 0 93 for request in request_iter: 94 aggregated_payload_size += len(request.payload.payload_compressable) 95 response.aggregated_payload_size = aggregated_payload_size 96 self._control() 97 return response 98 99 def FullDuplexCall(self, request_iter, unused_rpc_context): 100 for request in request_iter: 101 for parameter in request.response_parameters: 102 response = response_pb2.StreamingOutputCallResponse() 103 response.payload.payload_type = payload_pb2.COMPRESSABLE 104 response.payload.payload_compressable = 'a' * parameter.size 105 self._control() 106 yield response 107 108 def HalfDuplexCall(self, request_iter, unused_rpc_context): 109 responses = [] 110 for request in request_iter: 111 for parameter in request.response_parameters: 112 response = response_pb2.StreamingOutputCallResponse() 113 response.payload.payload_type = payload_pb2.COMPRESSABLE 114 response.payload.payload_compressable = 'a' * parameter.size 115 self._control() 116 responses.append(response) 117 for response in responses: 118 yield response 119 120 121class _Service( 122 collections.namedtuple('_Service', ( 123 'servicer_methods', 124 'server', 125 'stub', 126 ))): 127 """A live and running service. 128 129 Attributes: 130 servicer_methods: The _ServicerMethods servicing RPCs. 131 server: The grpc.Server servicing RPCs. 132 stub: A stub on which to invoke RPCs. 133 """ 134 135 136def _CreateService(): 137 """Provides a servicer backend and a stub. 138 139 Returns: 140 A _Service with which to test RPCs. 141 """ 142 servicer_methods = _ServicerMethods() 143 144 class Servicer(getattr(service_pb2_grpc, SERVICER_IDENTIFIER)): 145 146 def UnaryCall(self, request, context): 147 return servicer_methods.UnaryCall(request, context) 148 149 def StreamingOutputCall(self, request, context): 150 return servicer_methods.StreamingOutputCall(request, context) 151 152 def StreamingInputCall(self, request_iter, context): 153 return servicer_methods.StreamingInputCall(request_iter, context) 154 155 def FullDuplexCall(self, request_iter, context): 156 return servicer_methods.FullDuplexCall(request_iter, context) 157 158 def HalfDuplexCall(self, request_iter, context): 159 return servicer_methods.HalfDuplexCall(request_iter, context) 160 161 server = test_common.test_server() 162 getattr(service_pb2_grpc, ADD_SERVICER_TO_SERVER_IDENTIFIER)(Servicer(), 163 server) 164 port = server.add_insecure_port('[::]:0') 165 server.start() 166 channel = grpc.insecure_channel('localhost:{}'.format(port)) 167 stub = getattr(service_pb2_grpc, STUB_IDENTIFIER)(channel) 168 return _Service(servicer_methods, server, stub) 169 170 171def _CreateIncompleteService(): 172 """Provides a servicer backend that fails to implement methods and its stub. 173 174 Returns: 175 A _Service with which to test RPCs. The returned _Service's 176 servicer_methods implements none of the methods required of it. 177 """ 178 179 class Servicer(getattr(service_pb2_grpc, SERVICER_IDENTIFIER)): 180 pass 181 182 server = test_common.test_server() 183 getattr(service_pb2_grpc, ADD_SERVICER_TO_SERVER_IDENTIFIER)(Servicer(), 184 server) 185 port = server.add_insecure_port('[::]:0') 186 server.start() 187 channel = grpc.insecure_channel('localhost:{}'.format(port)) 188 stub = getattr(service_pb2_grpc, STUB_IDENTIFIER)(channel) 189 return _Service(None, server, stub) 190 191 192def _streaming_input_request_iterator(): 193 for _ in range(3): 194 request = request_pb2.StreamingInputCallRequest() 195 request.payload.payload_type = payload_pb2.COMPRESSABLE 196 request.payload.payload_compressable = 'a' 197 yield request 198 199 200def _streaming_output_request(): 201 request = request_pb2.StreamingOutputCallRequest() 202 sizes = [1, 2, 3] 203 request.response_parameters.add(size=sizes[0], interval_us=0) 204 request.response_parameters.add(size=sizes[1], interval_us=0) 205 request.response_parameters.add(size=sizes[2], interval_us=0) 206 return request 207 208 209def _full_duplex_request_iterator(): 210 request = request_pb2.StreamingOutputCallRequest() 211 request.response_parameters.add(size=1, interval_us=0) 212 yield request 213 request = request_pb2.StreamingOutputCallRequest() 214 request.response_parameters.add(size=2, interval_us=0) 215 request.response_parameters.add(size=3, interval_us=0) 216 yield request 217 218 219class PythonPluginTest(unittest.TestCase): 220 """Test case for the gRPC Python protoc-plugin. 221 222 While reading these tests, remember that the futures API 223 (`stub.method.future()`) only gives futures for the *response-unary* 224 methods and does not exist for response-streaming methods. 225 """ 226 227 def testImportAttributes(self): 228 # check that we can access the generated module and its members. 229 self.assertIsNotNone(getattr(service_pb2_grpc, STUB_IDENTIFIER, None)) 230 self.assertIsNotNone( 231 getattr(service_pb2_grpc, SERVICER_IDENTIFIER, None)) 232 self.assertIsNotNone( 233 getattr(service_pb2_grpc, ADD_SERVICER_TO_SERVER_IDENTIFIER, None)) 234 235 def testUpDown(self): 236 service = _CreateService() 237 self.assertIsNotNone(service.servicer_methods) 238 self.assertIsNotNone(service.server) 239 self.assertIsNotNone(service.stub) 240 service.server.stop(None) 241 242 def testIncompleteServicer(self): 243 service = _CreateIncompleteService() 244 request = request_pb2.SimpleRequest(response_size=13) 245 with self.assertRaises(grpc.RpcError) as exception_context: 246 service.stub.UnaryCall(request) 247 self.assertIs(exception_context.exception.code(), 248 grpc.StatusCode.UNIMPLEMENTED) 249 service.server.stop(None) 250 251 def testUnaryCall(self): 252 service = _CreateService() 253 request = request_pb2.SimpleRequest(response_size=13) 254 response = service.stub.UnaryCall(request) 255 expected_response = service.servicer_methods.UnaryCall( 256 request, 'not a real context!') 257 self.assertEqual(expected_response, response) 258 service.server.stop(None) 259 260 def testUnaryCallFuture(self): 261 service = _CreateService() 262 request = request_pb2.SimpleRequest(response_size=13) 263 # Check that the call does not block waiting for the server to respond. 264 with service.servicer_methods.pause(): 265 response_future = service.stub.UnaryCall.future(request) 266 response = response_future.result() 267 expected_response = service.servicer_methods.UnaryCall( 268 request, 'not a real RpcContext!') 269 self.assertEqual(expected_response, response) 270 service.server.stop(None) 271 272 def testUnaryCallFutureExpired(self): 273 service = _CreateService() 274 request = request_pb2.SimpleRequest(response_size=13) 275 with service.servicer_methods.pause(): 276 response_future = service.stub.UnaryCall.future( 277 request, timeout=test_constants.SHORT_TIMEOUT) 278 with self.assertRaises(grpc.RpcError) as exception_context: 279 response_future.result() 280 self.assertIs(exception_context.exception.code(), 281 grpc.StatusCode.DEADLINE_EXCEEDED) 282 self.assertIs(response_future.code(), grpc.StatusCode.DEADLINE_EXCEEDED) 283 service.server.stop(None) 284 285 def testUnaryCallFutureCancelled(self): 286 service = _CreateService() 287 request = request_pb2.SimpleRequest(response_size=13) 288 with service.servicer_methods.pause(): 289 response_future = service.stub.UnaryCall.future(request) 290 response_future.cancel() 291 self.assertTrue(response_future.cancelled()) 292 self.assertIs(response_future.code(), grpc.StatusCode.CANCELLED) 293 service.server.stop(None) 294 295 def testUnaryCallFutureFailed(self): 296 service = _CreateService() 297 request = request_pb2.SimpleRequest(response_size=13) 298 with service.servicer_methods.fail(): 299 response_future = service.stub.UnaryCall.future(request) 300 self.assertIsNotNone(response_future.exception()) 301 self.assertIs(response_future.code(), grpc.StatusCode.UNKNOWN) 302 service.server.stop(None) 303 304 def testStreamingOutputCall(self): 305 service = _CreateService() 306 request = _streaming_output_request() 307 responses = service.stub.StreamingOutputCall(request) 308 expected_responses = service.servicer_methods.StreamingOutputCall( 309 request, 'not a real RpcContext!') 310 for expected_response, response in moves.zip_longest( 311 expected_responses, responses): 312 self.assertEqual(expected_response, response) 313 service.server.stop(None) 314 315 def testStreamingOutputCallExpired(self): 316 service = _CreateService() 317 request = _streaming_output_request() 318 with service.servicer_methods.pause(): 319 responses = service.stub.StreamingOutputCall( 320 request, timeout=test_constants.SHORT_TIMEOUT) 321 with self.assertRaises(grpc.RpcError) as exception_context: 322 list(responses) 323 self.assertIs(exception_context.exception.code(), 324 grpc.StatusCode.DEADLINE_EXCEEDED) 325 service.server.stop(None) 326 327 def testStreamingOutputCallCancelled(self): 328 service = _CreateService() 329 request = _streaming_output_request() 330 responses = service.stub.StreamingOutputCall(request) 331 next(responses) 332 responses.cancel() 333 with self.assertRaises(grpc.RpcError) as exception_context: 334 next(responses) 335 self.assertIs(responses.code(), grpc.StatusCode.CANCELLED) 336 service.server.stop(None) 337 338 def testStreamingOutputCallFailed(self): 339 service = _CreateService() 340 request = _streaming_output_request() 341 with service.servicer_methods.fail(): 342 responses = service.stub.StreamingOutputCall(request) 343 self.assertIsNotNone(responses) 344 with self.assertRaises(grpc.RpcError) as exception_context: 345 next(responses) 346 self.assertIs(exception_context.exception.code(), 347 grpc.StatusCode.UNKNOWN) 348 service.server.stop(None) 349 350 def testStreamingInputCall(self): 351 service = _CreateService() 352 response = service.stub.StreamingInputCall( 353 _streaming_input_request_iterator()) 354 expected_response = service.servicer_methods.StreamingInputCall( 355 _streaming_input_request_iterator(), 'not a real RpcContext!') 356 self.assertEqual(expected_response, response) 357 service.server.stop(None) 358 359 def testStreamingInputCallFuture(self): 360 service = _CreateService() 361 with service.servicer_methods.pause(): 362 response_future = service.stub.StreamingInputCall.future( 363 _streaming_input_request_iterator()) 364 response = response_future.result() 365 expected_response = service.servicer_methods.StreamingInputCall( 366 _streaming_input_request_iterator(), 'not a real RpcContext!') 367 self.assertEqual(expected_response, response) 368 service.server.stop(None) 369 370 def testStreamingInputCallFutureExpired(self): 371 service = _CreateService() 372 with service.servicer_methods.pause(): 373 response_future = service.stub.StreamingInputCall.future( 374 _streaming_input_request_iterator(), 375 timeout=test_constants.SHORT_TIMEOUT) 376 with self.assertRaises(grpc.RpcError) as exception_context: 377 response_future.result() 378 self.assertIsInstance(response_future.exception(), grpc.RpcError) 379 self.assertIs(response_future.exception().code(), 380 grpc.StatusCode.DEADLINE_EXCEEDED) 381 self.assertIs(exception_context.exception.code(), 382 grpc.StatusCode.DEADLINE_EXCEEDED) 383 service.server.stop(None) 384 385 def testStreamingInputCallFutureCancelled(self): 386 service = _CreateService() 387 with service.servicer_methods.pause(): 388 response_future = service.stub.StreamingInputCall.future( 389 _streaming_input_request_iterator()) 390 response_future.cancel() 391 self.assertTrue(response_future.cancelled()) 392 with self.assertRaises(grpc.FutureCancelledError): 393 response_future.result() 394 service.server.stop(None) 395 396 def testStreamingInputCallFutureFailed(self): 397 service = _CreateService() 398 with service.servicer_methods.fail(): 399 response_future = service.stub.StreamingInputCall.future( 400 _streaming_input_request_iterator()) 401 self.assertIsNotNone(response_future.exception()) 402 self.assertIs(response_future.code(), grpc.StatusCode.UNKNOWN) 403 service.server.stop(None) 404 405 def testFullDuplexCall(self): 406 service = _CreateService() 407 responses = service.stub.FullDuplexCall(_full_duplex_request_iterator()) 408 expected_responses = service.servicer_methods.FullDuplexCall( 409 _full_duplex_request_iterator(), 'not a real RpcContext!') 410 for expected_response, response in moves.zip_longest( 411 expected_responses, responses): 412 self.assertEqual(expected_response, response) 413 service.server.stop(None) 414 415 def testFullDuplexCallExpired(self): 416 request_iterator = _full_duplex_request_iterator() 417 service = _CreateService() 418 with service.servicer_methods.pause(): 419 responses = service.stub.FullDuplexCall( 420 request_iterator, timeout=test_constants.SHORT_TIMEOUT) 421 with self.assertRaises(grpc.RpcError) as exception_context: 422 list(responses) 423 self.assertIs(exception_context.exception.code(), 424 grpc.StatusCode.DEADLINE_EXCEEDED) 425 service.server.stop(None) 426 427 def testFullDuplexCallCancelled(self): 428 service = _CreateService() 429 request_iterator = _full_duplex_request_iterator() 430 responses = service.stub.FullDuplexCall(request_iterator) 431 next(responses) 432 responses.cancel() 433 with self.assertRaises(grpc.RpcError) as exception_context: 434 next(responses) 435 self.assertIs(exception_context.exception.code(), 436 grpc.StatusCode.CANCELLED) 437 service.server.stop(None) 438 439 def testFullDuplexCallFailed(self): 440 request_iterator = _full_duplex_request_iterator() 441 service = _CreateService() 442 with service.servicer_methods.fail(): 443 responses = service.stub.FullDuplexCall(request_iterator) 444 with self.assertRaises(grpc.RpcError) as exception_context: 445 next(responses) 446 self.assertIs(exception_context.exception.code(), 447 grpc.StatusCode.UNKNOWN) 448 service.server.stop(None) 449 450 def testHalfDuplexCall(self): 451 service = _CreateService() 452 453 def half_duplex_request_iterator(): 454 request = request_pb2.StreamingOutputCallRequest() 455 request.response_parameters.add(size=1, interval_us=0) 456 yield request 457 request = request_pb2.StreamingOutputCallRequest() 458 request.response_parameters.add(size=2, interval_us=0) 459 request.response_parameters.add(size=3, interval_us=0) 460 yield request 461 462 responses = service.stub.HalfDuplexCall(half_duplex_request_iterator()) 463 expected_responses = service.servicer_methods.HalfDuplexCall( 464 half_duplex_request_iterator(), 'not a real RpcContext!') 465 for expected_response, response in moves.zip_longest( 466 expected_responses, responses): 467 self.assertEqual(expected_response, response) 468 service.server.stop(None) 469 470 def testHalfDuplexCallWedged(self): 471 condition = threading.Condition() 472 wait_cell = [False] 473 474 @contextlib.contextmanager 475 def wait(): # pylint: disable=invalid-name 476 # Where's Python 3's 'nonlocal' statement when you need it? 477 with condition: 478 wait_cell[0] = True 479 yield 480 with condition: 481 wait_cell[0] = False 482 condition.notify_all() 483 484 def half_duplex_request_iterator(): 485 request = request_pb2.StreamingOutputCallRequest() 486 request.response_parameters.add(size=1, interval_us=0) 487 yield request 488 with condition: 489 while wait_cell[0]: 490 condition.wait() 491 492 service = _CreateService() 493 with wait(): 494 responses = service.stub.HalfDuplexCall( 495 half_duplex_request_iterator(), 496 timeout=test_constants.SHORT_TIMEOUT) 497 # half-duplex waits for the client to send all info 498 with self.assertRaises(grpc.RpcError) as exception_context: 499 next(responses) 500 self.assertIs(exception_context.exception.code(), 501 grpc.StatusCode.DEADLINE_EXCEEDED) 502 service.server.stop(None) 503 504 505if __name__ == '__main__': 506 unittest.main(verbosity=2) 507