1# Copyright 2015 Google Inc. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#      http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14
15"""Tests for oauth2client.contrib.devshell."""
16
17import datetime
18import json
19import os
20import socket
21import threading
22
23import mock
24import unittest2
25
26from oauth2client import _helpers
27from oauth2client import client
28from oauth2client.contrib import devshell
29
30# A dummy value to use for the expires_in field
31# in CredentialInfoResponse.
32EXPIRES_IN = 1000
33DEFAULT_CREDENTIAL_JSON = json.dumps([
34    'joe@example.com',
35    'fooproj',
36    'sometoken',
37    EXPIRES_IN
38])
39
40
41class TestCredentialInfoResponse(unittest2.TestCase):
42
43    def test_constructor_with_non_list(self):
44        json_non_list = '{}'
45        with self.assertRaises(ValueError):
46            devshell.CredentialInfoResponse(json_non_list)
47
48    def test_constructor_with_bad_json(self):
49        json_non_list = '{BADJSON'
50        with self.assertRaises(ValueError):
51            devshell.CredentialInfoResponse(json_non_list)
52
53    def test_constructor_empty_list(self):
54        info_response = devshell.CredentialInfoResponse('[]')
55        self.assertEqual(info_response.user_email, None)
56        self.assertEqual(info_response.project_id, None)
57        self.assertEqual(info_response.access_token, None)
58        self.assertEqual(info_response.expires_in, None)
59
60    def test_constructor_full_list(self):
61        user_email = 'user_email'
62        project_id = 'project_id'
63        access_token = 'access_token'
64        expires_in = 1
65        json_string = json.dumps(
66            [user_email, project_id, access_token, expires_in])
67        info_response = devshell.CredentialInfoResponse(json_string)
68        self.assertEqual(info_response.user_email, user_email)
69        self.assertEqual(info_response.project_id, project_id)
70        self.assertEqual(info_response.access_token, access_token)
71        self.assertEqual(info_response.expires_in, expires_in)
72
73
74class Test_SendRecv(unittest2.TestCase):
75
76    def test_port_zero(self):
77        with mock.patch('oauth2client.contrib.devshell.os') as os_mod:
78            os_mod.getenv = mock.MagicMock(name='getenv', return_value=0)
79            with self.assertRaises(devshell.NoDevshellServer):
80                devshell._SendRecv()
81            os_mod.getenv.assert_called_once_with(devshell.DEVSHELL_ENV, 0)
82
83    def test_no_newline_in_received_header(self):
84        non_zero_port = 1
85        sock = mock.MagicMock()
86
87        header_without_newline = ''
88        sock.recv(6).decode = mock.MagicMock(
89            name='decode', return_value=header_without_newline)
90
91        with mock.patch('oauth2client.contrib.devshell.os') as os_mod:
92            os_mod.getenv = mock.MagicMock(name='getenv',
93                                           return_value=non_zero_port)
94            with mock.patch('oauth2client.contrib.devshell.socket') as socket:
95                socket.socket = mock.MagicMock(name='socket',
96                                               return_value=sock)
97                with self.assertRaises(devshell.CommunicationError):
98                    devshell._SendRecv()
99                os_mod.getenv.assert_called_once_with(devshell.DEVSHELL_ENV, 0)
100                socket.socket.assert_called_once_with()
101                sock.recv(6).decode.assert_called_once_with()
102
103                data = devshell.CREDENTIAL_INFO_REQUEST_JSON
104                msg = _helpers._to_bytes(
105                    '{0}\n{1}'.format(len(data), data), encoding='utf-8')
106                expected_sock_calls = [
107                    mock.call.recv(6),  # From the set-up above
108                    mock.call.connect(('localhost', non_zero_port)),
109                    mock.call.sendall(msg),
110                    mock.call.recv(6),
111                    mock.call.recv(6),  # From the check above
112                ]
113                self.assertEqual(sock.method_calls, expected_sock_calls)
114
115
116class _AuthReferenceServer(threading.Thread):
117
118    def __init__(self, response=None):
119        super(_AuthReferenceServer, self).__init__(None)
120        self.response = response or DEFAULT_CREDENTIAL_JSON
121        self.bad_request = False
122
123    def __enter__(self):
124        return self.start_server()
125
126    def start_server(self):
127        self._socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
128        self._socket.bind(('localhost', 0))
129        port = self._socket.getsockname()[1]
130        os.environ[devshell.DEVSHELL_ENV] = str(port)
131        self._socket.listen(0)
132        self.daemon = True
133        self.start()
134        return self
135
136    def __exit__(self, e_type, value, traceback):
137        self.stop_server()
138
139    def stop_server(self):
140        del os.environ[devshell.DEVSHELL_ENV]
141        self._socket.close()
142
143    def run(self):
144        s = None
145        try:
146            # Do not set the timeout on the socket, leave it in the blocking
147            # mode as setting the timeout seems to cause spurious EAGAIN
148            # errors on OSX.
149            self._socket.settimeout(None)
150
151            s, unused_addr = self._socket.accept()
152            resp_buffer = ''
153            resp_1 = s.recv(6).decode()
154            nstr, extra = resp_1.split('\n', 1)
155            resp_buffer = extra
156            n = int(nstr)
157            to_read = n - len(extra)
158            if to_read > 0:
159                resp_buffer += _helpers._from_bytes(
160                    s.recv(to_read, socket.MSG_WAITALL))
161            if resp_buffer != devshell.CREDENTIAL_INFO_REQUEST_JSON:
162                self.bad_request = True
163            l = len(self.response)
164            s.sendall('{0}\n{1}'.format(l, self.response).encode())
165        finally:
166            # Will fail if s is None, but these tests never encounter
167            # that scenario.
168            s.close()
169
170
171class DevshellCredentialsTests(unittest2.TestCase):
172
173    def test_signals_no_server(self):
174        with self.assertRaises(devshell.NoDevshellServer):
175            devshell.DevshellCredentials()
176
177    def test_bad_message_to_mock_server(self):
178        request_content = devshell.CREDENTIAL_INFO_REQUEST_JSON + 'extrastuff'
179        request_message = _helpers._to_bytes(
180            '{0}\n{1}'.format(len(request_content), request_content))
181        response_message = 'foobar'
182        with _AuthReferenceServer(response_message) as auth_server:
183            self.assertFalse(auth_server.bad_request)
184            sock = socket.socket()
185            port = int(os.getenv(devshell.DEVSHELL_ENV, 0))
186            sock.connect(('localhost', port))
187            sock.sendall(request_message)
188
189            # Mimic the receive part of _SendRecv
190            header = sock.recv(6).decode()
191            len_str, result = header.split('\n', 1)
192            to_read = int(len_str) - len(result)
193            result += sock.recv(to_read, socket.MSG_WAITALL).decode()
194
195        self.assertTrue(auth_server.bad_request)
196        self.assertEqual(result, response_message)
197
198    def test_request_response(self):
199        with _AuthReferenceServer():
200            response = devshell._SendRecv()
201            self.assertEqual(response.user_email, 'joe@example.com')
202            self.assertEqual(response.project_id, 'fooproj')
203            self.assertEqual(response.access_token, 'sometoken')
204
205    def test_no_refresh_token(self):
206        with _AuthReferenceServer():
207            creds = devshell.DevshellCredentials()
208            self.assertEquals(None, creds.refresh_token)
209
210    @mock.patch('oauth2client.client._UTCNOW')
211    def test_reads_credentials(self, utcnow):
212        NOW = datetime.datetime(1992, 12, 31)
213        utcnow.return_value = NOW
214        with _AuthReferenceServer():
215            creds = devshell.DevshellCredentials()
216            self.assertEqual('joe@example.com', creds.user_email)
217            self.assertEqual('fooproj', creds.project_id)
218            self.assertEqual('sometoken', creds.access_token)
219            self.assertEqual(
220                NOW + datetime.timedelta(seconds=EXPIRES_IN),
221                creds.token_expiry)
222            utcnow.assert_called_once_with()
223
224    def test_handles_skipped_fields(self):
225        with _AuthReferenceServer('["joe@example.com"]'):
226            creds = devshell.DevshellCredentials()
227            self.assertEqual('joe@example.com', creds.user_email)
228            self.assertEqual(None, creds.project_id)
229            self.assertEqual(None, creds.access_token)
230            self.assertEqual(None, creds.token_expiry)
231
232    def test_handles_tiny_response(self):
233        with _AuthReferenceServer('[]'):
234            creds = devshell.DevshellCredentials()
235            self.assertEqual(None, creds.user_email)
236            self.assertEqual(None, creds.project_id)
237            self.assertEqual(None, creds.access_token)
238
239    def test_handles_ignores_extra_fields(self):
240        with _AuthReferenceServer(
241                '["joe@example.com", "fooproj", "sometoken", 1, "extra"]'):
242            creds = devshell.DevshellCredentials()
243            self.assertEqual('joe@example.com', creds.user_email)
244            self.assertEqual('fooproj', creds.project_id)
245            self.assertEqual('sometoken', creds.access_token)
246
247    def test_refuses_to_save_to_well_known_file(self):
248        ORIGINAL_ISDIR = os.path.isdir
249        try:
250            os.path.isdir = lambda path: True
251            with _AuthReferenceServer():
252                creds = devshell.DevshellCredentials()
253                with self.assertRaises(NotImplementedError):
254                    client.save_to_well_known_file(creds)
255        finally:
256            os.path.isdir = ORIGINAL_ISDIR
257
258    def test_from_json(self):
259        with self.assertRaises(NotImplementedError):
260            devshell.DevshellCredentials.from_json(None)
261
262    def test_serialization_data(self):
263        with _AuthReferenceServer('[]'):
264            credentials = devshell.DevshellCredentials()
265            with self.assertRaises(NotImplementedError):
266                getattr(credentials, 'serialization_data')
267