1# Copyright 2016 Google Inc. All rights reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#      http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14
15import socket
16import sys
17import threading
18
19import mock
20from six.moves.urllib import request
21import unittest2
22
23from oauth2client import client
24from oauth2client import tools
25
26try:
27    import argparse
28except ImportError:  # pragma: NO COVER
29    raise unittest2.SkipTest('argparase unavailable.')
30
31
32class TestClientRedirectServer(unittest2.TestCase):
33    """Test the ClientRedirectServer and ClientRedirectHandler classes."""
34
35    def test_ClientRedirectServer(self):
36        # create a ClientRedirectServer and run it in a thread to listen
37        # for a mock GET request with the access token
38        # the server should return a 200 message and store the token
39        httpd = tools.ClientRedirectServer(('localhost', 0),
40                                           tools.ClientRedirectHandler)
41        code = 'foo'
42        url = 'http://localhost:{0}?code={1}'.format(
43            httpd.server_address[1], code)
44        t = threading.Thread(target=httpd.handle_request)
45        t.setDaemon(True)
46        t.start()
47        f = request.urlopen(url)
48        self.assertTrue(f.read())
49        t.join()
50        httpd.server_close()
51        self.assertEqual(httpd.query_params.get('code'), code)
52
53
54class TestRunFlow(unittest2.TestCase):
55
56    def setUp(self):
57        self.server = mock.Mock()
58        self.flow = mock.Mock()
59        self.storage = mock.Mock()
60        self.credentials = mock.Mock()
61
62        self.flow.step1_get_authorize_url.return_value = (
63            'http://example.com/auth')
64        self.flow.step2_exchange.return_value = self.credentials
65
66        self.flags = argparse.Namespace(
67            noauth_local_webserver=True, logging_level='INFO')
68        self.server_flags = argparse.Namespace(
69            noauth_local_webserver=False,
70            logging_level='INFO',
71            auth_host_port=[8080, ],
72            auth_host_name='localhost')
73
74    @mock.patch.object(sys, 'argv', ['ignored', '--noauth_local_webserver'])
75    @mock.patch('oauth2client.tools.logging')
76    @mock.patch('oauth2client.tools.input')
77    def test_run_flow_no_webserver(self, input_mock, logging_mock):
78        input_mock.return_value = 'auth_code'
79
80        # Successful exchange.
81        returned_credentials = tools.run_flow(self.flow, self.storage)
82
83        self.assertEqual(self.credentials, returned_credentials)
84        self.assertEqual(self.flow.redirect_uri, client.OOB_CALLBACK_URN)
85        self.flow.step2_exchange.assert_called_once_with(
86            'auth_code', http=None)
87        self.storage.put.assert_called_once_with(self.credentials)
88        self.credentials.set_store.assert_called_once_with(self.storage)
89
90    @mock.patch('oauth2client.tools.logging')
91    @mock.patch('oauth2client.tools.input')
92    def test_run_flow_no_webserver_explicit_flags(
93            self, input_mock, logging_mock):
94        input_mock.return_value = 'auth_code'
95
96        # Successful exchange.
97        returned_credentials = tools.run_flow(
98            self.flow, self.storage, flags=self.flags)
99
100        self.assertEqual(self.credentials, returned_credentials)
101        self.assertEqual(self.flow.redirect_uri, client.OOB_CALLBACK_URN)
102        self.flow.step2_exchange.assert_called_once_with(
103            'auth_code', http=None)
104
105    @mock.patch('oauth2client.tools.logging')
106    @mock.patch('oauth2client.tools.input')
107    def test_run_flow_no_webserver_exchange_error(
108            self, input_mock, logging_mock):
109        input_mock.return_value = 'auth_code'
110        self.flow.step2_exchange.side_effect = client.FlowExchangeError()
111
112        # Error while exchanging.
113        with self.assertRaises(SystemExit):
114            tools.run_flow(self.flow, self.storage, flags=self.flags)
115
116        self.flow.step2_exchange.assert_called_once_with(
117            'auth_code', http=None)
118
119    @mock.patch('oauth2client.tools.logging')
120    @mock.patch('oauth2client.tools.ClientRedirectServer')
121    @mock.patch('webbrowser.open')
122    def test_run_flow_webserver(
123            self, webbrowser_open_mock, server_ctor_mock, logging_mock):
124        server_ctor_mock.return_value = self.server
125        self.server.query_params = {'code': 'auth_code'}
126
127        # Successful exchange.
128        returned_credentials = tools.run_flow(
129            self.flow, self.storage, flags=self.server_flags)
130
131        self.assertEqual(self.credentials, returned_credentials)
132        self.assertEqual(self.flow.redirect_uri, 'http://localhost:8080/')
133        self.flow.step2_exchange.assert_called_once_with(
134            'auth_code', http=None)
135        self.storage.put.assert_called_once_with(self.credentials)
136        self.credentials.set_store.assert_called_once_with(self.storage)
137        self.assertTrue(self.server.handle_request.called)
138        webbrowser_open_mock.assert_called_once_with(
139            'http://example.com/auth', autoraise=True, new=1)
140
141    @mock.patch('oauth2client.tools.logging')
142    @mock.patch('oauth2client.tools.ClientRedirectServer')
143    @mock.patch('webbrowser.open')
144    def test_run_flow_webserver_exchange_error(
145            self, webbrowser_open_mock, server_ctor_mock, logging_mock):
146        server_ctor_mock.return_value = self.server
147        self.server.query_params = {'error': 'any error'}
148
149        # Exchange returned an error code.
150        with self.assertRaises(SystemExit):
151            tools.run_flow(self.flow, self.storage, flags=self.server_flags)
152
153        self.assertTrue(self.server.handle_request.called)
154
155    @mock.patch('oauth2client.tools.logging')
156    @mock.patch('oauth2client.tools.ClientRedirectServer')
157    @mock.patch('webbrowser.open')
158    def test_run_flow_webserver_no_code(
159            self, webbrowser_open_mock, server_ctor_mock, logging_mock):
160        server_ctor_mock.return_value = self.server
161        self.server.query_params = {}
162
163        # No code found in response
164        with self.assertRaises(SystemExit):
165            tools.run_flow(self.flow, self.storage, flags=self.server_flags)
166
167        self.assertTrue(self.server.handle_request.called)
168
169    @mock.patch('oauth2client.tools.logging')
170    @mock.patch('oauth2client.tools.ClientRedirectServer')
171    @mock.patch('oauth2client.tools.input')
172    def test_run_flow_webserver_fallback(
173            self, input_mock, server_ctor_mock, logging_mock):
174        server_ctor_mock.side_effect = socket.error()
175        input_mock.return_value = 'auth_code'
176
177        # It should catch the socket error and proceed as if
178        # noauth_local_webserver was specified.
179        returned_credentials = tools.run_flow(
180            self.flow, self.storage, flags=self.server_flags)
181
182        self.assertEqual(self.credentials, returned_credentials)
183        self.assertEqual(self.flow.redirect_uri, client.OOB_CALLBACK_URN)
184        self.flow.step2_exchange.assert_called_once_with(
185            'auth_code', http=None)
186        self.assertTrue(server_ctor_mock.called)
187        self.assertFalse(self.server.handle_request.called)
188
189
190class TestMessageIfMissing(unittest2.TestCase):
191    def test_message_if_missing(self):
192        self.assertIn('somefile.txt', tools.message_if_missing('somefile.txt'))
193