1# Copyright 2016 gRPC authors.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14"""Tests of standard AuthMetadataPlugins."""
15
16import collections
17import threading
18import unittest
19
20from grpc import _auth
21
22
23class MockGoogleCreds(object):
24
25    def get_access_token(self):
26        token = collections.namedtuple('MockAccessTokenInfo',
27                                       ('access_token', 'expires_in'))
28        token.access_token = 'token'
29        return token
30
31
32class MockExceptionGoogleCreds(object):
33
34    def get_access_token(self):
35        raise Exception()
36
37
38class GoogleCallCredentialsTest(unittest.TestCase):
39
40    def test_google_call_credentials_success(self):
41        callback_event = threading.Event()
42
43        def mock_callback(metadata, error):
44            self.assertEqual(metadata, (('authorization', 'Bearer token'),))
45            self.assertIsNone(error)
46            callback_event.set()
47
48        call_creds = _auth.GoogleCallCredentials(MockGoogleCreds())
49        call_creds(None, mock_callback)
50        self.assertTrue(callback_event.wait(1.0))
51
52    def test_google_call_credentials_error(self):
53        callback_event = threading.Event()
54
55        def mock_callback(metadata, error):
56            self.assertIsNotNone(error)
57            callback_event.set()
58
59        call_creds = _auth.GoogleCallCredentials(MockExceptionGoogleCreds())
60        call_creds(None, mock_callback)
61        self.assertTrue(callback_event.wait(1.0))
62
63
64class AccessTokenAuthMetadataPluginTest(unittest.TestCase):
65
66    def test_google_call_credentials_success(self):
67        callback_event = threading.Event()
68
69        def mock_callback(metadata, error):
70            self.assertEqual(metadata, (('authorization', 'Bearer token'),))
71            self.assertIsNone(error)
72            callback_event.set()
73
74        metadata_plugin = _auth.AccessTokenAuthMetadataPlugin('token')
75        metadata_plugin(None, mock_callback)
76        self.assertTrue(callback_event.wait(1.0))
77
78
79if __name__ == '__main__':
80    unittest.main(verbosity=2)
81