1#
2# Copyright 2015 Google Inc.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8#     http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15
16"""Tests for apitools.base.py.batch."""
17
18import textwrap
19
20import mock
21from six.moves import http_client
22from six.moves import range  # pylint:disable=redefined-builtin
23from six.moves.urllib import parse
24import unittest2
25
26from apitools.base.py import batch
27from apitools.base.py import exceptions
28from apitools.base.py import http_wrapper
29
30
31class FakeCredentials(object):
32
33    def __init__(self):
34        self.num_refreshes = 0
35
36    def refresh(self, _):
37        self.num_refreshes += 1
38
39
40class FakeHttp(object):
41
42    class FakeRequest(object):
43
44        def __init__(self, credentials=None):
45            if credentials is not None:
46                self.credentials = credentials
47
48    def __init__(self, credentials=None):
49        self.request = FakeHttp.FakeRequest(credentials=credentials)
50
51
52class FakeService(object):
53
54    """A service for testing."""
55
56    def GetMethodConfig(self, _):
57        return {}
58
59    def GetUploadConfig(self, _):
60        return {}
61
62    # pylint: disable=unused-argument
63    def PrepareHttpRequest(
64            self, method_config, request, global_params, upload_config):
65        return global_params['desired_request']
66    # pylint: enable=unused-argument
67
68    def ProcessHttpResponse(self, _, http_response):
69        return http_response
70
71
72class BatchTest(unittest2.TestCase):
73
74    def assertUrlEqual(self, expected_url, provided_url):
75
76        def parse_components(url):
77            parsed = parse.urlsplit(url)
78            query = parse.parse_qs(parsed.query)
79            return parsed._replace(query=''), query
80
81        expected_parse, expected_query = parse_components(expected_url)
82        provided_parse, provided_query = parse_components(provided_url)
83
84        self.assertEqual(expected_parse, provided_parse)
85        self.assertEqual(expected_query, provided_query)
86
87    def __ConfigureMock(self, mock_request, expected_request, response):
88
89        if isinstance(response, list):
90            response = list(response)
91
92        def CheckRequest(_, request, **unused_kwds):
93            self.assertUrlEqual(expected_request.url, request.url)
94            self.assertEqual(expected_request.http_method, request.http_method)
95            if isinstance(response, list):
96                return response.pop(0)
97            return response
98
99        mock_request.side_effect = CheckRequest
100
101    def testRequestServiceUnavailable(self):
102        mock_service = FakeService()
103
104        desired_url = 'https://www.example.com'
105        batch_api_request = batch.BatchApiRequest(batch_url=desired_url,
106                                                  retryable_codes=[])
107        # The request to be added. The actual request sent will be somewhat
108        # larger, as this is added to a batch.
109        desired_request = http_wrapper.Request(desired_url, 'POST', {
110            'content-type': 'multipart/mixed; boundary="None"',
111            'content-length': 80,
112        }, 'x' * 80)
113
114        with mock.patch.object(http_wrapper, 'MakeRequest',
115                               autospec=True) as mock_request:
116            self.__ConfigureMock(
117                mock_request,
118                http_wrapper.Request(desired_url, 'POST', {
119                    'content-type': 'multipart/mixed; boundary="None"',
120                    'content-length': 419,
121                }, 'x' * 419),
122                http_wrapper.Response({
123                    'status': '200',
124                    'content-type': 'multipart/mixed; boundary="boundary"',
125                }, textwrap.dedent("""\
126                --boundary
127                content-type: text/plain
128                content-id: <id+0>
129
130                HTTP/1.1 503 SERVICE UNAVAILABLE
131                nope
132                --boundary--"""), None))
133
134            batch_api_request.Add(
135                mock_service, 'unused', None,
136                global_params={'desired_request': desired_request})
137
138            api_request_responses = batch_api_request.Execute(
139                FakeHttp(), sleep_between_polls=0)
140
141            self.assertEqual(1, len(api_request_responses))
142
143            # Make sure we didn't retry non-retryable code 503.
144            self.assertEqual(1, mock_request.call_count)
145
146            self.assertTrue(api_request_responses[0].is_error)
147            self.assertIsNone(api_request_responses[0].response)
148            self.assertIsInstance(api_request_responses[0].exception,
149                                  exceptions.HttpError)
150
151    def testSingleRequestInBatch(self):
152        desired_url = 'https://www.example.com'
153
154        callback_was_called = []
155
156        def _Callback(response, exception):
157            self.assertEqual({'status': '200'}, response.info)
158            self.assertEqual('content', response.content)
159            self.assertEqual(desired_url, response.request_url)
160            self.assertIsNone(exception)
161            callback_was_called.append(1)
162
163        mock_service = FakeService()
164
165        batch_api_request = batch.BatchApiRequest(batch_url=desired_url)
166        # The request to be added. The actual request sent will be somewhat
167        # larger, as this is added to a batch.
168        desired_request = http_wrapper.Request(desired_url, 'POST', {
169            'content-type': 'multipart/mixed; boundary="None"',
170            'content-length': 80,
171        }, 'x' * 80)
172
173        with mock.patch.object(http_wrapper, 'MakeRequest',
174                               autospec=True) as mock_request:
175            self.__ConfigureMock(
176                mock_request,
177                http_wrapper.Request(desired_url, 'POST', {
178                    'content-type': 'multipart/mixed; boundary="None"',
179                    'content-length': 419,
180                }, 'x' * 419),
181                http_wrapper.Response({
182                    'status': '200',
183                    'content-type': 'multipart/mixed; boundary="boundary"',
184                }, textwrap.dedent("""\
185                --boundary
186                content-type: text/plain
187                content-id: <id+0>
188
189                HTTP/1.1 200 OK
190                content
191                --boundary--"""), None))
192
193            batch_api_request.Add(mock_service, 'unused', None, {
194                'desired_request': desired_request,
195            })
196
197            api_request_responses = batch_api_request.Execute(
198                FakeHttp(), batch_request_callback=_Callback)
199
200            self.assertEqual(1, len(api_request_responses))
201            self.assertEqual(1, mock_request.call_count)
202
203            self.assertFalse(api_request_responses[0].is_error)
204
205            response = api_request_responses[0].response
206            self.assertEqual({'status': '200'}, response.info)
207            self.assertEqual('content', response.content)
208            self.assertEqual(desired_url, response.request_url)
209        self.assertEquals(1, len(callback_was_called))
210
211    def _MakeResponse(self, number_of_parts):
212        return http_wrapper.Response(
213            info={
214                'status': '200',
215                'content-type': 'multipart/mixed; boundary="boundary"',
216            },
217            content='--boundary\n' + '--boundary\n'.join(
218                textwrap.dedent("""\
219                    content-type: text/plain
220                    content-id: <id+{0}>
221
222                    HTTP/1.1 200 OK
223                    response {0} content
224
225                    """)
226                .format(i) for i in range(number_of_parts)) + '--boundary--',
227            request_url=None,
228        )
229
230    def _MakeSampleRequest(self, url, name):
231        return http_wrapper.Request(url, 'POST', {
232            'content-type': 'multipart/mixed; boundary="None"',
233            'content-length': 80,
234        }, '{0} {1}'.format(name, 'x' * (79 - len(name))))
235
236    def testMultipleRequestInBatchWithMax(self):
237        mock_service = FakeService()
238
239        desired_url = 'https://www.example.com'
240        batch_api_request = batch.BatchApiRequest(batch_url=desired_url)
241
242        number_of_requests = 10
243        max_batch_size = 3
244        for i in range(number_of_requests):
245            batch_api_request.Add(
246                mock_service, 'unused', None,
247                {'desired_request': self._MakeSampleRequest(
248                    desired_url, 'Sample-{0}'.format(i))})
249
250        responses = []
251        for i in range(0, number_of_requests, max_batch_size):
252            responses.append(
253                self._MakeResponse(
254                    min(number_of_requests - i, max_batch_size)))
255        with mock.patch.object(http_wrapper, 'MakeRequest',
256                               autospec=True) as mock_request:
257            self.__ConfigureMock(
258                mock_request,
259                expected_request=http_wrapper.Request(desired_url, 'POST', {
260                    'content-type': 'multipart/mixed; boundary="None"',
261                    'content-length': 1142,
262                }, 'x' * 1142),
263                response=responses)
264            api_request_responses = batch_api_request.Execute(
265                FakeHttp(), max_batch_size=max_batch_size)
266
267        self.assertEqual(number_of_requests, len(api_request_responses))
268        self.assertEqual(
269            -(-number_of_requests // max_batch_size),
270            mock_request.call_count)
271
272    def testRefreshOnAuthFailure(self):
273        mock_service = FakeService()
274
275        desired_url = 'https://www.example.com'
276        batch_api_request = batch.BatchApiRequest(batch_url=desired_url)
277        # The request to be added. The actual request sent will be somewhat
278        # larger, as this is added to a batch.
279        desired_request = http_wrapper.Request(desired_url, 'POST', {
280            'content-type': 'multipart/mixed; boundary="None"',
281            'content-length': 80,
282        }, 'x' * 80)
283
284        with mock.patch.object(http_wrapper, 'MakeRequest',
285                               autospec=True) as mock_request:
286            self.__ConfigureMock(
287                mock_request,
288                http_wrapper.Request(desired_url, 'POST', {
289                    'content-type': 'multipart/mixed; boundary="None"',
290                    'content-length': 419,
291                }, 'x' * 419), [
292                    http_wrapper.Response({
293                        'status': '200',
294                        'content-type': 'multipart/mixed; boundary="boundary"',
295                    }, textwrap.dedent("""\
296                    --boundary
297                    content-type: text/plain
298                    content-id: <id+0>
299
300                    HTTP/1.1 401 UNAUTHORIZED
301                    Invalid grant
302
303                    --boundary--"""), None),
304                    http_wrapper.Response({
305                        'status': '200',
306                        'content-type': 'multipart/mixed; boundary="boundary"',
307                    }, textwrap.dedent("""\
308                    --boundary
309                    content-type: text/plain
310                    content-id: <id+0>
311
312                    HTTP/1.1 200 OK
313                    content
314                    --boundary--"""), None)
315                ])
316
317            batch_api_request.Add(mock_service, 'unused', None, {
318                'desired_request': desired_request,
319            })
320
321            credentials = FakeCredentials()
322            api_request_responses = batch_api_request.Execute(
323                FakeHttp(credentials=credentials), sleep_between_polls=0)
324
325            self.assertEqual(1, len(api_request_responses))
326            self.assertEqual(2, mock_request.call_count)
327            self.assertEqual(1, credentials.num_refreshes)
328
329            self.assertFalse(api_request_responses[0].is_error)
330
331            response = api_request_responses[0].response
332            self.assertEqual({'status': '200'}, response.info)
333            self.assertEqual('content', response.content)
334            self.assertEqual(desired_url, response.request_url)
335
336    def testNoAttempts(self):
337        desired_url = 'https://www.example.com'
338        batch_api_request = batch.BatchApiRequest(batch_url=desired_url)
339        batch_api_request.Add(FakeService(), 'unused', None, {
340            'desired_request': http_wrapper.Request(desired_url, 'POST', {
341                'content-type': 'multipart/mixed; boundary="None"',
342                'content-length': 80,
343            }, 'x' * 80),
344        })
345        api_request_responses = batch_api_request.Execute(None, max_retries=0)
346        self.assertEqual(1, len(api_request_responses))
347        self.assertIsNone(api_request_responses[0].response)
348        self.assertIsNone(api_request_responses[0].exception)
349
350    def _DoTestConvertIdToHeader(self, test_id, expected_result):
351        batch_request = batch.BatchHttpRequest('https://www.example.com')
352        self.assertEqual(
353            expected_result % batch_request._BatchHttpRequest__base_id,
354            batch_request._ConvertIdToHeader(test_id))
355
356    def testConvertIdSimple(self):
357        self._DoTestConvertIdToHeader('blah', '<%s+blah>')
358
359    def testConvertIdThatNeedsEscaping(self):
360        self._DoTestConvertIdToHeader('~tilde1', '<%s+%%7Etilde1>')
361
362    def _DoTestConvertHeaderToId(self, header, expected_id):
363        batch_request = batch.BatchHttpRequest('https://www.example.com')
364        self.assertEqual(expected_id,
365                         batch_request._ConvertHeaderToId(header))
366
367    def testConvertHeaderToIdSimple(self):
368        self._DoTestConvertHeaderToId('<hello+blah>', 'blah')
369
370    def testConvertHeaderToIdWithLotsOfPlus(self):
371        self._DoTestConvertHeaderToId('<a+++++plus>', 'plus')
372
373    def _DoTestConvertInvalidHeaderToId(self, invalid_header):
374        batch_request = batch.BatchHttpRequest('https://www.example.com')
375        self.assertRaises(exceptions.BatchError,
376                          batch_request._ConvertHeaderToId, invalid_header)
377
378    def testHeaderWithoutAngleBrackets(self):
379        self._DoTestConvertInvalidHeaderToId('1+1')
380
381    def testHeaderWithoutPlus(self):
382        self._DoTestConvertInvalidHeaderToId('<HEADER>')
383
384    def testSerializeRequest(self):
385        request = http_wrapper.Request(body='Hello World', headers={
386            'content-type': 'protocol/version',
387        })
388        expected_serialized_request = '\n'.join([
389            'GET  HTTP/1.1',
390            'Content-Type: protocol/version',
391            'MIME-Version: 1.0',
392            'content-length: 11',
393            'Host: ',
394            '',
395            'Hello World',
396        ])
397        batch_request = batch.BatchHttpRequest('https://www.example.com')
398        self.assertEqual(expected_serialized_request,
399                         batch_request._SerializeRequest(request))
400
401    def testSerializeRequestPreservesHeaders(self):
402        # Now confirm that if an additional, arbitrary header is added
403        # that it is successfully serialized to the request. Merely
404        # check that it is included, because the order of the headers
405        # in the request is arbitrary.
406        request = http_wrapper.Request(body='Hello World', headers={
407            'content-type': 'protocol/version',
408            'key': 'value',
409        })
410        batch_request = batch.BatchHttpRequest('https://www.example.com')
411        self.assertTrue(
412            'key: value\n' in batch_request._SerializeRequest(request))
413
414    def testSerializeRequestNoBody(self):
415        request = http_wrapper.Request(body=None, headers={
416            'content-type': 'protocol/version',
417        })
418        expected_serialized_request = '\n'.join([
419            'GET  HTTP/1.1',
420            'Content-Type: protocol/version',
421            'MIME-Version: 1.0',
422            'Host: ',
423            '',
424            '',
425        ])
426        batch_request = batch.BatchHttpRequest('https://www.example.com')
427        self.assertEqual(expected_serialized_request,
428                         batch_request._SerializeRequest(request))
429
430    def testDeserializeRequest(self):
431        serialized_payload = '\n'.join([
432            'GET  HTTP/1.1',
433            'Content-Type: protocol/version',
434            'MIME-Version: 1.0',
435            'content-length: 11',
436            'key: value',
437            'Host: ',
438            '',
439            'Hello World',
440        ])
441        example_url = 'https://www.example.com'
442        expected_response = http_wrapper.Response({
443            'content-length': str(len('Hello World')),
444            'Content-Type': 'protocol/version',
445            'key': 'value',
446            'MIME-Version': '1.0',
447            'status': '',
448            'Host': ''
449        }, 'Hello World', example_url)
450
451        batch_request = batch.BatchHttpRequest(example_url)
452        self.assertEqual(
453            expected_response,
454            batch_request._DeserializeResponse(serialized_payload))
455
456    def testNewId(self):
457        batch_request = batch.BatchHttpRequest('https://www.example.com')
458
459        for i in range(100):
460            self.assertEqual(str(i), batch_request._NewId())
461
462    def testAdd(self):
463        batch_request = batch.BatchHttpRequest('https://www.example.com')
464
465        for x in range(100):
466            batch_request.Add(http_wrapper.Request(body=str(x)))
467
468        for key in batch_request._BatchHttpRequest__request_response_handlers:
469            value = batch_request._BatchHttpRequest__request_response_handlers[
470                key]
471            self.assertEqual(key, value.request.body)
472            self.assertFalse(value.request.url)
473            self.assertEqual('GET', value.request.http_method)
474            self.assertIsNone(value.response)
475            self.assertIsNone(value.handler)
476
477    def testInternalExecuteWithFailedRequest(self):
478        with mock.patch.object(http_wrapper, 'MakeRequest',
479                               autospec=True) as mock_request:
480            self.__ConfigureMock(
481                mock_request,
482                http_wrapper.Request('https://www.example.com', 'POST', {
483                    'content-type': 'multipart/mixed; boundary="None"',
484                    'content-length': 80,
485                }, 'x' * 80),
486                http_wrapper.Response({'status': '300'}, None, None))
487
488            batch_request = batch.BatchHttpRequest('https://www.example.com')
489
490            self.assertRaises(
491                exceptions.HttpError, batch_request._Execute, None)
492
493    def testInternalExecuteWithNonMultipartResponse(self):
494        with mock.patch.object(http_wrapper, 'MakeRequest',
495                               autospec=True) as mock_request:
496            self.__ConfigureMock(
497                mock_request,
498                http_wrapper.Request('https://www.example.com', 'POST', {
499                    'content-type': 'multipart/mixed; boundary="None"',
500                    'content-length': 80,
501                }, 'x' * 80),
502                http_wrapper.Response({
503                    'status': '200',
504                    'content-type': 'blah/blah'
505                }, '', None))
506
507            batch_request = batch.BatchHttpRequest('https://www.example.com')
508
509            self.assertRaises(
510                exceptions.BatchError, batch_request._Execute, None)
511
512    def testInternalExecute(self):
513        with mock.patch.object(http_wrapper, 'MakeRequest',
514                               autospec=True) as mock_request:
515            self.__ConfigureMock(
516                mock_request,
517                http_wrapper.Request('https://www.example.com', 'POST', {
518                    'content-type': 'multipart/mixed; boundary="None"',
519                    'content-length': 583,
520                }, 'x' * 583),
521                http_wrapper.Response({
522                    'status': '200',
523                    'content-type': 'multipart/mixed; boundary="boundary"',
524                }, textwrap.dedent("""\
525                --boundary
526                content-type: text/plain
527                content-id: <id+2>
528
529                HTTP/1.1 200 OK
530                Second response
531
532                --boundary
533                content-type: text/plain
534                content-id: <id+1>
535
536                HTTP/1.1 401 UNAUTHORIZED
537                First response
538
539                --boundary--"""), None))
540
541            test_requests = {
542                '1': batch.RequestResponseAndHandler(
543                    http_wrapper.Request(body='first'), None, None),
544                '2': batch.RequestResponseAndHandler(
545                    http_wrapper.Request(body='second'), None, None),
546            }
547
548            batch_request = batch.BatchHttpRequest('https://www.example.com')
549            batch_request._BatchHttpRequest__request_response_handlers = (
550                test_requests)
551
552            batch_request._Execute(FakeHttp())
553
554            test_responses = (
555                batch_request._BatchHttpRequest__request_response_handlers)
556
557            self.assertEqual(http_client.UNAUTHORIZED,
558                             test_responses['1'].response.status_code)
559            self.assertEqual(http_client.OK,
560                             test_responses['2'].response.status_code)
561
562            self.assertIn(
563                'First response', test_responses['1'].response.content)
564            self.assertIn(
565                'Second response', test_responses['2'].response.content)
566
567    def testPublicExecute(self):
568
569        def LocalCallback(response, exception):
570            self.assertEqual({'status': '418'}, response.info)
571            self.assertEqual('Teapot', response.content)
572            self.assertIsNone(response.request_url)
573            self.assertIsInstance(exception, exceptions.HttpError)
574
575        global_callback = mock.Mock()
576        batch_request = batch.BatchHttpRequest(
577            'https://www.example.com', global_callback)
578
579        with mock.patch.object(batch.BatchHttpRequest, '_Execute',
580                               autospec=True) as mock_execute:
581            mock_execute.return_value = None
582
583            test_requests = {
584                '0': batch.RequestResponseAndHandler(
585                    None,
586                    http_wrapper.Response({'status': '200'}, 'Hello!', None),
587                    None),
588                '1': batch.RequestResponseAndHandler(
589                    None,
590                    http_wrapper.Response({'status': '418'}, 'Teapot', None),
591                    LocalCallback),
592            }
593
594            batch_request._BatchHttpRequest__request_response_handlers = (
595                test_requests)
596            batch_request.Execute(None)
597
598            # Global callback was called once per handler.
599            self.assertEqual(len(test_requests), global_callback.call_count)
600