• Home
  • History
  • Annotate
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright 2016 The gRPC Authors
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *     http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 package io.grpc.protobuf.services;
18 
19 import static com.google.common.base.Preconditions.checkNotNull;
20 import static com.google.common.base.Preconditions.checkState;
21 
22 import com.google.protobuf.Descriptors.Descriptor;
23 import com.google.protobuf.Descriptors.FieldDescriptor;
24 import com.google.protobuf.Descriptors.FileDescriptor;
25 import com.google.protobuf.Descriptors.MethodDescriptor;
26 import com.google.protobuf.Descriptors.ServiceDescriptor;
27 import io.grpc.BindableService;
28 import io.grpc.ExperimentalApi;
29 import io.grpc.InternalNotifyOnServerBuild;
30 import io.grpc.Server;
31 import io.grpc.ServerServiceDefinition;
32 import io.grpc.Status;
33 import io.grpc.protobuf.ProtoFileDescriptorSupplier;
34 import io.grpc.reflection.v1alpha.ErrorResponse;
35 import io.grpc.reflection.v1alpha.ExtensionNumberResponse;
36 import io.grpc.reflection.v1alpha.ExtensionRequest;
37 import io.grpc.reflection.v1alpha.FileDescriptorResponse;
38 import io.grpc.reflection.v1alpha.ListServiceResponse;
39 import io.grpc.reflection.v1alpha.ServerReflectionGrpc;
40 import io.grpc.reflection.v1alpha.ServerReflectionRequest;
41 import io.grpc.reflection.v1alpha.ServerReflectionResponse;
42 import io.grpc.reflection.v1alpha.ServiceResponse;
43 import io.grpc.stub.ServerCallStreamObserver;
44 import io.grpc.stub.StreamObserver;
45 import java.util.ArrayDeque;
46 import java.util.Collections;
47 import java.util.HashMap;
48 import java.util.HashSet;
49 import java.util.List;
50 import java.util.Map;
51 import java.util.Queue;
52 import java.util.Set;
53 import javax.annotation.Nullable;
54 import javax.annotation.concurrent.GuardedBy;
55 
56 /**
57  * Provides a reflection service for Protobuf services (including the reflection service itself).
58  *
59  * <p>Separately tracks mutable and immutable services. Throws an exception if either group of
60  * services contains multiple Protobuf files with declarations of the same service, method, type, or
61  * extension.
62  */
63 @ExperimentalApi("https://github.com/grpc/grpc-java/issues/2222")
64 public final class ProtoReflectionService extends ServerReflectionGrpc.ServerReflectionImplBase
65     implements InternalNotifyOnServerBuild {
66 
67   private final Object lock = new Object();
68 
69   @GuardedBy("lock")
70   private ServerReflectionIndex serverReflectionIndex;
71 
72   private Server server;
73 
ProtoReflectionService()74   private ProtoReflectionService() {}
75 
newInstance()76   public static BindableService newInstance() {
77     return new ProtoReflectionService();
78   }
79 
80   /** Receives a reference to the server at build time. */
81   @Override
notifyOnBuild(Server server)82   public void notifyOnBuild(Server server) {
83     this.server = checkNotNull(server);
84   }
85 
86   /**
87    * Checks for updates to the server's mutable services and updates the index if any changes are
88    * detected. A change is any addition or removal in the set of file descriptors attached to the
89    * mutable services or a change in the service names.
90    *
91    * @return The (potentially updated) index.
92    */
updateIndexIfNecessary()93   private ServerReflectionIndex updateIndexIfNecessary() {
94     synchronized (lock) {
95       if (serverReflectionIndex == null) {
96         serverReflectionIndex =
97             new ServerReflectionIndex(server.getImmutableServices(), server.getMutableServices());
98         return serverReflectionIndex;
99       }
100 
101       Set<FileDescriptor> serverFileDescriptors = new HashSet<FileDescriptor>();
102       Set<String> serverServiceNames = new HashSet<String>();
103       List<ServerServiceDefinition> serverMutableServices = server.getMutableServices();
104       for (ServerServiceDefinition mutableService : serverMutableServices) {
105         io.grpc.ServiceDescriptor serviceDescriptor = mutableService.getServiceDescriptor();
106         if (serviceDescriptor.getSchemaDescriptor() instanceof ProtoFileDescriptorSupplier) {
107           String serviceName = serviceDescriptor.getName();
108           FileDescriptor fileDescriptor =
109               ((ProtoFileDescriptorSupplier) serviceDescriptor.getSchemaDescriptor())
110                   .getFileDescriptor();
111           serverFileDescriptors.add(fileDescriptor);
112           serverServiceNames.add(serviceName);
113         }
114       }
115 
116       // Replace the index if the underlying mutable services have changed. Check both the file
117       // descriptors and the service names, because one file descriptor can define multiple
118       // services.
119       FileDescriptorIndex mutableServicesIndex = serverReflectionIndex.getMutableServicesIndex();
120       if (!mutableServicesIndex.getServiceFileDescriptors().equals(serverFileDescriptors)
121           || !mutableServicesIndex.getServiceNames().equals(serverServiceNames)) {
122         serverReflectionIndex =
123             new ServerReflectionIndex(server.getImmutableServices(), serverMutableServices);
124       }
125 
126       return serverReflectionIndex;
127     }
128   }
129 
130   @Override
serverReflectionInfo( final StreamObserver<ServerReflectionResponse> responseObserver)131   public StreamObserver<ServerReflectionRequest> serverReflectionInfo(
132       final StreamObserver<ServerReflectionResponse> responseObserver) {
133     final ServerCallStreamObserver<ServerReflectionResponse> serverCallStreamObserver =
134         (ServerCallStreamObserver<ServerReflectionResponse>) responseObserver;
135     ProtoReflectionStreamObserver requestObserver =
136         new ProtoReflectionStreamObserver(updateIndexIfNecessary(), serverCallStreamObserver);
137     serverCallStreamObserver.setOnReadyHandler(requestObserver);
138     serverCallStreamObserver.disableAutoInboundFlowControl();
139     serverCallStreamObserver.request(1);
140     return requestObserver;
141   }
142 
143   private static class ProtoReflectionStreamObserver
144       implements Runnable, StreamObserver<ServerReflectionRequest> {
145     private final ServerReflectionIndex serverReflectionIndex;
146     private final ServerCallStreamObserver<ServerReflectionResponse> serverCallStreamObserver;
147 
148     private boolean closeAfterSend = false;
149     private ServerReflectionRequest request;
150 
ProtoReflectionStreamObserver( ServerReflectionIndex serverReflectionIndex, ServerCallStreamObserver<ServerReflectionResponse> serverCallStreamObserver)151     ProtoReflectionStreamObserver(
152         ServerReflectionIndex serverReflectionIndex,
153         ServerCallStreamObserver<ServerReflectionResponse> serverCallStreamObserver) {
154       this.serverReflectionIndex = serverReflectionIndex;
155       this.serverCallStreamObserver = checkNotNull(serverCallStreamObserver, "observer");
156     }
157 
158     @Override
run()159     public void run() {
160       if (request != null) {
161         handleReflectionRequest();
162       }
163     }
164 
165     @Override
onNext(ServerReflectionRequest request)166     public void onNext(ServerReflectionRequest request) {
167       checkState(this.request == null);
168       this.request = checkNotNull(request);
169       handleReflectionRequest();
170     }
171 
handleReflectionRequest()172     private void handleReflectionRequest() {
173       if (serverCallStreamObserver.isReady()) {
174         switch (request.getMessageRequestCase()) {
175           case FILE_BY_FILENAME:
176             getFileByName(request);
177             break;
178           case FILE_CONTAINING_SYMBOL:
179             getFileContainingSymbol(request);
180             break;
181           case FILE_CONTAINING_EXTENSION:
182             getFileByExtension(request);
183             break;
184           case ALL_EXTENSION_NUMBERS_OF_TYPE:
185             getAllExtensions(request);
186             break;
187           case LIST_SERVICES:
188             listServices(request);
189             break;
190           default:
191             sendErrorResponse(
192                 request,
193                 Status.Code.UNIMPLEMENTED,
194                 "not implemented " + request.getMessageRequestCase());
195         }
196         request = null;
197         if (closeAfterSend) {
198           serverCallStreamObserver.onCompleted();
199         } else {
200           serverCallStreamObserver.request(1);
201         }
202       }
203     }
204 
205     @Override
onCompleted()206     public void onCompleted() {
207       if (request != null) {
208         closeAfterSend = true;
209       } else {
210         serverCallStreamObserver.onCompleted();
211       }
212     }
213 
214     @Override
onError(Throwable cause)215     public void onError(Throwable cause) {
216       serverCallStreamObserver.onError(cause);
217     }
218 
getFileByName(ServerReflectionRequest request)219     private void getFileByName(ServerReflectionRequest request) {
220       String name = request.getFileByFilename();
221       FileDescriptor fd = serverReflectionIndex.getFileDescriptorByName(name);
222       if (fd != null) {
223         serverCallStreamObserver.onNext(createServerReflectionResponse(request, fd));
224       } else {
225         sendErrorResponse(request, Status.Code.NOT_FOUND, "File not found.");
226       }
227     }
228 
getFileContainingSymbol(ServerReflectionRequest request)229     private void getFileContainingSymbol(ServerReflectionRequest request) {
230       String symbol = request.getFileContainingSymbol();
231       FileDescriptor fd = serverReflectionIndex.getFileDescriptorBySymbol(symbol);
232       if (fd != null) {
233         serverCallStreamObserver.onNext(createServerReflectionResponse(request, fd));
234       } else {
235         sendErrorResponse(request, Status.Code.NOT_FOUND, "Symbol not found.");
236       }
237     }
238 
getFileByExtension(ServerReflectionRequest request)239     private void getFileByExtension(ServerReflectionRequest request) {
240       ExtensionRequest extensionRequest = request.getFileContainingExtension();
241       String type = extensionRequest.getContainingType();
242       int extension = extensionRequest.getExtensionNumber();
243       FileDescriptor fd =
244           serverReflectionIndex.getFileDescriptorByExtensionAndNumber(type, extension);
245       if (fd != null) {
246         serverCallStreamObserver.onNext(createServerReflectionResponse(request, fd));
247       } else {
248         sendErrorResponse(request, Status.Code.NOT_FOUND, "Extension not found.");
249       }
250     }
251 
getAllExtensions(ServerReflectionRequest request)252     private void getAllExtensions(ServerReflectionRequest request) {
253       String type = request.getAllExtensionNumbersOfType();
254       Set<Integer> extensions = serverReflectionIndex.getExtensionNumbersOfType(type);
255       if (extensions != null) {
256         ExtensionNumberResponse.Builder builder =
257             ExtensionNumberResponse.newBuilder()
258                 .setBaseTypeName(type)
259                 .addAllExtensionNumber(extensions);
260         serverCallStreamObserver.onNext(
261             ServerReflectionResponse.newBuilder()
262                 .setValidHost(request.getHost())
263                 .setOriginalRequest(request)
264                 .setAllExtensionNumbersResponse(builder)
265                 .build());
266       } else {
267         sendErrorResponse(request, Status.Code.NOT_FOUND, "Type not found.");
268       }
269     }
270 
listServices(ServerReflectionRequest request)271     private void listServices(ServerReflectionRequest request) {
272       ListServiceResponse.Builder builder = ListServiceResponse.newBuilder();
273       for (String serviceName : serverReflectionIndex.getServiceNames()) {
274         builder.addService(ServiceResponse.newBuilder().setName(serviceName));
275       }
276       serverCallStreamObserver.onNext(
277           ServerReflectionResponse.newBuilder()
278               .setValidHost(request.getHost())
279               .setOriginalRequest(request)
280               .setListServicesResponse(builder)
281               .build());
282     }
283 
sendErrorResponse( ServerReflectionRequest request, Status.Code code, String message)284     private void sendErrorResponse(
285         ServerReflectionRequest request, Status.Code code, String message) {
286       ServerReflectionResponse response =
287           ServerReflectionResponse.newBuilder()
288               .setValidHost(request.getHost())
289               .setOriginalRequest(request)
290               .setErrorResponse(
291                   ErrorResponse.newBuilder()
292                       .setErrorCode(code.value())
293                       .setErrorMessage(message))
294               .build();
295       serverCallStreamObserver.onNext(response);
296     }
297 
createServerReflectionResponse( ServerReflectionRequest request, FileDescriptor fd)298     private ServerReflectionResponse createServerReflectionResponse(
299         ServerReflectionRequest request, FileDescriptor fd) {
300       FileDescriptorResponse.Builder fdRBuilder = FileDescriptorResponse.newBuilder();
301 
302       Set<String> seenFiles = new HashSet<String>();
303       Queue<FileDescriptor> frontier = new ArrayDeque<FileDescriptor>();
304       seenFiles.add(fd.getName());
305       frontier.add(fd);
306       while (!frontier.isEmpty()) {
307         FileDescriptor nextFd = frontier.remove();
308         fdRBuilder.addFileDescriptorProto(nextFd.toProto().toByteString());
309         for (FileDescriptor dependencyFd : nextFd.getDependencies()) {
310           if (!seenFiles.contains(dependencyFd.getName())) {
311             seenFiles.add(dependencyFd.getName());
312             frontier.add(dependencyFd);
313           }
314         }
315       }
316       return ServerReflectionResponse.newBuilder()
317           .setValidHost(request.getHost())
318           .setOriginalRequest(request)
319           .setFileDescriptorResponse(fdRBuilder)
320           .build();
321     }
322   }
323 
324   /**
325    * Indexes the server's services and allows lookups of file descriptors by filename, symbol, type,
326    * and extension number.
327    *
328    * <p>Internally, this stores separate indices for the immutable and mutable services. When
329    * queried, the immutable service index is checked for a matching value. Only if there is no match
330    * in the immutable service index are the mutable services checked.
331    */
332   private static final class ServerReflectionIndex {
333     private final FileDescriptorIndex immutableServicesIndex;
334     private final FileDescriptorIndex mutableServicesIndex;
335 
ServerReflectionIndex( List<ServerServiceDefinition> immutableServices, List<ServerServiceDefinition> mutableServices)336     public ServerReflectionIndex(
337         List<ServerServiceDefinition> immutableServices,
338         List<ServerServiceDefinition> mutableServices) {
339       immutableServicesIndex = new FileDescriptorIndex(immutableServices);
340       mutableServicesIndex = new FileDescriptorIndex(mutableServices);
341     }
342 
getMutableServicesIndex()343     private FileDescriptorIndex getMutableServicesIndex() {
344       return mutableServicesIndex;
345     }
346 
getServiceNames()347     private Set<String> getServiceNames() {
348       Set<String> immutableServiceNames = immutableServicesIndex.getServiceNames();
349       Set<String> mutableServiceNames = mutableServicesIndex.getServiceNames();
350       Set<String> serviceNames =
351           new HashSet<String>(immutableServiceNames.size() + mutableServiceNames.size());
352       serviceNames.addAll(immutableServiceNames);
353       serviceNames.addAll(mutableServiceNames);
354       return serviceNames;
355     }
356 
357     @Nullable
getFileDescriptorByName(String name)358     private FileDescriptor getFileDescriptorByName(String name) {
359       FileDescriptor fd = immutableServicesIndex.getFileDescriptorByName(name);
360       if (fd == null) {
361         fd = mutableServicesIndex.getFileDescriptorByName(name);
362       }
363       return fd;
364     }
365 
366     @Nullable
getFileDescriptorBySymbol(String symbol)367     private FileDescriptor getFileDescriptorBySymbol(String symbol) {
368       FileDescriptor fd = immutableServicesIndex.getFileDescriptorBySymbol(symbol);
369       if (fd == null) {
370         fd = mutableServicesIndex.getFileDescriptorBySymbol(symbol);
371       }
372       return fd;
373     }
374 
375     @Nullable
getFileDescriptorByExtensionAndNumber(String type, int extension)376     private FileDescriptor getFileDescriptorByExtensionAndNumber(String type, int extension) {
377       FileDescriptor fd =
378           immutableServicesIndex.getFileDescriptorByExtensionAndNumber(type, extension);
379       if (fd == null) {
380         fd = mutableServicesIndex.getFileDescriptorByExtensionAndNumber(type, extension);
381       }
382       return fd;
383     }
384 
385     @Nullable
getExtensionNumbersOfType(String type)386     private Set<Integer> getExtensionNumbersOfType(String type) {
387       Set<Integer> extensionNumbers = immutableServicesIndex.getExtensionNumbersOfType(type);
388       if (extensionNumbers == null) {
389         extensionNumbers = mutableServicesIndex.getExtensionNumbersOfType(type);
390       }
391       return extensionNumbers;
392     }
393   }
394 
395   /**
396    * Provides a set of methods for answering reflection queries for the file descriptors underlying
397    * a set of services. Used by {@link ServerReflectionIndex} to separately index immutable and
398    * mutable services.
399    */
400   private static final class FileDescriptorIndex {
401     private final Set<String> serviceNames = new HashSet<String>();
402     private final Set<FileDescriptor> serviceFileDescriptors = new HashSet<FileDescriptor>();
403     private final Map<String, FileDescriptor> fileDescriptorsByName =
404         new HashMap<String, FileDescriptor>();
405     private final Map<String, FileDescriptor> fileDescriptorsBySymbol =
406         new HashMap<String, FileDescriptor>();
407     private final Map<String, Map<Integer, FileDescriptor>> fileDescriptorsByExtensionAndNumber =
408         new HashMap<String, Map<Integer, FileDescriptor>>();
409 
FileDescriptorIndex(List<ServerServiceDefinition> services)410     FileDescriptorIndex(List<ServerServiceDefinition> services) {
411       Queue<FileDescriptor> fileDescriptorsToProcess = new ArrayDeque<FileDescriptor>();
412       Set<String> seenFiles = new HashSet<String>();
413       for (ServerServiceDefinition service : services) {
414         io.grpc.ServiceDescriptor serviceDescriptor = service.getServiceDescriptor();
415         if (serviceDescriptor.getSchemaDescriptor() instanceof ProtoFileDescriptorSupplier) {
416           FileDescriptor fileDescriptor =
417               ((ProtoFileDescriptorSupplier) serviceDescriptor.getSchemaDescriptor())
418                   .getFileDescriptor();
419           String serviceName = serviceDescriptor.getName();
420           checkState(
421               !serviceNames.contains(serviceName), "Service already defined: %s", serviceName);
422           serviceFileDescriptors.add(fileDescriptor);
423           serviceNames.add(serviceName);
424           if (!seenFiles.contains(fileDescriptor.getName())) {
425             seenFiles.add(fileDescriptor.getName());
426             fileDescriptorsToProcess.add(fileDescriptor);
427           }
428         }
429       }
430 
431       while (!fileDescriptorsToProcess.isEmpty()) {
432         FileDescriptor currentFd = fileDescriptorsToProcess.remove();
433         processFileDescriptor(currentFd);
434         for (FileDescriptor dependencyFd : currentFd.getDependencies()) {
435           if (!seenFiles.contains(dependencyFd.getName())) {
436             seenFiles.add(dependencyFd.getName());
437             fileDescriptorsToProcess.add(dependencyFd);
438           }
439         }
440       }
441     }
442 
443     /**
444      * Returns the file descriptors for the indexed services, but not their dependencies. This is
445      * used to check if the server's mutable services have changed.
446      */
getServiceFileDescriptors()447     private Set<FileDescriptor> getServiceFileDescriptors() {
448       return Collections.unmodifiableSet(serviceFileDescriptors);
449     }
450 
getServiceNames()451     private Set<String> getServiceNames() {
452       return Collections.unmodifiableSet(serviceNames);
453     }
454 
455     @Nullable
getFileDescriptorByName(String name)456     private FileDescriptor getFileDescriptorByName(String name) {
457       return fileDescriptorsByName.get(name);
458     }
459 
460     @Nullable
getFileDescriptorBySymbol(String symbol)461     private FileDescriptor getFileDescriptorBySymbol(String symbol) {
462       return fileDescriptorsBySymbol.get(symbol);
463     }
464 
465     @Nullable
getFileDescriptorByExtensionAndNumber(String type, int number)466     private FileDescriptor getFileDescriptorByExtensionAndNumber(String type, int number) {
467       if (fileDescriptorsByExtensionAndNumber.containsKey(type)) {
468         return fileDescriptorsByExtensionAndNumber.get(type).get(number);
469       }
470       return null;
471     }
472 
473     @Nullable
getExtensionNumbersOfType(String type)474     private Set<Integer> getExtensionNumbersOfType(String type) {
475       if (fileDescriptorsByExtensionAndNumber.containsKey(type)) {
476         return Collections.unmodifiableSet(fileDescriptorsByExtensionAndNumber.get(type).keySet());
477       }
478       return null;
479     }
480 
processFileDescriptor(FileDescriptor fd)481     private void processFileDescriptor(FileDescriptor fd) {
482       String fdName = fd.getName();
483       checkState(!fileDescriptorsByName.containsKey(fdName), "File name already used: %s", fdName);
484       fileDescriptorsByName.put(fdName, fd);
485       for (ServiceDescriptor service : fd.getServices()) {
486         processService(service, fd);
487       }
488       for (Descriptor type : fd.getMessageTypes()) {
489         processType(type, fd);
490       }
491       for (FieldDescriptor extension : fd.getExtensions()) {
492         processExtension(extension, fd);
493       }
494     }
495 
processService(ServiceDescriptor service, FileDescriptor fd)496     private void processService(ServiceDescriptor service, FileDescriptor fd) {
497       String serviceName = service.getFullName();
498       checkState(
499           !fileDescriptorsBySymbol.containsKey(serviceName),
500           "Service already defined: %s",
501           serviceName);
502       fileDescriptorsBySymbol.put(serviceName, fd);
503       for (MethodDescriptor method : service.getMethods()) {
504         String methodName = method.getFullName();
505         checkState(
506             !fileDescriptorsBySymbol.containsKey(methodName),
507             "Method already defined: %s",
508             methodName);
509         fileDescriptorsBySymbol.put(methodName, fd);
510       }
511     }
512 
processType(Descriptor type, FileDescriptor fd)513     private void processType(Descriptor type, FileDescriptor fd) {
514       String typeName = type.getFullName();
515       checkState(
516           !fileDescriptorsBySymbol.containsKey(typeName), "Type already defined: %s", typeName);
517       fileDescriptorsBySymbol.put(typeName, fd);
518       for (FieldDescriptor extension : type.getExtensions()) {
519         processExtension(extension, fd);
520       }
521       for (Descriptor nestedType : type.getNestedTypes()) {
522         processType(nestedType, fd);
523       }
524     }
525 
processExtension(FieldDescriptor extension, FileDescriptor fd)526     private void processExtension(FieldDescriptor extension, FileDescriptor fd) {
527       String extensionName = extension.getContainingType().getFullName();
528       int extensionNumber = extension.getNumber();
529       if (!fileDescriptorsByExtensionAndNumber.containsKey(extensionName)) {
530         fileDescriptorsByExtensionAndNumber.put(
531             extensionName, new HashMap<Integer, FileDescriptor>());
532       }
533       checkState(
534           !fileDescriptorsByExtensionAndNumber.get(extensionName).containsKey(extensionNumber),
535           "Extension name and number already defined: %s, %s",
536           extensionName,
537           extensionNumber);
538       fileDescriptorsByExtensionAndNumber.get(extensionName).put(extensionNumber, fd);
539     }
540   }
541 }
542