1# Copyright 2016 gRPC authors.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14"""GRPCAuthMetadataPlugins for standard authentication."""
15
16import inspect
17from concurrent import futures
18
19import grpc
20
21
22def _sign_request(callback, token, error):
23    metadata = (('authorization', 'Bearer {}'.format(token)),)
24    callback(metadata, error)
25
26
27def _create_get_token_callback(callback):
28
29    def get_token_callback(future):
30        try:
31            access_token = future.result().access_token
32        except Exception as exception:  # pylint: disable=broad-except
33            _sign_request(callback, None, exception)
34        else:
35            _sign_request(callback, access_token, None)
36
37    return get_token_callback
38
39
40class GoogleCallCredentials(grpc.AuthMetadataPlugin):
41    """Metadata wrapper for GoogleCredentials from the oauth2client library."""
42
43    def __init__(self, credentials):
44        self._credentials = credentials
45        self._pool = futures.ThreadPoolExecutor(max_workers=1)
46
47        # Hack to determine if these are JWT creds and we need to pass
48        # additional_claims when getting a token
49        self._is_jwt = 'additional_claims' in inspect.getargspec(
50            credentials.get_access_token).args
51
52    def __call__(self, context, callback):
53        # MetadataPlugins cannot block (see grpc.beta.interfaces.py)
54        if self._is_jwt:
55            future = self._pool.submit(
56                self._credentials.get_access_token,
57                additional_claims={
58                    'aud': context.service_url
59                })
60        else:
61            future = self._pool.submit(self._credentials.get_access_token)
62        future.add_done_callback(_create_get_token_callback(callback))
63
64    def __del__(self):
65        self._pool.shutdown(wait=False)
66
67
68class AccessTokenAuthMetadataPlugin(grpc.AuthMetadataPlugin):
69    """Metadata wrapper for raw access token credentials."""
70
71    def __init__(self, access_token):
72        self._access_token = access_token
73
74    def __call__(self, context, callback):
75        _sign_request(callback, self._access_token, None)
76