1# Copyright 2016 gRPC authors.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14"""Reference implementation for reflection in gRPC Python."""
15
16import grpc
17from google.protobuf import descriptor_pb2
18from google.protobuf import descriptor_pool
19
20from grpc_reflection.v1alpha import reflection_pb2 as _reflection_pb2
21from grpc_reflection.v1alpha import reflection_pb2_grpc as _reflection_pb2_grpc
22
23_POOL = descriptor_pool.Default()
24SERVICE_NAME = _reflection_pb2.DESCRIPTOR.services_by_name[
25    'ServerReflection'].full_name
26
27
28def _not_found_error():
29    return _reflection_pb2.ServerReflectionResponse(
30        error_response=_reflection_pb2.ErrorResponse(
31            error_code=grpc.StatusCode.NOT_FOUND.value[0],
32            error_message=grpc.StatusCode.NOT_FOUND.value[1].encode(),
33        ))
34
35
36def _file_descriptor_response(descriptor):
37    proto = descriptor_pb2.FileDescriptorProto()
38    descriptor.CopyToProto(proto)
39    serialized_proto = proto.SerializeToString()
40    return _reflection_pb2.ServerReflectionResponse(
41        file_descriptor_response=_reflection_pb2.FileDescriptorResponse(
42            file_descriptor_proto=(serialized_proto,)),)
43
44
45class ReflectionServicer(_reflection_pb2_grpc.ServerReflectionServicer):
46    """Servicer handling RPCs for service statuses."""
47
48    def __init__(self, service_names, pool=None):
49        """Constructor.
50
51    Args:
52      service_names: Iterable of fully-qualified service names available.
53    """
54        self._service_names = tuple(sorted(service_names))
55        self._pool = _POOL if pool is None else pool
56
57    def _file_by_filename(self, filename):
58        try:
59            descriptor = self._pool.FindFileByName(filename)
60        except KeyError:
61            return _not_found_error()
62        else:
63            return _file_descriptor_response(descriptor)
64
65    def _file_containing_symbol(self, fully_qualified_name):
66        try:
67            descriptor = self._pool.FindFileContainingSymbol(
68                fully_qualified_name)
69        except KeyError:
70            return _not_found_error()
71        else:
72            return _file_descriptor_response(descriptor)
73
74    def _file_containing_extension(self, containing_type, extension_number):
75        try:
76            message_descriptor = self._pool.FindMessageTypeByName(
77                containing_type)
78            extension_descriptor = self._pool.FindExtensionByNumber(
79                message_descriptor, extension_number)
80            descriptor = self._pool.FindFileContainingSymbol(
81                extension_descriptor.full_name)
82        except KeyError:
83            return _not_found_error()
84        else:
85            return _file_descriptor_response(descriptor)
86
87    def _all_extension_numbers_of_type(self, containing_type):
88        try:
89            message_descriptor = self._pool.FindMessageTypeByName(
90                containing_type)
91            extension_numbers = tuple(
92                sorted(
93                    extension.number
94                    for extension in self._pool.FindAllExtensions(
95                        message_descriptor)))
96        except KeyError:
97            return _not_found_error()
98        else:
99            return _reflection_pb2.ServerReflectionResponse(
100                all_extension_numbers_response=_reflection_pb2.
101                ExtensionNumberResponse(
102                    base_type_name=message_descriptor.full_name,
103                    extension_number=extension_numbers))
104
105    def _list_services(self):
106        return _reflection_pb2.ServerReflectionResponse(
107            list_services_response=_reflection_pb2.ListServiceResponse(
108                service=[
109                    _reflection_pb2.ServiceResponse(name=service_name)
110                    for service_name in self._service_names
111                ]))
112
113    def ServerReflectionInfo(self, request_iterator, context):
114        # pylint: disable=unused-argument
115        for request in request_iterator:
116            if request.HasField('file_by_filename'):
117                yield self._file_by_filename(request.file_by_filename)
118            elif request.HasField('file_containing_symbol'):
119                yield self._file_containing_symbol(
120                    request.file_containing_symbol)
121            elif request.HasField('file_containing_extension'):
122                yield self._file_containing_extension(
123                    request.file_containing_extension.containing_type,
124                    request.file_containing_extension.extension_number)
125            elif request.HasField('all_extension_numbers_of_type'):
126                yield self._all_extension_numbers_of_type(
127                    request.all_extension_numbers_of_type)
128            elif request.HasField('list_services'):
129                yield self._list_services()
130            else:
131                yield _reflection_pb2.ServerReflectionResponse(
132                    error_response=_reflection_pb2.ErrorResponse(
133                        error_code=grpc.StatusCode.INVALID_ARGUMENT.value[0],
134                        error_message=grpc.StatusCode.INVALID_ARGUMENT.value[1]
135                        .encode(),
136                    ))
137
138
139def enable_server_reflection(service_names, server, pool=None):
140    """Enables server reflection on a server.
141
142    Args:
143      service_names: Iterable of fully-qualified service names available.
144      server: grpc.Server to which reflection service will be added.
145      pool: DescriptorPool object to use (descriptor_pool.Default() if None).
146    """
147    _reflection_pb2_grpc.add_ServerReflectionServicer_to_server(
148        ReflectionServicer(service_names, pool=pool), server)
149