1 /*
2  *
3  * Copyright 2016 gRPC authors.
4  *
5  * Licensed under the Apache License, Version 2.0 (the "License");
6  * you may not use this file except in compliance with the License.
7  * You may obtain a copy of the License at
8  *
9  *     http://www.apache.org/licenses/LICENSE-2.0
10  *
11  * Unless required by applicable law or agreed to in writing, software
12  * distributed under the License is distributed on an "AS IS" BASIS,
13  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14  * See the License for the specific language governing permissions and
15  * limitations under the License.
16  *
17  */
18 
19 #include <unordered_set>
20 #include <vector>
21 
22 #include <grpcpp/grpcpp.h>
23 
24 #include "src/cpp/ext/proto_server_reflection.h"
25 
26 using grpc::Status;
27 using grpc::StatusCode;
28 using grpc::reflection::v1alpha::ErrorResponse;
29 using grpc::reflection::v1alpha::ExtensionNumberResponse;
30 using grpc::reflection::v1alpha::ExtensionRequest;
31 using grpc::reflection::v1alpha::FileDescriptorResponse;
32 using grpc::reflection::v1alpha::ListServiceResponse;
33 using grpc::reflection::v1alpha::ServerReflectionRequest;
34 using grpc::reflection::v1alpha::ServerReflectionResponse;
35 using grpc::reflection::v1alpha::ServiceResponse;
36 
37 namespace grpc {
38 
ProtoServerReflection()39 ProtoServerReflection::ProtoServerReflection()
40     : descriptor_pool_(protobuf::DescriptorPool::generated_pool()) {}
41 
SetServiceList(const std::vector<grpc::string> * services)42 void ProtoServerReflection::SetServiceList(
43     const std::vector<grpc::string>* services) {
44   services_ = services;
45 }
46 
ServerReflectionInfo(ServerContext * context,ServerReaderWriter<ServerReflectionResponse,ServerReflectionRequest> * stream)47 Status ProtoServerReflection::ServerReflectionInfo(
48     ServerContext* context,
49     ServerReaderWriter<ServerReflectionResponse, ServerReflectionRequest>*
50         stream) {
51   ServerReflectionRequest request;
52   ServerReflectionResponse response;
53   Status status;
54   while (stream->Read(&request)) {
55     switch (request.message_request_case()) {
56       case ServerReflectionRequest::MessageRequestCase::kFileByFilename:
57         status = GetFileByName(context, request.file_by_filename(), &response);
58         break;
59       case ServerReflectionRequest::MessageRequestCase::kFileContainingSymbol:
60         status = GetFileContainingSymbol(
61             context, request.file_containing_symbol(), &response);
62         break;
63       case ServerReflectionRequest::MessageRequestCase::
64           kFileContainingExtension:
65         status = GetFileContainingExtension(
66             context, &request.file_containing_extension(), &response);
67         break;
68       case ServerReflectionRequest::MessageRequestCase::
69           kAllExtensionNumbersOfType:
70         status = GetAllExtensionNumbers(
71             context, request.all_extension_numbers_of_type(),
72             response.mutable_all_extension_numbers_response());
73         break;
74       case ServerReflectionRequest::MessageRequestCase::kListServices:
75         status =
76             ListService(context, response.mutable_list_services_response());
77         break;
78       default:
79         status = Status(StatusCode::UNIMPLEMENTED, "");
80     }
81 
82     if (!status.ok()) {
83       FillErrorResponse(status, response.mutable_error_response());
84     }
85     response.set_valid_host(request.host());
86     response.set_allocated_original_request(
87         new ServerReflectionRequest(request));
88     stream->Write(response);
89   }
90 
91   return Status::OK;
92 }
93 
FillErrorResponse(const Status & status,ErrorResponse * error_response)94 void ProtoServerReflection::FillErrorResponse(const Status& status,
95                                               ErrorResponse* error_response) {
96   error_response->set_error_code(status.error_code());
97   error_response->set_error_message(status.error_message());
98 }
99 
ListService(ServerContext * context,ListServiceResponse * response)100 Status ProtoServerReflection::ListService(ServerContext* context,
101                                           ListServiceResponse* response) {
102   if (services_ == nullptr) {
103     return Status(StatusCode::NOT_FOUND, "Services not found.");
104   }
105   for (auto it = services_->begin(); it != services_->end(); ++it) {
106     ServiceResponse* service_response = response->add_service();
107     service_response->set_name(*it);
108   }
109   return Status::OK;
110 }
111 
GetFileByName(ServerContext * context,const grpc::string & filename,ServerReflectionResponse * response)112 Status ProtoServerReflection::GetFileByName(
113     ServerContext* context, const grpc::string& filename,
114     ServerReflectionResponse* response) {
115   if (descriptor_pool_ == nullptr) {
116     return Status::CANCELLED;
117   }
118 
119   const protobuf::FileDescriptor* file_desc =
120       descriptor_pool_->FindFileByName(filename);
121   if (file_desc == nullptr) {
122     return Status(StatusCode::NOT_FOUND, "File not found.");
123   }
124   std::unordered_set<grpc::string> seen_files;
125   FillFileDescriptorResponse(file_desc, response, &seen_files);
126   return Status::OK;
127 }
128 
GetFileContainingSymbol(ServerContext * context,const grpc::string & symbol,ServerReflectionResponse * response)129 Status ProtoServerReflection::GetFileContainingSymbol(
130     ServerContext* context, const grpc::string& symbol,
131     ServerReflectionResponse* response) {
132   if (descriptor_pool_ == nullptr) {
133     return Status::CANCELLED;
134   }
135 
136   const protobuf::FileDescriptor* file_desc =
137       descriptor_pool_->FindFileContainingSymbol(symbol);
138   if (file_desc == nullptr) {
139     return Status(StatusCode::NOT_FOUND, "Symbol not found.");
140   }
141   std::unordered_set<grpc::string> seen_files;
142   FillFileDescriptorResponse(file_desc, response, &seen_files);
143   return Status::OK;
144 }
145 
GetFileContainingExtension(ServerContext * context,const ExtensionRequest * request,ServerReflectionResponse * response)146 Status ProtoServerReflection::GetFileContainingExtension(
147     ServerContext* context, const ExtensionRequest* request,
148     ServerReflectionResponse* response) {
149   if (descriptor_pool_ == nullptr) {
150     return Status::CANCELLED;
151   }
152 
153   const protobuf::Descriptor* desc =
154       descriptor_pool_->FindMessageTypeByName(request->containing_type());
155   if (desc == nullptr) {
156     return Status(StatusCode::NOT_FOUND, "Type not found.");
157   }
158 
159   const protobuf::FieldDescriptor* field_desc =
160       descriptor_pool_->FindExtensionByNumber(desc,
161                                               request->extension_number());
162   if (field_desc == nullptr) {
163     return Status(StatusCode::NOT_FOUND, "Extension not found.");
164   }
165   std::unordered_set<grpc::string> seen_files;
166   FillFileDescriptorResponse(field_desc->file(), response, &seen_files);
167   return Status::OK;
168 }
169 
GetAllExtensionNumbers(ServerContext * context,const grpc::string & type,ExtensionNumberResponse * response)170 Status ProtoServerReflection::GetAllExtensionNumbers(
171     ServerContext* context, const grpc::string& type,
172     ExtensionNumberResponse* response) {
173   if (descriptor_pool_ == nullptr) {
174     return Status::CANCELLED;
175   }
176 
177   const protobuf::Descriptor* desc =
178       descriptor_pool_->FindMessageTypeByName(type);
179   if (desc == nullptr) {
180     return Status(StatusCode::NOT_FOUND, "Type not found.");
181   }
182 
183   std::vector<const protobuf::FieldDescriptor*> extensions;
184   descriptor_pool_->FindAllExtensions(desc, &extensions);
185   for (auto it = extensions.begin(); it != extensions.end(); it++) {
186     response->add_extension_number((*it)->number());
187   }
188   response->set_base_type_name(type);
189   return Status::OK;
190 }
191 
FillFileDescriptorResponse(const protobuf::FileDescriptor * file_desc,ServerReflectionResponse * response,std::unordered_set<grpc::string> * seen_files)192 void ProtoServerReflection::FillFileDescriptorResponse(
193     const protobuf::FileDescriptor* file_desc,
194     ServerReflectionResponse* response,
195     std::unordered_set<grpc::string>* seen_files) {
196   if (seen_files->find(file_desc->name()) != seen_files->end()) {
197     return;
198   }
199   seen_files->insert(file_desc->name());
200 
201   protobuf::FileDescriptorProto file_desc_proto;
202   grpc::string data;
203   file_desc->CopyTo(&file_desc_proto);
204   file_desc_proto.SerializeToString(&data);
205   response->mutable_file_descriptor_response()->add_file_descriptor_proto(data);
206 
207   for (int i = 0; i < file_desc->dependency_count(); ++i) {
208     FillFileDescriptorResponse(file_desc->dependency(i), response, seen_files);
209   }
210 }
211 
212 }  // namespace grpc
213