1"""The mock module allows easy mocking of apitools clients.
2
3This module allows you to mock out the constructor of a particular apitools
4client, for a specific API and version. Then, when the client is created, it
5will be run against an expected session that you define. This way code that is
6not aware of the testing framework can construct new clients as normal, as long
7as it's all done within the context of a mock.
8"""
9
10import difflib
11
12from protorpc import messages
13import six
14
15import apitools.base.py as apitools_base
16
17
18class Error(Exception):
19
20    """Exceptions for this module."""
21
22
23def _MessagesEqual(msg1, msg2):
24    """Compare two protorpc messages for equality.
25
26    Using python's == operator does not work in all cases, specifically when
27    there is a list involved.
28
29    Args:
30      msg1: protorpc.messages.Message or [protorpc.messages.Message] or number
31          or string, One of the messages to compare.
32      msg2: protorpc.messages.Message or [protorpc.messages.Message] or number
33          or string, One of the messages to compare.
34
35    Returns:
36      If the messages are isomorphic.
37    """
38    if isinstance(msg1, list) and isinstance(msg2, list):
39        if len(msg1) != len(msg2):
40            return False
41        return all(_MessagesEqual(x, y) for x, y in zip(msg1, msg2))
42
43    if (not isinstance(msg1, messages.Message) or
44            not isinstance(msg2, messages.Message)):
45        return msg1 == msg2
46    for field in msg1.all_fields():
47        field1 = getattr(msg1, field.name)
48        field2 = getattr(msg2, field.name)
49        if not _MessagesEqual(field1, field2):
50            return False
51    return True
52
53
54class UnexpectedRequestException(Error):
55
56    def __init__(self, received_call, expected_call):
57        expected_key, expected_request = expected_call
58        received_key, received_request = received_call
59
60        expected_repr = apitools_base.MessageToRepr(
61            expected_request, multiline=True)
62        received_repr = apitools_base.MessageToRepr(
63            received_request, multiline=True)
64
65        expected_lines = expected_repr.splitlines()
66        received_lines = received_repr.splitlines()
67
68        diff_lines = difflib.unified_diff(expected_lines, received_lines)
69        diff = '\n'.join(diff_lines)
70
71        if expected_key != received_key:
72            msg = '\n'.join((
73                'expected: {expected_key}({expected_request})',
74                'received: {received_key}({received_request})',
75                '',
76            )).format(
77                expected_key=expected_key,
78                expected_request=expected_repr,
79                received_key=received_key,
80                received_request=received_repr)
81            super(UnexpectedRequestException, self).__init__(msg)
82        else:
83            msg = '\n'.join((
84                'for request to {key},',
85                'expected: {expected_request}',
86                'received: {received_request}',
87                'diff: {diff}',
88                '',
89            )).format(
90                key=expected_key,
91                expected_request=expected_repr,
92                received_request=received_repr,
93                diff=diff)
94            super(UnexpectedRequestException, self).__init__(msg)
95
96
97class ExpectedRequestsException(Error):
98
99    def __init__(self, expected_calls):
100        msg = 'expected:\n'
101        for (key, request) in expected_calls:
102            msg += '{key}({request})\n'.format(
103                key=key,
104                request=apitools_base.MessageToRepr(request, multiline=True))
105        super(ExpectedRequestsException, self).__init__(msg)
106
107
108class _ExpectedRequestResponse(object):
109
110    """Encapsulation of an expected request and corresponding response."""
111
112    def __init__(self, key, request, response=None, exception=None):
113        self.__key = key
114        self.__request = request
115
116        if response and exception:
117            raise apitools_base.ConfigurationValueError(
118                'Should specify at most one of response and exception')
119        if response and isinstance(response, apitools_base.Error):
120            raise apitools_base.ConfigurationValueError(
121                'Responses should not be an instance of Error')
122        if exception and not isinstance(exception, apitools_base.Error):
123            raise apitools_base.ConfigurationValueError(
124                'Exceptions must be instances of Error')
125
126        self.__response = response
127        self.__exception = exception
128
129    @property
130    def key(self):
131        return self.__key
132
133    @property
134    def request(self):
135        return self.__request
136
137    def ValidateAndRespond(self, key, request):
138        """Validate that key and request match expectations, and respond if so.
139
140        Args:
141          key: str, Actual key to compare against expectations.
142          request: protorpc.messages.Message or [protorpc.messages.Message]
143            or number or string, Actual request to compare againt expectations
144
145        Raises:
146          UnexpectedRequestException: If key or request dont match
147              expectations.
148          apitools_base.Error: If a non-None exception is specified to
149              be thrown.
150
151        Returns:
152          The response that was specified to be returned.
153
154        """
155        if key != self.__key or not _MessagesEqual(request, self.__request):
156            raise UnexpectedRequestException((key, request),
157                                             (self.__key, self.__request))
158
159        if self.__exception:
160            # Can only throw apitools_base.Error.
161            raise self.__exception  # pylint: disable=raising-bad-type
162
163        return self.__response
164
165
166class _MockedService(apitools_base.BaseApiService):
167
168    def __init__(self, key, mocked_client, methods, real_service):
169        super(_MockedService, self).__init__(mocked_client)
170        self.__dict__.update(real_service.__dict__)
171        for method in methods:
172            real_method = None
173            if real_service:
174                real_method = getattr(real_service, method)
175            setattr(self, method,
176                    _MockedMethod(key + '.' + method,
177                                  mocked_client,
178                                  real_method))
179
180
181class _MockedMethod(object):
182
183    """A mocked API service method."""
184
185    def __init__(self, key, mocked_client, real_method):
186        self.__key = key
187        self.__mocked_client = mocked_client
188        self.__real_method = real_method
189
190    def Expect(self, request, response=None, exception=None, **unused_kwargs):
191        """Add an expectation on the mocked method.
192
193        Exactly one of response and exception should be specified.
194
195        Args:
196          request: The request that should be expected
197          response: The response that should be returned or None if
198              exception is provided.
199          exception: An exception that should be thrown, or None.
200
201        """
202        # TODO(jasmuth): the unused_kwargs provides a placeholder for
203        # future things that can be passed to Expect(), like special
204        # params to the method call.
205
206        # pylint: disable=protected-access
207        # Class in same module.
208        self.__mocked_client._request_responses.append(
209            _ExpectedRequestResponse(self.__key,
210                                     request,
211                                     response=response,
212                                     exception=exception))
213        # pylint: enable=protected-access
214
215    def __call__(self, request, **unused_kwargs):
216        # TODO(jasmuth): allow the testing code to expect certain
217        # values in these currently unused_kwargs, especially the
218        # upload parameter used by media-heavy services like bigquery
219        # or bigstore.
220
221        # pylint: disable=protected-access
222        # Class in same module.
223        if self.__mocked_client._request_responses:
224            request_response = self.__mocked_client._request_responses.pop(0)
225        else:
226            raise UnexpectedRequestException(
227                (self.__key, request), (None, None))
228        # pylint: enable=protected-access
229
230        response = request_response.ValidateAndRespond(self.__key, request)
231
232        if response is None and self.__real_method:
233            response = self.__real_method(request)
234            print(apitools_base.MessageToRepr(
235                response, multiline=True, shortstrings=True))
236            return response
237
238        return response
239
240
241def _MakeMockedServiceConstructor(mocked_service):
242    def Constructor(unused_self, unused_client):
243        return mocked_service
244    return Constructor
245
246
247class Client(object):
248
249    """Mock an apitools client."""
250
251    def __init__(self, client_class, real_client=None):
252        """Mock an apitools API, given its class.
253
254        Args:
255          client_class: The class for the API. eg, if you
256                from apis.sqladmin import v1beta3
257              then you can pass v1beta3.SqladminV1beta3 to this class
258              and anything within its context will use your mocked
259              version.
260          real_client: apitools Client, The client to make requests
261              against when the expected response is None.
262
263        """
264
265        if not real_client:
266            real_client = client_class(get_credentials=False)
267
268        self.__client_class = client_class
269        self.__real_service_classes = {}
270        self.__real_client = real_client
271
272        self._request_responses = []
273
274    def __enter__(self):
275        return self.Mock()
276
277    def Mock(self):
278        """Stub out the client class with mocked services."""
279        client = self.__real_client or self.__client_class(
280            get_credentials=False)
281        for name in dir(self.__client_class):
282            service_class = getattr(self.__client_class, name)
283            if not isinstance(service_class, type):
284                continue
285            if not issubclass(service_class, apitools_base.BaseApiService):
286                continue
287            self.__real_service_classes[name] = service_class
288            service = service_class(client)
289            # pylint: disable=protected-access
290            # Some liberty is allowed with mocking.
291            collection_name = service_class._NAME
292            # pylint: enable=protected-access
293            api_name = '%s_%s' % (self.__client_class._PACKAGE,
294                                  self.__client_class._URL_VERSION)
295            mocked_service = _MockedService(
296                api_name + '.' + collection_name, self,
297                service._method_configs.keys(),
298                service if self.__real_client else None)
299            mocked_constructor = _MakeMockedServiceConstructor(mocked_service)
300            setattr(self.__client_class, name, mocked_constructor)
301
302            setattr(self, collection_name, mocked_service)
303
304        self.__real_include_fields = self.__client_class.IncludeFields
305        self.__client_class.IncludeFields = self.IncludeFields
306
307        return self
308
309    def __exit__(self, exc_type, value, traceback):
310        self.Unmock()
311        if value:
312            six.reraise(exc_type, value, traceback)
313        return True
314
315    def Unmock(self):
316        for name, service_class in self.__real_service_classes.items():
317            setattr(self.__client_class, name, service_class)
318
319        if self._request_responses:
320            raise ExpectedRequestsException(
321                [(rq_rs.key, rq_rs.request) for rq_rs
322                 in self._request_responses])
323
324        self.__client_class.IncludeFields = self.__real_include_fields
325
326    def IncludeFields(self, include_fields):
327        if self.__real_client:
328            return self.__real_include_fields(self.__real_client,
329                                              include_fields)
330