1#
2# Copyright 2015 Google Inc.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8#     http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15
16"""The mock module allows easy mocking of apitools clients.
17
18This module allows you to mock out the constructor of a particular apitools
19client, for a specific API and version. Then, when the client is created, it
20will be run against an expected session that you define. This way code that is
21not aware of the testing framework can construct new clients as normal, as long
22as it's all done within the context of a mock.
23"""
24
25import difflib
26import sys
27
28import six
29
30from apitools.base.protorpclite import messages
31from apitools.base.py import base_api
32from apitools.base.py import encoding
33from apitools.base.py import exceptions
34
35
36class Error(Exception):
37
38    """Exceptions for this module."""
39
40
41def _MessagesEqual(msg1, msg2):
42    """Compare two protorpc messages for equality.
43
44    Using python's == operator does not work in all cases, specifically when
45    there is a list involved.
46
47    Args:
48      msg1: protorpc.messages.Message or [protorpc.messages.Message] or number
49          or string, One of the messages to compare.
50      msg2: protorpc.messages.Message or [protorpc.messages.Message] or number
51          or string, One of the messages to compare.
52
53    Returns:
54      If the messages are isomorphic.
55    """
56    if isinstance(msg1, list) and isinstance(msg2, list):
57        if len(msg1) != len(msg2):
58            return False
59        return all(_MessagesEqual(x, y) for x, y in zip(msg1, msg2))
60
61    if (not isinstance(msg1, messages.Message) or
62            not isinstance(msg2, messages.Message)):
63        return msg1 == msg2
64    for field in msg1.all_fields():
65        field1 = getattr(msg1, field.name)
66        field2 = getattr(msg2, field.name)
67        if not _MessagesEqual(field1, field2):
68            return False
69    return True
70
71
72class UnexpectedRequestException(Error):
73
74    def __init__(self, received_call, expected_call):
75        expected_key, expected_request = expected_call
76        received_key, received_request = received_call
77
78        expected_repr = encoding.MessageToRepr(
79            expected_request, multiline=True)
80        received_repr = encoding.MessageToRepr(
81            received_request, multiline=True)
82
83        expected_lines = expected_repr.splitlines()
84        received_lines = received_repr.splitlines()
85
86        diff_lines = difflib.unified_diff(expected_lines, received_lines)
87        diff = '\n'.join(diff_lines)
88
89        if expected_key != received_key:
90            msg = '\n'.join((
91                'expected: {expected_key}({expected_request})',
92                'received: {received_key}({received_request})',
93                '',
94            )).format(
95                expected_key=expected_key,
96                expected_request=expected_repr,
97                received_key=received_key,
98                received_request=received_repr)
99            super(UnexpectedRequestException, self).__init__(msg)
100        else:
101            msg = '\n'.join((
102                'for request to {key},',
103                'expected: {expected_request}',
104                'received: {received_request}',
105                'diff: {diff}',
106                '',
107            )).format(
108                key=expected_key,
109                expected_request=expected_repr,
110                received_request=received_repr,
111                diff=diff)
112            super(UnexpectedRequestException, self).__init__(msg)
113
114
115class ExpectedRequestsException(Error):
116
117    def __init__(self, expected_calls):
118        msg = 'expected:\n'
119        for (key, request) in expected_calls:
120            msg += '{key}({request})\n'.format(
121                key=key,
122                request=encoding.MessageToRepr(request, multiline=True))
123        super(ExpectedRequestsException, self).__init__(msg)
124
125
126class _ExpectedRequestResponse(object):
127
128    """Encapsulation of an expected request and corresponding response."""
129
130    def __init__(self, key, request, response=None, exception=None):
131        self.__key = key
132        self.__request = request
133
134        if response and exception:
135            raise exceptions.ConfigurationValueError(
136                'Should specify at most one of response and exception')
137        if response and isinstance(response, exceptions.Error):
138            raise exceptions.ConfigurationValueError(
139                'Responses should not be an instance of Error')
140        if exception and not isinstance(exception, exceptions.Error):
141            raise exceptions.ConfigurationValueError(
142                'Exceptions must be instances of Error')
143
144        self.__response = response
145        self.__exception = exception
146
147    @property
148    def key(self):
149        return self.__key
150
151    @property
152    def request(self):
153        return self.__request
154
155    def ValidateAndRespond(self, key, request):
156        """Validate that key and request match expectations, and respond if so.
157
158        Args:
159          key: str, Actual key to compare against expectations.
160          request: protorpc.messages.Message or [protorpc.messages.Message]
161            or number or string, Actual request to compare againt expectations
162
163        Raises:
164          UnexpectedRequestException: If key or request dont match
165              expectations.
166          apitools_base.Error: If a non-None exception is specified to
167              be thrown.
168
169        Returns:
170          The response that was specified to be returned.
171
172        """
173        if key != self.__key or not _MessagesEqual(request, self.__request):
174            raise UnexpectedRequestException((key, request),
175                                             (self.__key, self.__request))
176
177        if self.__exception:
178            # Can only throw apitools_base.Error.
179            raise self.__exception  # pylint: disable=raising-bad-type
180
181        return self.__response
182
183
184class _MockedMethod(object):
185
186    """A mocked API service method."""
187
188    def __init__(self, key, mocked_client, real_method):
189        self.__name__ = real_method.__name__
190        self.__key = key
191        self.__mocked_client = mocked_client
192        self.__real_method = real_method
193        self.method_config = real_method.method_config
194
195    def Expect(self, request, response=None, exception=None, **unused_kwargs):
196        """Add an expectation on the mocked method.
197
198        Exactly one of response and exception should be specified.
199
200        Args:
201          request: The request that should be expected
202          response: The response that should be returned or None if
203              exception is provided.
204          exception: An exception that should be thrown, or None.
205
206        """
207        # TODO(jasmuth): the unused_kwargs provides a placeholder for
208        # future things that can be passed to Expect(), like special
209        # params to the method call.
210
211        # pylint: disable=protected-access
212        # Class in same module.
213        self.__mocked_client._request_responses.append(
214            _ExpectedRequestResponse(self.__key,
215                                     request,
216                                     response=response,
217                                     exception=exception))
218        # pylint: enable=protected-access
219
220    def __call__(self, request, **unused_kwargs):
221        # TODO(jasmuth): allow the testing code to expect certain
222        # values in these currently unused_kwargs, especially the
223        # upload parameter used by media-heavy services like bigquery
224        # or bigstore.
225
226        # pylint: disable=protected-access
227        # Class in same module.
228        if self.__mocked_client._request_responses:
229            request_response = self.__mocked_client._request_responses.pop(0)
230        else:
231            raise UnexpectedRequestException(
232                (self.__key, request), (None, None))
233        # pylint: enable=protected-access
234
235        response = request_response.ValidateAndRespond(self.__key, request)
236
237        if response is None and self.__real_method:
238            response = self.__real_method(request)
239            print(encoding.MessageToRepr(
240                response, multiline=True, shortstrings=True))
241            return response
242
243        return response
244
245
246def _MakeMockedService(api_name, collection_name,
247                       mock_client, service, real_service):
248    class MockedService(base_api.BaseApiService):
249        pass
250
251    for method in service.GetMethodsList():
252        real_method = None
253        if real_service:
254            real_method = getattr(real_service, method)
255        setattr(MockedService,
256                method,
257                _MockedMethod(api_name + '.' + collection_name + '.' + method,
258                              mock_client,
259                              real_method))
260    return MockedService
261
262
263class Client(object):
264
265    """Mock an apitools client."""
266
267    def __init__(self, client_class, real_client=None):
268        """Mock an apitools API, given its class.
269
270        Args:
271          client_class: The class for the API. eg, if you
272                from apis.sqladmin import v1beta3
273              then you can pass v1beta3.SqladminV1beta3 to this class
274              and anything within its context will use your mocked
275              version.
276          real_client: apitools Client, The client to make requests
277              against when the expected response is None.
278
279        """
280
281        if not real_client:
282            real_client = client_class(get_credentials=False)
283
284        self.__orig_class = self.__class__
285        self.__client_class = client_class
286        self.__real_service_classes = {}
287        self.__real_client = real_client
288
289        self._request_responses = []
290        self.__real_include_fields = None
291
292    def __enter__(self):
293        return self.Mock()
294
295    def Mock(self):
296        """Stub out the client class with mocked services."""
297        client = self.__real_client or self.__client_class(
298            get_credentials=False)
299
300        class Patched(self.__class__, self.__client_class):
301            pass
302        self.__class__ = Patched
303
304        for name in dir(self.__client_class):
305            service_class = getattr(self.__client_class, name)
306            if not isinstance(service_class, type):
307                continue
308            if not issubclass(service_class, base_api.BaseApiService):
309                continue
310            self.__real_service_classes[name] = service_class
311            # pylint: disable=protected-access
312            collection_name = service_class._NAME
313            # pylint: enable=protected-access
314            api_name = '%s_%s' % (self.__client_class._PACKAGE,
315                                  self.__client_class._URL_VERSION)
316            mocked_service_class = _MakeMockedService(
317                api_name, collection_name, self,
318                service_class,
319                service_class(client) if self.__real_client else None)
320
321            setattr(self.__client_class, name, mocked_service_class)
322
323            setattr(self, collection_name, mocked_service_class(self))
324
325        self.__real_include_fields = self.__client_class.IncludeFields
326        self.__client_class.IncludeFields = self.IncludeFields
327
328        # pylint: disable=attribute-defined-outside-init
329        self._url = client._url
330        self._http = client._http
331
332        return self
333
334    def __exit__(self, exc_type, value, traceback):
335        is_active_exception = value is not None
336        self.Unmock(suppress=is_active_exception)
337        if is_active_exception:
338            six.reraise(exc_type, value, traceback)
339        return True
340
341    def Unmock(self, suppress=False):
342        self.__class__ = self.__orig_class
343        for name, service_class in self.__real_service_classes.items():
344            setattr(self.__client_class, name, service_class)
345            delattr(self, service_class._NAME)
346        self.__real_service_classes = {}
347        del self._url
348        del self._http
349
350        self.__client_class.IncludeFields = self.__real_include_fields
351        self.__real_include_fields = None
352
353        requests = [(rq_rs.key, rq_rs.request)
354                    for rq_rs in self._request_responses]
355        self._request_responses = []
356
357        if requests and not suppress and sys.exc_info()[1] is None:
358            raise ExpectedRequestsException(requests)
359
360    def IncludeFields(self, include_fields):
361        if self.__real_client:
362            return self.__real_include_fields(self.__real_client,
363                                              include_fields)
364