1# 2# Copyright 2015 Google Inc. 3# 4# Licensed under the Apache License, Version 2.0 (the "License"); 5# you may not use this file except in compliance with the License. 6# You may obtain a copy of the License at 7# 8# http://www.apache.org/licenses/LICENSE-2.0 9# 10# Unless required by applicable law or agreed to in writing, software 11# distributed under the License is distributed on an "AS IS" BASIS, 12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13# See the License for the specific language governing permissions and 14# limitations under the License. 15 16"""The mock module allows easy mocking of apitools clients. 17 18This module allows you to mock out the constructor of a particular apitools 19client, for a specific API and version. Then, when the client is created, it 20will be run against an expected session that you define. This way code that is 21not aware of the testing framework can construct new clients as normal, as long 22as it's all done within the context of a mock. 23""" 24 25import difflib 26import sys 27 28import six 29 30from apitools.base.protorpclite import messages 31from apitools.base.py import base_api 32from apitools.base.py import encoding 33from apitools.base.py import exceptions 34 35 36class Error(Exception): 37 38 """Exceptions for this module.""" 39 40 41def _MessagesEqual(msg1, msg2): 42 """Compare two protorpc messages for equality. 43 44 Using python's == operator does not work in all cases, specifically when 45 there is a list involved. 46 47 Args: 48 msg1: protorpc.messages.Message or [protorpc.messages.Message] or number 49 or string, One of the messages to compare. 50 msg2: protorpc.messages.Message or [protorpc.messages.Message] or number 51 or string, One of the messages to compare. 52 53 Returns: 54 If the messages are isomorphic. 55 """ 56 if isinstance(msg1, list) and isinstance(msg2, list): 57 if len(msg1) != len(msg2): 58 return False 59 return all(_MessagesEqual(x, y) for x, y in zip(msg1, msg2)) 60 61 if (not isinstance(msg1, messages.Message) or 62 not isinstance(msg2, messages.Message)): 63 return msg1 == msg2 64 for field in msg1.all_fields(): 65 field1 = getattr(msg1, field.name) 66 field2 = getattr(msg2, field.name) 67 if not _MessagesEqual(field1, field2): 68 return False 69 return True 70 71 72class UnexpectedRequestException(Error): 73 74 def __init__(self, received_call, expected_call): 75 expected_key, expected_request = expected_call 76 received_key, received_request = received_call 77 78 expected_repr = encoding.MessageToRepr( 79 expected_request, multiline=True) 80 received_repr = encoding.MessageToRepr( 81 received_request, multiline=True) 82 83 expected_lines = expected_repr.splitlines() 84 received_lines = received_repr.splitlines() 85 86 diff_lines = difflib.unified_diff(expected_lines, received_lines) 87 diff = '\n'.join(diff_lines) 88 89 if expected_key != received_key: 90 msg = '\n'.join(( 91 'expected: {expected_key}({expected_request})', 92 'received: {received_key}({received_request})', 93 '', 94 )).format( 95 expected_key=expected_key, 96 expected_request=expected_repr, 97 received_key=received_key, 98 received_request=received_repr) 99 super(UnexpectedRequestException, self).__init__(msg) 100 else: 101 msg = '\n'.join(( 102 'for request to {key},', 103 'expected: {expected_request}', 104 'received: {received_request}', 105 'diff: {diff}', 106 '', 107 )).format( 108 key=expected_key, 109 expected_request=expected_repr, 110 received_request=received_repr, 111 diff=diff) 112 super(UnexpectedRequestException, self).__init__(msg) 113 114 115class ExpectedRequestsException(Error): 116 117 def __init__(self, expected_calls): 118 msg = 'expected:\n' 119 for (key, request) in expected_calls: 120 msg += '{key}({request})\n'.format( 121 key=key, 122 request=encoding.MessageToRepr(request, multiline=True)) 123 super(ExpectedRequestsException, self).__init__(msg) 124 125 126class _ExpectedRequestResponse(object): 127 128 """Encapsulation of an expected request and corresponding response.""" 129 130 def __init__(self, key, request, response=None, exception=None): 131 self.__key = key 132 self.__request = request 133 134 if response and exception: 135 raise exceptions.ConfigurationValueError( 136 'Should specify at most one of response and exception') 137 if response and isinstance(response, exceptions.Error): 138 raise exceptions.ConfigurationValueError( 139 'Responses should not be an instance of Error') 140 if exception and not isinstance(exception, exceptions.Error): 141 raise exceptions.ConfigurationValueError( 142 'Exceptions must be instances of Error') 143 144 self.__response = response 145 self.__exception = exception 146 147 @property 148 def key(self): 149 return self.__key 150 151 @property 152 def request(self): 153 return self.__request 154 155 def ValidateAndRespond(self, key, request): 156 """Validate that key and request match expectations, and respond if so. 157 158 Args: 159 key: str, Actual key to compare against expectations. 160 request: protorpc.messages.Message or [protorpc.messages.Message] 161 or number or string, Actual request to compare againt expectations 162 163 Raises: 164 UnexpectedRequestException: If key or request dont match 165 expectations. 166 apitools_base.Error: If a non-None exception is specified to 167 be thrown. 168 169 Returns: 170 The response that was specified to be returned. 171 172 """ 173 if key != self.__key or not _MessagesEqual(request, self.__request): 174 raise UnexpectedRequestException((key, request), 175 (self.__key, self.__request)) 176 177 if self.__exception: 178 # Can only throw apitools_base.Error. 179 raise self.__exception # pylint: disable=raising-bad-type 180 181 return self.__response 182 183 184class _MockedMethod(object): 185 186 """A mocked API service method.""" 187 188 def __init__(self, key, mocked_client, real_method): 189 self.__name__ = real_method.__name__ 190 self.__key = key 191 self.__mocked_client = mocked_client 192 self.__real_method = real_method 193 self.method_config = real_method.method_config 194 195 def Expect(self, request, response=None, exception=None, **unused_kwargs): 196 """Add an expectation on the mocked method. 197 198 Exactly one of response and exception should be specified. 199 200 Args: 201 request: The request that should be expected 202 response: The response that should be returned or None if 203 exception is provided. 204 exception: An exception that should be thrown, or None. 205 206 """ 207 # TODO(jasmuth): the unused_kwargs provides a placeholder for 208 # future things that can be passed to Expect(), like special 209 # params to the method call. 210 211 # pylint: disable=protected-access 212 # Class in same module. 213 self.__mocked_client._request_responses.append( 214 _ExpectedRequestResponse(self.__key, 215 request, 216 response=response, 217 exception=exception)) 218 # pylint: enable=protected-access 219 220 def __call__(self, request, **unused_kwargs): 221 # TODO(jasmuth): allow the testing code to expect certain 222 # values in these currently unused_kwargs, especially the 223 # upload parameter used by media-heavy services like bigquery 224 # or bigstore. 225 226 # pylint: disable=protected-access 227 # Class in same module. 228 if self.__mocked_client._request_responses: 229 request_response = self.__mocked_client._request_responses.pop(0) 230 else: 231 raise UnexpectedRequestException( 232 (self.__key, request), (None, None)) 233 # pylint: enable=protected-access 234 235 response = request_response.ValidateAndRespond(self.__key, request) 236 237 if response is None and self.__real_method: 238 response = self.__real_method(request) 239 print(encoding.MessageToRepr( 240 response, multiline=True, shortstrings=True)) 241 return response 242 243 return response 244 245 246def _MakeMockedService(api_name, collection_name, 247 mock_client, service, real_service): 248 class MockedService(base_api.BaseApiService): 249 pass 250 251 for method in service.GetMethodsList(): 252 real_method = None 253 if real_service: 254 real_method = getattr(real_service, method) 255 setattr(MockedService, 256 method, 257 _MockedMethod(api_name + '.' + collection_name + '.' + method, 258 mock_client, 259 real_method)) 260 return MockedService 261 262 263class Client(object): 264 265 """Mock an apitools client.""" 266 267 def __init__(self, client_class, real_client=None): 268 """Mock an apitools API, given its class. 269 270 Args: 271 client_class: The class for the API. eg, if you 272 from apis.sqladmin import v1beta3 273 then you can pass v1beta3.SqladminV1beta3 to this class 274 and anything within its context will use your mocked 275 version. 276 real_client: apitools Client, The client to make requests 277 against when the expected response is None. 278 279 """ 280 281 if not real_client: 282 real_client = client_class(get_credentials=False) 283 284 self.__orig_class = self.__class__ 285 self.__client_class = client_class 286 self.__real_service_classes = {} 287 self.__real_client = real_client 288 289 self._request_responses = [] 290 self.__real_include_fields = None 291 292 def __enter__(self): 293 return self.Mock() 294 295 def Mock(self): 296 """Stub out the client class with mocked services.""" 297 client = self.__real_client or self.__client_class( 298 get_credentials=False) 299 300 class Patched(self.__class__, self.__client_class): 301 pass 302 self.__class__ = Patched 303 304 for name in dir(self.__client_class): 305 service_class = getattr(self.__client_class, name) 306 if not isinstance(service_class, type): 307 continue 308 if not issubclass(service_class, base_api.BaseApiService): 309 continue 310 self.__real_service_classes[name] = service_class 311 # pylint: disable=protected-access 312 collection_name = service_class._NAME 313 # pylint: enable=protected-access 314 api_name = '%s_%s' % (self.__client_class._PACKAGE, 315 self.__client_class._URL_VERSION) 316 mocked_service_class = _MakeMockedService( 317 api_name, collection_name, self, 318 service_class, 319 service_class(client) if self.__real_client else None) 320 321 setattr(self.__client_class, name, mocked_service_class) 322 323 setattr(self, collection_name, mocked_service_class(self)) 324 325 self.__real_include_fields = self.__client_class.IncludeFields 326 self.__client_class.IncludeFields = self.IncludeFields 327 328 # pylint: disable=attribute-defined-outside-init 329 self._url = client._url 330 self._http = client._http 331 332 return self 333 334 def __exit__(self, exc_type, value, traceback): 335 is_active_exception = value is not None 336 self.Unmock(suppress=is_active_exception) 337 if is_active_exception: 338 six.reraise(exc_type, value, traceback) 339 return True 340 341 def Unmock(self, suppress=False): 342 self.__class__ = self.__orig_class 343 for name, service_class in self.__real_service_classes.items(): 344 setattr(self.__client_class, name, service_class) 345 delattr(self, service_class._NAME) 346 self.__real_service_classes = {} 347 del self._url 348 del self._http 349 350 self.__client_class.IncludeFields = self.__real_include_fields 351 self.__real_include_fields = None 352 353 requests = [(rq_rs.key, rq_rs.request) 354 for rq_rs in self._request_responses] 355 self._request_responses = [] 356 357 if requests and not suppress and sys.exc_info()[1] is None: 358 raise ExpectedRequestsException(requests) 359 360 def IncludeFields(self, include_fields): 361 if self.__real_client: 362 return self.__real_include_fields(self.__real_client, 363 include_fields) 364