1 /*
2  *
3  * Copyright 2016 gRPC authors.
4  *
5  * Licensed under the Apache License, Version 2.0 (the "License");
6  * you may not use this file except in compliance with the License.
7  * You may obtain a copy of the License at
8  *
9  *     http://www.apache.org/licenses/LICENSE-2.0
10  *
11  * Unless required by applicable law or agreed to in writing, software
12  * distributed under the License is distributed on an "AS IS" BASIS,
13  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14  * See the License for the specific language governing permissions and
15  * limitations under the License.
16  *
17  */
18 
19 #include "test/cpp/util/proto_reflection_descriptor_database.h"
20 
21 #include <vector>
22 
23 #include <grpc/support/log.h>
24 
25 using grpc::reflection::v1alpha::ErrorResponse;
26 using grpc::reflection::v1alpha::ListServiceResponse;
27 using grpc::reflection::v1alpha::ServerReflection;
28 using grpc::reflection::v1alpha::ServerReflectionRequest;
29 using grpc::reflection::v1alpha::ServerReflectionResponse;
30 
31 namespace grpc {
32 
ProtoReflectionDescriptorDatabase(std::unique_ptr<ServerReflection::Stub> stub)33 ProtoReflectionDescriptorDatabase::ProtoReflectionDescriptorDatabase(
34     std::unique_ptr<ServerReflection::Stub> stub)
35     : stub_(std::move(stub)) {}
36 
ProtoReflectionDescriptorDatabase(const std::shared_ptr<grpc::Channel> & channel)37 ProtoReflectionDescriptorDatabase::ProtoReflectionDescriptorDatabase(
38     const std::shared_ptr<grpc::Channel>& channel)
39     : stub_(ServerReflection::NewStub(channel)) {}
40 
~ProtoReflectionDescriptorDatabase()41 ProtoReflectionDescriptorDatabase::~ProtoReflectionDescriptorDatabase() {
42   if (stream_) {
43     stream_->WritesDone();
44     Status status = stream_->Finish();
45     if (!status.ok()) {
46       if (status.error_code() == StatusCode::UNIMPLEMENTED) {
47         gpr_log(GPR_INFO,
48                 "Reflection request not implemented; "
49                 "is the ServerReflection service enabled?");
50       }
51       gpr_log(GPR_INFO,
52               "ServerReflectionInfo rpc failed. Error code: %d, details: %s",
53               static_cast<int>(status.error_code()),
54               status.error_message().c_str());
55     }
56   }
57 }
58 
FindFileByName(const string & filename,protobuf::FileDescriptorProto * output)59 bool ProtoReflectionDescriptorDatabase::FindFileByName(
60     const string& filename, protobuf::FileDescriptorProto* output) {
61   if (cached_db_.FindFileByName(filename, output)) {
62     return true;
63   }
64 
65   if (known_files_.find(filename) != known_files_.end()) {
66     return false;
67   }
68 
69   ServerReflectionRequest request;
70   request.set_file_by_filename(filename);
71   ServerReflectionResponse response;
72 
73   if (!DoOneRequest(request, response)) {
74     return false;
75   }
76 
77   if (response.message_response_case() ==
78       ServerReflectionResponse::MessageResponseCase::kFileDescriptorResponse) {
79     AddFileFromResponse(response.file_descriptor_response());
80   } else if (response.message_response_case() ==
81              ServerReflectionResponse::MessageResponseCase::kErrorResponse) {
82     const ErrorResponse& error = response.error_response();
83     if (error.error_code() == StatusCode::NOT_FOUND) {
84       gpr_log(GPR_INFO, "NOT_FOUND from server for FindFileByName(%s)",
85               filename.c_str());
86     } else {
87       gpr_log(GPR_INFO,
88               "Error on FindFileByName(%s)\n\tError code: %d\n"
89               "\tError Message: %s",
90               filename.c_str(), error.error_code(),
91               error.error_message().c_str());
92     }
93   } else {
94     gpr_log(
95         GPR_INFO,
96         "Error on FindFileByName(%s) response type\n"
97         "\tExpecting: %d\n\tReceived: %d",
98         filename.c_str(),
99         ServerReflectionResponse::MessageResponseCase::kFileDescriptorResponse,
100         response.message_response_case());
101   }
102 
103   return cached_db_.FindFileByName(filename, output);
104 }
105 
FindFileContainingSymbol(const string & symbol_name,protobuf::FileDescriptorProto * output)106 bool ProtoReflectionDescriptorDatabase::FindFileContainingSymbol(
107     const string& symbol_name, protobuf::FileDescriptorProto* output) {
108   if (cached_db_.FindFileContainingSymbol(symbol_name, output)) {
109     return true;
110   }
111 
112   if (missing_symbols_.find(symbol_name) != missing_symbols_.end()) {
113     return false;
114   }
115 
116   ServerReflectionRequest request;
117   request.set_file_containing_symbol(symbol_name);
118   ServerReflectionResponse response;
119 
120   if (!DoOneRequest(request, response)) {
121     return false;
122   }
123 
124   if (response.message_response_case() ==
125       ServerReflectionResponse::MessageResponseCase::kFileDescriptorResponse) {
126     AddFileFromResponse(response.file_descriptor_response());
127   } else if (response.message_response_case() ==
128              ServerReflectionResponse::MessageResponseCase::kErrorResponse) {
129     const ErrorResponse& error = response.error_response();
130     if (error.error_code() == StatusCode::NOT_FOUND) {
131       missing_symbols_.insert(symbol_name);
132       gpr_log(GPR_INFO,
133               "NOT_FOUND from server for FindFileContainingSymbol(%s)",
134               symbol_name.c_str());
135     } else {
136       gpr_log(GPR_INFO,
137               "Error on FindFileContainingSymbol(%s)\n"
138               "\tError code: %d\n\tError Message: %s",
139               symbol_name.c_str(), error.error_code(),
140               error.error_message().c_str());
141     }
142   } else {
143     gpr_log(
144         GPR_INFO,
145         "Error on FindFileContainingSymbol(%s) response type\n"
146         "\tExpecting: %d\n\tReceived: %d",
147         symbol_name.c_str(),
148         ServerReflectionResponse::MessageResponseCase::kFileDescriptorResponse,
149         response.message_response_case());
150   }
151   return cached_db_.FindFileContainingSymbol(symbol_name, output);
152 }
153 
FindFileContainingExtension(const string & containing_type,int field_number,protobuf::FileDescriptorProto * output)154 bool ProtoReflectionDescriptorDatabase::FindFileContainingExtension(
155     const string& containing_type, int field_number,
156     protobuf::FileDescriptorProto* output) {
157   if (cached_db_.FindFileContainingExtension(containing_type, field_number,
158                                              output)) {
159     return true;
160   }
161 
162   if (missing_extensions_.find(containing_type) != missing_extensions_.end() &&
163       missing_extensions_[containing_type].find(field_number) !=
164           missing_extensions_[containing_type].end()) {
165     gpr_log(GPR_INFO, "nested map.");
166     return false;
167   }
168 
169   ServerReflectionRequest request;
170   request.mutable_file_containing_extension()->set_containing_type(
171       containing_type);
172   request.mutable_file_containing_extension()->set_extension_number(
173       field_number);
174   ServerReflectionResponse response;
175 
176   if (!DoOneRequest(request, response)) {
177     return false;
178   }
179 
180   if (response.message_response_case() ==
181       ServerReflectionResponse::MessageResponseCase::kFileDescriptorResponse) {
182     AddFileFromResponse(response.file_descriptor_response());
183   } else if (response.message_response_case() ==
184              ServerReflectionResponse::MessageResponseCase::kErrorResponse) {
185     const ErrorResponse& error = response.error_response();
186     if (error.error_code() == StatusCode::NOT_FOUND) {
187       if (missing_extensions_.find(containing_type) ==
188           missing_extensions_.end()) {
189         missing_extensions_[containing_type] = {};
190       }
191       missing_extensions_[containing_type].insert(field_number);
192       gpr_log(GPR_INFO,
193               "NOT_FOUND from server for FindFileContainingExtension(%s, %d)",
194               containing_type.c_str(), field_number);
195     } else {
196       gpr_log(GPR_INFO,
197               "Error on FindFileContainingExtension(%s, %d)\n"
198               "\tError code: %d\n\tError Message: %s",
199               containing_type.c_str(), field_number, error.error_code(),
200               error.error_message().c_str());
201     }
202   } else {
203     gpr_log(
204         GPR_INFO,
205         "Error on FindFileContainingExtension(%s, %d) response type\n"
206         "\tExpecting: %d\n\tReceived: %d",
207         containing_type.c_str(), field_number,
208         ServerReflectionResponse::MessageResponseCase::kFileDescriptorResponse,
209         response.message_response_case());
210   }
211 
212   return cached_db_.FindFileContainingExtension(containing_type, field_number,
213                                                 output);
214 }
215 
FindAllExtensionNumbers(const string & extendee_type,std::vector<int> * output)216 bool ProtoReflectionDescriptorDatabase::FindAllExtensionNumbers(
217     const string& extendee_type, std::vector<int>* output) {
218   if (cached_extension_numbers_.find(extendee_type) !=
219       cached_extension_numbers_.end()) {
220     *output = cached_extension_numbers_[extendee_type];
221     return true;
222   }
223 
224   ServerReflectionRequest request;
225   request.set_all_extension_numbers_of_type(extendee_type);
226   ServerReflectionResponse response;
227 
228   if (!DoOneRequest(request, response)) {
229     return false;
230   }
231 
232   if (response.message_response_case() ==
233       ServerReflectionResponse::MessageResponseCase::
234           kAllExtensionNumbersResponse) {
235     auto number = response.all_extension_numbers_response().extension_number();
236     *output = std::vector<int>(number.begin(), number.end());
237     cached_extension_numbers_[extendee_type] = *output;
238     return true;
239   } else if (response.message_response_case() ==
240              ServerReflectionResponse::MessageResponseCase::kErrorResponse) {
241     const ErrorResponse& error = response.error_response();
242     if (error.error_code() == StatusCode::NOT_FOUND) {
243       gpr_log(GPR_INFO, "NOT_FOUND from server for FindAllExtensionNumbers(%s)",
244               extendee_type.c_str());
245     } else {
246       gpr_log(GPR_INFO,
247               "Error on FindAllExtensionNumbersExtension(%s)\n"
248               "\tError code: %d\n\tError Message: %s",
249               extendee_type.c_str(), error.error_code(),
250               error.error_message().c_str());
251     }
252   }
253   return false;
254 }
255 
GetServices(std::vector<grpc::string> * output)256 bool ProtoReflectionDescriptorDatabase::GetServices(
257     std::vector<grpc::string>* output) {
258   ServerReflectionRequest request;
259   request.set_list_services("");
260   ServerReflectionResponse response;
261 
262   if (!DoOneRequest(request, response)) {
263     return false;
264   }
265 
266   if (response.message_response_case() ==
267       ServerReflectionResponse::MessageResponseCase::kListServicesResponse) {
268     const ListServiceResponse& ls_response = response.list_services_response();
269     for (int i = 0; i < ls_response.service_size(); ++i) {
270       (*output).push_back(ls_response.service(i).name());
271     }
272     return true;
273   } else if (response.message_response_case() ==
274              ServerReflectionResponse::MessageResponseCase::kErrorResponse) {
275     const ErrorResponse& error = response.error_response();
276     gpr_log(GPR_INFO,
277             "Error on GetServices()\n\tError code: %d\n"
278             "\tError Message: %s",
279             error.error_code(), error.error_message().c_str());
280   } else {
281     gpr_log(
282         GPR_INFO,
283         "Error on GetServices() response type\n\tExpecting: %d\n\tReceived: %d",
284         ServerReflectionResponse::MessageResponseCase::kListServicesResponse,
285         response.message_response_case());
286   }
287   return false;
288 }
289 
290 const protobuf::FileDescriptorProto
ParseFileDescriptorProtoResponse(const grpc::string & byte_fd_proto)291 ProtoReflectionDescriptorDatabase::ParseFileDescriptorProtoResponse(
292     const grpc::string& byte_fd_proto) {
293   protobuf::FileDescriptorProto file_desc_proto;
294   file_desc_proto.ParseFromString(byte_fd_proto);
295   return file_desc_proto;
296 }
297 
AddFileFromResponse(const grpc::reflection::v1alpha::FileDescriptorResponse & response)298 void ProtoReflectionDescriptorDatabase::AddFileFromResponse(
299     const grpc::reflection::v1alpha::FileDescriptorResponse& response) {
300   for (int i = 0; i < response.file_descriptor_proto_size(); ++i) {
301     const protobuf::FileDescriptorProto file_proto =
302         ParseFileDescriptorProtoResponse(response.file_descriptor_proto(i));
303     if (known_files_.find(file_proto.name()) == known_files_.end()) {
304       known_files_.insert(file_proto.name());
305       cached_db_.Add(file_proto);
306     }
307   }
308 }
309 
310 const std::shared_ptr<ProtoReflectionDescriptorDatabase::ClientStream>
GetStream()311 ProtoReflectionDescriptorDatabase::GetStream() {
312   if (!stream_) {
313     stream_ = stub_->ServerReflectionInfo(&ctx_);
314   }
315   return stream_;
316 }
317 
DoOneRequest(const ServerReflectionRequest & request,ServerReflectionResponse & response)318 bool ProtoReflectionDescriptorDatabase::DoOneRequest(
319     const ServerReflectionRequest& request,
320     ServerReflectionResponse& response) {
321   bool success = false;
322   stream_mutex_.lock();
323   if (GetStream()->Write(request) && GetStream()->Read(&response)) {
324     success = true;
325   }
326   stream_mutex_.unlock();
327   return success;
328 }
329 
330 }  // namespace grpc
331