1# Copyright 2015 gRPC authors.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14
15import collections
16import logging
17import threading
18
19import grpc
20from grpc import _common
21from grpc._cython import cygrpc
22
23logging.basicConfig()
24_LOGGER = logging.getLogger(__name__)
25
26
27class _AuthMetadataContext(
28        collections.namedtuple('AuthMetadataContext', (
29            'service_url',
30            'method_name',
31        )), grpc.AuthMetadataContext):
32    pass
33
34
35class _CallbackState(object):
36
37    def __init__(self):
38        self.lock = threading.Lock()
39        self.called = False
40        self.exception = None
41
42
43class _AuthMetadataPluginCallback(grpc.AuthMetadataPluginCallback):
44
45    def __init__(self, state, callback):
46        self._state = state
47        self._callback = callback
48
49    def __call__(self, metadata, error):
50        with self._state.lock:
51            if self._state.exception is None:
52                if self._state.called:
53                    raise RuntimeError(
54                        'AuthMetadataPluginCallback invoked more than once!')
55                else:
56                    self._state.called = True
57            else:
58                raise RuntimeError(
59                    'AuthMetadataPluginCallback raised exception "{}"!'.format(
60                        self._state.exception))
61        if error is None:
62            self._callback(metadata, cygrpc.StatusCode.ok, None)
63        else:
64            self._callback(None, cygrpc.StatusCode.internal,
65                           _common.encode(str(error)))
66
67
68class _Plugin(object):
69
70    def __init__(self, metadata_plugin):
71        self._metadata_plugin = metadata_plugin
72
73    def __call__(self, service_url, method_name, callback):
74        context = _AuthMetadataContext(
75            _common.decode(service_url), _common.decode(method_name))
76        callback_state = _CallbackState()
77        try:
78            self._metadata_plugin(context,
79                                  _AuthMetadataPluginCallback(
80                                      callback_state, callback))
81        except Exception as exception:  # pylint: disable=broad-except
82            _LOGGER.exception(
83                'AuthMetadataPluginCallback "%s" raised exception!',
84                self._metadata_plugin)
85            with callback_state.lock:
86                callback_state.exception = exception
87                if callback_state.called:
88                    return
89            callback(None, cygrpc.StatusCode.internal,
90                     _common.encode(str(exception)))
91
92
93def metadata_plugin_call_credentials(metadata_plugin, name):
94    if name is None:
95        try:
96            effective_name = metadata_plugin.__name__
97        except AttributeError:
98            effective_name = metadata_plugin.__class__.__name__
99    else:
100        effective_name = name
101    return grpc.CallCredentials(
102        cygrpc.MetadataPluginCallCredentials(
103            _Plugin(metadata_plugin), _common.encode(effective_name)))
104