1"""The mock module allows easy mocking of apitools clients. 2 3This module allows you to mock out the constructor of a particular apitools 4client, for a specific API and version. Then, when the client is created, it 5will be run against an expected session that you define. This way code that is 6not aware of the testing framework can construct new clients as normal, as long 7as it's all done within the context of a mock. 8""" 9 10import difflib 11 12from protorpc import messages 13import six 14 15import apitools.base.py as apitools_base 16 17 18class Error(Exception): 19 20 """Exceptions for this module.""" 21 22 23def _MessagesEqual(msg1, msg2): 24 """Compare two protorpc messages for equality. 25 26 Using python's == operator does not work in all cases, specifically when 27 there is a list involved. 28 29 Args: 30 msg1: protorpc.messages.Message or [protorpc.messages.Message] or number 31 or string, One of the messages to compare. 32 msg2: protorpc.messages.Message or [protorpc.messages.Message] or number 33 or string, One of the messages to compare. 34 35 Returns: 36 If the messages are isomorphic. 37 """ 38 if isinstance(msg1, list) and isinstance(msg2, list): 39 if len(msg1) != len(msg2): 40 return False 41 return all(_MessagesEqual(x, y) for x, y in zip(msg1, msg2)) 42 43 if (not isinstance(msg1, messages.Message) or 44 not isinstance(msg2, messages.Message)): 45 return msg1 == msg2 46 for field in msg1.all_fields(): 47 field1 = getattr(msg1, field.name) 48 field2 = getattr(msg2, field.name) 49 if not _MessagesEqual(field1, field2): 50 return False 51 return True 52 53 54class UnexpectedRequestException(Error): 55 56 def __init__(self, received_call, expected_call): 57 expected_key, expected_request = expected_call 58 received_key, received_request = received_call 59 60 expected_repr = apitools_base.MessageToRepr( 61 expected_request, multiline=True) 62 received_repr = apitools_base.MessageToRepr( 63 received_request, multiline=True) 64 65 expected_lines = expected_repr.splitlines() 66 received_lines = received_repr.splitlines() 67 68 diff_lines = difflib.unified_diff(expected_lines, received_lines) 69 diff = '\n'.join(diff_lines) 70 71 if expected_key != received_key: 72 msg = '\n'.join(( 73 'expected: {expected_key}({expected_request})', 74 'received: {received_key}({received_request})', 75 '', 76 )).format( 77 expected_key=expected_key, 78 expected_request=expected_repr, 79 received_key=received_key, 80 received_request=received_repr) 81 super(UnexpectedRequestException, self).__init__(msg) 82 else: 83 msg = '\n'.join(( 84 'for request to {key},', 85 'expected: {expected_request}', 86 'received: {received_request}', 87 'diff: {diff}', 88 '', 89 )).format( 90 key=expected_key, 91 expected_request=expected_repr, 92 received_request=received_repr, 93 diff=diff) 94 super(UnexpectedRequestException, self).__init__(msg) 95 96 97class ExpectedRequestsException(Error): 98 99 def __init__(self, expected_calls): 100 msg = 'expected:\n' 101 for (key, request) in expected_calls: 102 msg += '{key}({request})\n'.format( 103 key=key, 104 request=apitools_base.MessageToRepr(request, multiline=True)) 105 super(ExpectedRequestsException, self).__init__(msg) 106 107 108class _ExpectedRequestResponse(object): 109 110 """Encapsulation of an expected request and corresponding response.""" 111 112 def __init__(self, key, request, response=None, exception=None): 113 self.__key = key 114 self.__request = request 115 116 if response and exception: 117 raise apitools_base.ConfigurationValueError( 118 'Should specify at most one of response and exception') 119 if response and isinstance(response, apitools_base.Error): 120 raise apitools_base.ConfigurationValueError( 121 'Responses should not be an instance of Error') 122 if exception and not isinstance(exception, apitools_base.Error): 123 raise apitools_base.ConfigurationValueError( 124 'Exceptions must be instances of Error') 125 126 self.__response = response 127 self.__exception = exception 128 129 @property 130 def key(self): 131 return self.__key 132 133 @property 134 def request(self): 135 return self.__request 136 137 def ValidateAndRespond(self, key, request): 138 """Validate that key and request match expectations, and respond if so. 139 140 Args: 141 key: str, Actual key to compare against expectations. 142 request: protorpc.messages.Message or [protorpc.messages.Message] 143 or number or string, Actual request to compare againt expectations 144 145 Raises: 146 UnexpectedRequestException: If key or request dont match 147 expectations. 148 apitools_base.Error: If a non-None exception is specified to 149 be thrown. 150 151 Returns: 152 The response that was specified to be returned. 153 154 """ 155 if key != self.__key or not _MessagesEqual(request, self.__request): 156 raise UnexpectedRequestException((key, request), 157 (self.__key, self.__request)) 158 159 if self.__exception: 160 # Can only throw apitools_base.Error. 161 raise self.__exception # pylint: disable=raising-bad-type 162 163 return self.__response 164 165 166class _MockedService(apitools_base.BaseApiService): 167 168 def __init__(self, key, mocked_client, methods, real_service): 169 super(_MockedService, self).__init__(mocked_client) 170 self.__dict__.update(real_service.__dict__) 171 for method in methods: 172 real_method = None 173 if real_service: 174 real_method = getattr(real_service, method) 175 setattr(self, method, 176 _MockedMethod(key + '.' + method, 177 mocked_client, 178 real_method)) 179 180 181class _MockedMethod(object): 182 183 """A mocked API service method.""" 184 185 def __init__(self, key, mocked_client, real_method): 186 self.__key = key 187 self.__mocked_client = mocked_client 188 self.__real_method = real_method 189 190 def Expect(self, request, response=None, exception=None, **unused_kwargs): 191 """Add an expectation on the mocked method. 192 193 Exactly one of response and exception should be specified. 194 195 Args: 196 request: The request that should be expected 197 response: The response that should be returned or None if 198 exception is provided. 199 exception: An exception that should be thrown, or None. 200 201 """ 202 # TODO(jasmuth): the unused_kwargs provides a placeholder for 203 # future things that can be passed to Expect(), like special 204 # params to the method call. 205 206 # pylint: disable=protected-access 207 # Class in same module. 208 self.__mocked_client._request_responses.append( 209 _ExpectedRequestResponse(self.__key, 210 request, 211 response=response, 212 exception=exception)) 213 # pylint: enable=protected-access 214 215 def __call__(self, request, **unused_kwargs): 216 # TODO(jasmuth): allow the testing code to expect certain 217 # values in these currently unused_kwargs, especially the 218 # upload parameter used by media-heavy services like bigquery 219 # or bigstore. 220 221 # pylint: disable=protected-access 222 # Class in same module. 223 if self.__mocked_client._request_responses: 224 request_response = self.__mocked_client._request_responses.pop(0) 225 else: 226 raise UnexpectedRequestException( 227 (self.__key, request), (None, None)) 228 # pylint: enable=protected-access 229 230 response = request_response.ValidateAndRespond(self.__key, request) 231 232 if response is None and self.__real_method: 233 response = self.__real_method(request) 234 print(apitools_base.MessageToRepr( 235 response, multiline=True, shortstrings=True)) 236 return response 237 238 return response 239 240 241def _MakeMockedServiceConstructor(mocked_service): 242 def Constructor(unused_self, unused_client): 243 return mocked_service 244 return Constructor 245 246 247class Client(object): 248 249 """Mock an apitools client.""" 250 251 def __init__(self, client_class, real_client=None): 252 """Mock an apitools API, given its class. 253 254 Args: 255 client_class: The class for the API. eg, if you 256 from apis.sqladmin import v1beta3 257 then you can pass v1beta3.SqladminV1beta3 to this class 258 and anything within its context will use your mocked 259 version. 260 real_client: apitools Client, The client to make requests 261 against when the expected response is None. 262 263 """ 264 265 if not real_client: 266 real_client = client_class(get_credentials=False) 267 268 self.__client_class = client_class 269 self.__real_service_classes = {} 270 self.__real_client = real_client 271 272 self._request_responses = [] 273 274 def __enter__(self): 275 return self.Mock() 276 277 def Mock(self): 278 """Stub out the client class with mocked services.""" 279 client = self.__real_client or self.__client_class( 280 get_credentials=False) 281 for name in dir(self.__client_class): 282 service_class = getattr(self.__client_class, name) 283 if not isinstance(service_class, type): 284 continue 285 if not issubclass(service_class, apitools_base.BaseApiService): 286 continue 287 self.__real_service_classes[name] = service_class 288 service = service_class(client) 289 # pylint: disable=protected-access 290 # Some liberty is allowed with mocking. 291 collection_name = service_class._NAME 292 # pylint: enable=protected-access 293 api_name = '%s_%s' % (self.__client_class._PACKAGE, 294 self.__client_class._URL_VERSION) 295 mocked_service = _MockedService( 296 api_name + '.' + collection_name, self, 297 service._method_configs.keys(), 298 service if self.__real_client else None) 299 mocked_constructor = _MakeMockedServiceConstructor(mocked_service) 300 setattr(self.__client_class, name, mocked_constructor) 301 302 setattr(self, collection_name, mocked_service) 303 304 self.__real_include_fields = self.__client_class.IncludeFields 305 self.__client_class.IncludeFields = self.IncludeFields 306 307 return self 308 309 def __exit__(self, exc_type, value, traceback): 310 self.Unmock() 311 if value: 312 six.reraise(exc_type, value, traceback) 313 return True 314 315 def Unmock(self): 316 for name, service_class in self.__real_service_classes.items(): 317 setattr(self.__client_class, name, service_class) 318 319 if self._request_responses: 320 raise ExpectedRequestsException( 321 [(rq_rs.key, rq_rs.request) for rq_rs 322 in self._request_responses]) 323 324 self.__client_class.IncludeFields = self.__real_include_fields 325 326 def IncludeFields(self, include_fields): 327 if self.__real_client: 328 return self.__real_include_fields(self.__real_client, 329 include_fields) 330