1 /* 2 * Copyright 2016 The gRPC Authors 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 package io.grpc.protobuf.services; 18 19 import static com.google.common.base.Preconditions.checkNotNull; 20 import static com.google.common.base.Preconditions.checkState; 21 22 import com.google.protobuf.Descriptors.Descriptor; 23 import com.google.protobuf.Descriptors.FieldDescriptor; 24 import com.google.protobuf.Descriptors.FileDescriptor; 25 import com.google.protobuf.Descriptors.MethodDescriptor; 26 import com.google.protobuf.Descriptors.ServiceDescriptor; 27 import io.grpc.BindableService; 28 import io.grpc.ExperimentalApi; 29 import io.grpc.InternalNotifyOnServerBuild; 30 import io.grpc.Server; 31 import io.grpc.ServerServiceDefinition; 32 import io.grpc.Status; 33 import io.grpc.protobuf.ProtoFileDescriptorSupplier; 34 import io.grpc.reflection.v1alpha.ErrorResponse; 35 import io.grpc.reflection.v1alpha.ExtensionNumberResponse; 36 import io.grpc.reflection.v1alpha.ExtensionRequest; 37 import io.grpc.reflection.v1alpha.FileDescriptorResponse; 38 import io.grpc.reflection.v1alpha.ListServiceResponse; 39 import io.grpc.reflection.v1alpha.ServerReflectionGrpc; 40 import io.grpc.reflection.v1alpha.ServerReflectionRequest; 41 import io.grpc.reflection.v1alpha.ServerReflectionResponse; 42 import io.grpc.reflection.v1alpha.ServiceResponse; 43 import io.grpc.stub.ServerCallStreamObserver; 44 import io.grpc.stub.StreamObserver; 45 import java.util.ArrayDeque; 46 import java.util.Collections; 47 import java.util.HashMap; 48 import java.util.HashSet; 49 import java.util.List; 50 import java.util.Map; 51 import java.util.Queue; 52 import java.util.Set; 53 import javax.annotation.Nullable; 54 import javax.annotation.concurrent.GuardedBy; 55 56 /** 57 * Provides a reflection service for Protobuf services (including the reflection service itself). 58 * 59 * <p>Separately tracks mutable and immutable services. Throws an exception if either group of 60 * services contains multiple Protobuf files with declarations of the same service, method, type, or 61 * extension. 62 */ 63 @ExperimentalApi("https://github.com/grpc/grpc-java/issues/2222") 64 public final class ProtoReflectionService extends ServerReflectionGrpc.ServerReflectionImplBase 65 implements InternalNotifyOnServerBuild { 66 67 private final Object lock = new Object(); 68 69 @GuardedBy("lock") 70 private ServerReflectionIndex serverReflectionIndex; 71 72 private Server server; 73 ProtoReflectionService()74 private ProtoReflectionService() {} 75 newInstance()76 public static BindableService newInstance() { 77 return new ProtoReflectionService(); 78 } 79 80 /** Receives a reference to the server at build time. */ 81 @Override notifyOnBuild(Server server)82 public void notifyOnBuild(Server server) { 83 this.server = checkNotNull(server); 84 } 85 86 /** 87 * Checks for updates to the server's mutable services and updates the index if any changes are 88 * detected. A change is any addition or removal in the set of file descriptors attached to the 89 * mutable services or a change in the service names. 90 * 91 * @return The (potentially updated) index. 92 */ updateIndexIfNecessary()93 private ServerReflectionIndex updateIndexIfNecessary() { 94 synchronized (lock) { 95 if (serverReflectionIndex == null) { 96 serverReflectionIndex = 97 new ServerReflectionIndex(server.getImmutableServices(), server.getMutableServices()); 98 return serverReflectionIndex; 99 } 100 101 Set<FileDescriptor> serverFileDescriptors = new HashSet<FileDescriptor>(); 102 Set<String> serverServiceNames = new HashSet<String>(); 103 List<ServerServiceDefinition> serverMutableServices = server.getMutableServices(); 104 for (ServerServiceDefinition mutableService : serverMutableServices) { 105 io.grpc.ServiceDescriptor serviceDescriptor = mutableService.getServiceDescriptor(); 106 if (serviceDescriptor.getSchemaDescriptor() instanceof ProtoFileDescriptorSupplier) { 107 String serviceName = serviceDescriptor.getName(); 108 FileDescriptor fileDescriptor = 109 ((ProtoFileDescriptorSupplier) serviceDescriptor.getSchemaDescriptor()) 110 .getFileDescriptor(); 111 serverFileDescriptors.add(fileDescriptor); 112 serverServiceNames.add(serviceName); 113 } 114 } 115 116 // Replace the index if the underlying mutable services have changed. Check both the file 117 // descriptors and the service names, because one file descriptor can define multiple 118 // services. 119 FileDescriptorIndex mutableServicesIndex = serverReflectionIndex.getMutableServicesIndex(); 120 if (!mutableServicesIndex.getServiceFileDescriptors().equals(serverFileDescriptors) 121 || !mutableServicesIndex.getServiceNames().equals(serverServiceNames)) { 122 serverReflectionIndex = 123 new ServerReflectionIndex(server.getImmutableServices(), serverMutableServices); 124 } 125 126 return serverReflectionIndex; 127 } 128 } 129 130 @Override serverReflectionInfo( final StreamObserver<ServerReflectionResponse> responseObserver)131 public StreamObserver<ServerReflectionRequest> serverReflectionInfo( 132 final StreamObserver<ServerReflectionResponse> responseObserver) { 133 final ServerCallStreamObserver<ServerReflectionResponse> serverCallStreamObserver = 134 (ServerCallStreamObserver<ServerReflectionResponse>) responseObserver; 135 ProtoReflectionStreamObserver requestObserver = 136 new ProtoReflectionStreamObserver(updateIndexIfNecessary(), serverCallStreamObserver); 137 serverCallStreamObserver.setOnReadyHandler(requestObserver); 138 serverCallStreamObserver.disableAutoInboundFlowControl(); 139 serverCallStreamObserver.request(1); 140 return requestObserver; 141 } 142 143 private static class ProtoReflectionStreamObserver 144 implements Runnable, StreamObserver<ServerReflectionRequest> { 145 private final ServerReflectionIndex serverReflectionIndex; 146 private final ServerCallStreamObserver<ServerReflectionResponse> serverCallStreamObserver; 147 148 private boolean closeAfterSend = false; 149 private ServerReflectionRequest request; 150 ProtoReflectionStreamObserver( ServerReflectionIndex serverReflectionIndex, ServerCallStreamObserver<ServerReflectionResponse> serverCallStreamObserver)151 ProtoReflectionStreamObserver( 152 ServerReflectionIndex serverReflectionIndex, 153 ServerCallStreamObserver<ServerReflectionResponse> serverCallStreamObserver) { 154 this.serverReflectionIndex = serverReflectionIndex; 155 this.serverCallStreamObserver = checkNotNull(serverCallStreamObserver, "observer"); 156 } 157 158 @Override run()159 public void run() { 160 if (request != null) { 161 handleReflectionRequest(); 162 } 163 } 164 165 @Override onNext(ServerReflectionRequest request)166 public void onNext(ServerReflectionRequest request) { 167 checkState(this.request == null); 168 this.request = checkNotNull(request); 169 handleReflectionRequest(); 170 } 171 handleReflectionRequest()172 private void handleReflectionRequest() { 173 if (serverCallStreamObserver.isReady()) { 174 switch (request.getMessageRequestCase()) { 175 case FILE_BY_FILENAME: 176 getFileByName(request); 177 break; 178 case FILE_CONTAINING_SYMBOL: 179 getFileContainingSymbol(request); 180 break; 181 case FILE_CONTAINING_EXTENSION: 182 getFileByExtension(request); 183 break; 184 case ALL_EXTENSION_NUMBERS_OF_TYPE: 185 getAllExtensions(request); 186 break; 187 case LIST_SERVICES: 188 listServices(request); 189 break; 190 default: 191 sendErrorResponse( 192 request, 193 Status.Code.UNIMPLEMENTED, 194 "not implemented " + request.getMessageRequestCase()); 195 } 196 request = null; 197 if (closeAfterSend) { 198 serverCallStreamObserver.onCompleted(); 199 } else { 200 serverCallStreamObserver.request(1); 201 } 202 } 203 } 204 205 @Override onCompleted()206 public void onCompleted() { 207 if (request != null) { 208 closeAfterSend = true; 209 } else { 210 serverCallStreamObserver.onCompleted(); 211 } 212 } 213 214 @Override onError(Throwable cause)215 public void onError(Throwable cause) { 216 serverCallStreamObserver.onError(cause); 217 } 218 getFileByName(ServerReflectionRequest request)219 private void getFileByName(ServerReflectionRequest request) { 220 String name = request.getFileByFilename(); 221 FileDescriptor fd = serverReflectionIndex.getFileDescriptorByName(name); 222 if (fd != null) { 223 serverCallStreamObserver.onNext(createServerReflectionResponse(request, fd)); 224 } else { 225 sendErrorResponse(request, Status.Code.NOT_FOUND, "File not found."); 226 } 227 } 228 getFileContainingSymbol(ServerReflectionRequest request)229 private void getFileContainingSymbol(ServerReflectionRequest request) { 230 String symbol = request.getFileContainingSymbol(); 231 FileDescriptor fd = serverReflectionIndex.getFileDescriptorBySymbol(symbol); 232 if (fd != null) { 233 serverCallStreamObserver.onNext(createServerReflectionResponse(request, fd)); 234 } else { 235 sendErrorResponse(request, Status.Code.NOT_FOUND, "Symbol not found."); 236 } 237 } 238 getFileByExtension(ServerReflectionRequest request)239 private void getFileByExtension(ServerReflectionRequest request) { 240 ExtensionRequest extensionRequest = request.getFileContainingExtension(); 241 String type = extensionRequest.getContainingType(); 242 int extension = extensionRequest.getExtensionNumber(); 243 FileDescriptor fd = 244 serverReflectionIndex.getFileDescriptorByExtensionAndNumber(type, extension); 245 if (fd != null) { 246 serverCallStreamObserver.onNext(createServerReflectionResponse(request, fd)); 247 } else { 248 sendErrorResponse(request, Status.Code.NOT_FOUND, "Extension not found."); 249 } 250 } 251 getAllExtensions(ServerReflectionRequest request)252 private void getAllExtensions(ServerReflectionRequest request) { 253 String type = request.getAllExtensionNumbersOfType(); 254 Set<Integer> extensions = serverReflectionIndex.getExtensionNumbersOfType(type); 255 if (extensions != null) { 256 ExtensionNumberResponse.Builder builder = 257 ExtensionNumberResponse.newBuilder() 258 .setBaseTypeName(type) 259 .addAllExtensionNumber(extensions); 260 serverCallStreamObserver.onNext( 261 ServerReflectionResponse.newBuilder() 262 .setValidHost(request.getHost()) 263 .setOriginalRequest(request) 264 .setAllExtensionNumbersResponse(builder) 265 .build()); 266 } else { 267 sendErrorResponse(request, Status.Code.NOT_FOUND, "Type not found."); 268 } 269 } 270 listServices(ServerReflectionRequest request)271 private void listServices(ServerReflectionRequest request) { 272 ListServiceResponse.Builder builder = ListServiceResponse.newBuilder(); 273 for (String serviceName : serverReflectionIndex.getServiceNames()) { 274 builder.addService(ServiceResponse.newBuilder().setName(serviceName)); 275 } 276 serverCallStreamObserver.onNext( 277 ServerReflectionResponse.newBuilder() 278 .setValidHost(request.getHost()) 279 .setOriginalRequest(request) 280 .setListServicesResponse(builder) 281 .build()); 282 } 283 sendErrorResponse( ServerReflectionRequest request, Status.Code code, String message)284 private void sendErrorResponse( 285 ServerReflectionRequest request, Status.Code code, String message) { 286 ServerReflectionResponse response = 287 ServerReflectionResponse.newBuilder() 288 .setValidHost(request.getHost()) 289 .setOriginalRequest(request) 290 .setErrorResponse( 291 ErrorResponse.newBuilder() 292 .setErrorCode(code.value()) 293 .setErrorMessage(message)) 294 .build(); 295 serverCallStreamObserver.onNext(response); 296 } 297 createServerReflectionResponse( ServerReflectionRequest request, FileDescriptor fd)298 private ServerReflectionResponse createServerReflectionResponse( 299 ServerReflectionRequest request, FileDescriptor fd) { 300 FileDescriptorResponse.Builder fdRBuilder = FileDescriptorResponse.newBuilder(); 301 302 Set<String> seenFiles = new HashSet<String>(); 303 Queue<FileDescriptor> frontier = new ArrayDeque<FileDescriptor>(); 304 seenFiles.add(fd.getName()); 305 frontier.add(fd); 306 while (!frontier.isEmpty()) { 307 FileDescriptor nextFd = frontier.remove(); 308 fdRBuilder.addFileDescriptorProto(nextFd.toProto().toByteString()); 309 for (FileDescriptor dependencyFd : nextFd.getDependencies()) { 310 if (!seenFiles.contains(dependencyFd.getName())) { 311 seenFiles.add(dependencyFd.getName()); 312 frontier.add(dependencyFd); 313 } 314 } 315 } 316 return ServerReflectionResponse.newBuilder() 317 .setValidHost(request.getHost()) 318 .setOriginalRequest(request) 319 .setFileDescriptorResponse(fdRBuilder) 320 .build(); 321 } 322 } 323 324 /** 325 * Indexes the server's services and allows lookups of file descriptors by filename, symbol, type, 326 * and extension number. 327 * 328 * <p>Internally, this stores separate indices for the immutable and mutable services. When 329 * queried, the immutable service index is checked for a matching value. Only if there is no match 330 * in the immutable service index are the mutable services checked. 331 */ 332 private static final class ServerReflectionIndex { 333 private final FileDescriptorIndex immutableServicesIndex; 334 private final FileDescriptorIndex mutableServicesIndex; 335 ServerReflectionIndex( List<ServerServiceDefinition> immutableServices, List<ServerServiceDefinition> mutableServices)336 public ServerReflectionIndex( 337 List<ServerServiceDefinition> immutableServices, 338 List<ServerServiceDefinition> mutableServices) { 339 immutableServicesIndex = new FileDescriptorIndex(immutableServices); 340 mutableServicesIndex = new FileDescriptorIndex(mutableServices); 341 } 342 getMutableServicesIndex()343 private FileDescriptorIndex getMutableServicesIndex() { 344 return mutableServicesIndex; 345 } 346 getServiceNames()347 private Set<String> getServiceNames() { 348 Set<String> immutableServiceNames = immutableServicesIndex.getServiceNames(); 349 Set<String> mutableServiceNames = mutableServicesIndex.getServiceNames(); 350 Set<String> serviceNames = 351 new HashSet<String>(immutableServiceNames.size() + mutableServiceNames.size()); 352 serviceNames.addAll(immutableServiceNames); 353 serviceNames.addAll(mutableServiceNames); 354 return serviceNames; 355 } 356 357 @Nullable getFileDescriptorByName(String name)358 private FileDescriptor getFileDescriptorByName(String name) { 359 FileDescriptor fd = immutableServicesIndex.getFileDescriptorByName(name); 360 if (fd == null) { 361 fd = mutableServicesIndex.getFileDescriptorByName(name); 362 } 363 return fd; 364 } 365 366 @Nullable getFileDescriptorBySymbol(String symbol)367 private FileDescriptor getFileDescriptorBySymbol(String symbol) { 368 FileDescriptor fd = immutableServicesIndex.getFileDescriptorBySymbol(symbol); 369 if (fd == null) { 370 fd = mutableServicesIndex.getFileDescriptorBySymbol(symbol); 371 } 372 return fd; 373 } 374 375 @Nullable getFileDescriptorByExtensionAndNumber(String type, int extension)376 private FileDescriptor getFileDescriptorByExtensionAndNumber(String type, int extension) { 377 FileDescriptor fd = 378 immutableServicesIndex.getFileDescriptorByExtensionAndNumber(type, extension); 379 if (fd == null) { 380 fd = mutableServicesIndex.getFileDescriptorByExtensionAndNumber(type, extension); 381 } 382 return fd; 383 } 384 385 @Nullable getExtensionNumbersOfType(String type)386 private Set<Integer> getExtensionNumbersOfType(String type) { 387 Set<Integer> extensionNumbers = immutableServicesIndex.getExtensionNumbersOfType(type); 388 if (extensionNumbers == null) { 389 extensionNumbers = mutableServicesIndex.getExtensionNumbersOfType(type); 390 } 391 return extensionNumbers; 392 } 393 } 394 395 /** 396 * Provides a set of methods for answering reflection queries for the file descriptors underlying 397 * a set of services. Used by {@link ServerReflectionIndex} to separately index immutable and 398 * mutable services. 399 */ 400 private static final class FileDescriptorIndex { 401 private final Set<String> serviceNames = new HashSet<String>(); 402 private final Set<FileDescriptor> serviceFileDescriptors = new HashSet<FileDescriptor>(); 403 private final Map<String, FileDescriptor> fileDescriptorsByName = 404 new HashMap<String, FileDescriptor>(); 405 private final Map<String, FileDescriptor> fileDescriptorsBySymbol = 406 new HashMap<String, FileDescriptor>(); 407 private final Map<String, Map<Integer, FileDescriptor>> fileDescriptorsByExtensionAndNumber = 408 new HashMap<String, Map<Integer, FileDescriptor>>(); 409 FileDescriptorIndex(List<ServerServiceDefinition> services)410 FileDescriptorIndex(List<ServerServiceDefinition> services) { 411 Queue<FileDescriptor> fileDescriptorsToProcess = new ArrayDeque<FileDescriptor>(); 412 Set<String> seenFiles = new HashSet<String>(); 413 for (ServerServiceDefinition service : services) { 414 io.grpc.ServiceDescriptor serviceDescriptor = service.getServiceDescriptor(); 415 if (serviceDescriptor.getSchemaDescriptor() instanceof ProtoFileDescriptorSupplier) { 416 FileDescriptor fileDescriptor = 417 ((ProtoFileDescriptorSupplier) serviceDescriptor.getSchemaDescriptor()) 418 .getFileDescriptor(); 419 String serviceName = serviceDescriptor.getName(); 420 checkState( 421 !serviceNames.contains(serviceName), "Service already defined: %s", serviceName); 422 serviceFileDescriptors.add(fileDescriptor); 423 serviceNames.add(serviceName); 424 if (!seenFiles.contains(fileDescriptor.getName())) { 425 seenFiles.add(fileDescriptor.getName()); 426 fileDescriptorsToProcess.add(fileDescriptor); 427 } 428 } 429 } 430 431 while (!fileDescriptorsToProcess.isEmpty()) { 432 FileDescriptor currentFd = fileDescriptorsToProcess.remove(); 433 processFileDescriptor(currentFd); 434 for (FileDescriptor dependencyFd : currentFd.getDependencies()) { 435 if (!seenFiles.contains(dependencyFd.getName())) { 436 seenFiles.add(dependencyFd.getName()); 437 fileDescriptorsToProcess.add(dependencyFd); 438 } 439 } 440 } 441 } 442 443 /** 444 * Returns the file descriptors for the indexed services, but not their dependencies. This is 445 * used to check if the server's mutable services have changed. 446 */ getServiceFileDescriptors()447 private Set<FileDescriptor> getServiceFileDescriptors() { 448 return Collections.unmodifiableSet(serviceFileDescriptors); 449 } 450 getServiceNames()451 private Set<String> getServiceNames() { 452 return Collections.unmodifiableSet(serviceNames); 453 } 454 455 @Nullable getFileDescriptorByName(String name)456 private FileDescriptor getFileDescriptorByName(String name) { 457 return fileDescriptorsByName.get(name); 458 } 459 460 @Nullable getFileDescriptorBySymbol(String symbol)461 private FileDescriptor getFileDescriptorBySymbol(String symbol) { 462 return fileDescriptorsBySymbol.get(symbol); 463 } 464 465 @Nullable getFileDescriptorByExtensionAndNumber(String type, int number)466 private FileDescriptor getFileDescriptorByExtensionAndNumber(String type, int number) { 467 if (fileDescriptorsByExtensionAndNumber.containsKey(type)) { 468 return fileDescriptorsByExtensionAndNumber.get(type).get(number); 469 } 470 return null; 471 } 472 473 @Nullable getExtensionNumbersOfType(String type)474 private Set<Integer> getExtensionNumbersOfType(String type) { 475 if (fileDescriptorsByExtensionAndNumber.containsKey(type)) { 476 return Collections.unmodifiableSet(fileDescriptorsByExtensionAndNumber.get(type).keySet()); 477 } 478 return null; 479 } 480 processFileDescriptor(FileDescriptor fd)481 private void processFileDescriptor(FileDescriptor fd) { 482 String fdName = fd.getName(); 483 checkState(!fileDescriptorsByName.containsKey(fdName), "File name already used: %s", fdName); 484 fileDescriptorsByName.put(fdName, fd); 485 for (ServiceDescriptor service : fd.getServices()) { 486 processService(service, fd); 487 } 488 for (Descriptor type : fd.getMessageTypes()) { 489 processType(type, fd); 490 } 491 for (FieldDescriptor extension : fd.getExtensions()) { 492 processExtension(extension, fd); 493 } 494 } 495 processService(ServiceDescriptor service, FileDescriptor fd)496 private void processService(ServiceDescriptor service, FileDescriptor fd) { 497 String serviceName = service.getFullName(); 498 checkState( 499 !fileDescriptorsBySymbol.containsKey(serviceName), 500 "Service already defined: %s", 501 serviceName); 502 fileDescriptorsBySymbol.put(serviceName, fd); 503 for (MethodDescriptor method : service.getMethods()) { 504 String methodName = method.getFullName(); 505 checkState( 506 !fileDescriptorsBySymbol.containsKey(methodName), 507 "Method already defined: %s", 508 methodName); 509 fileDescriptorsBySymbol.put(methodName, fd); 510 } 511 } 512 processType(Descriptor type, FileDescriptor fd)513 private void processType(Descriptor type, FileDescriptor fd) { 514 String typeName = type.getFullName(); 515 checkState( 516 !fileDescriptorsBySymbol.containsKey(typeName), "Type already defined: %s", typeName); 517 fileDescriptorsBySymbol.put(typeName, fd); 518 for (FieldDescriptor extension : type.getExtensions()) { 519 processExtension(extension, fd); 520 } 521 for (Descriptor nestedType : type.getNestedTypes()) { 522 processType(nestedType, fd); 523 } 524 } 525 processExtension(FieldDescriptor extension, FileDescriptor fd)526 private void processExtension(FieldDescriptor extension, FileDescriptor fd) { 527 String extensionName = extension.getContainingType().getFullName(); 528 int extensionNumber = extension.getNumber(); 529 if (!fileDescriptorsByExtensionAndNumber.containsKey(extensionName)) { 530 fileDescriptorsByExtensionAndNumber.put( 531 extensionName, new HashMap<Integer, FileDescriptor>()); 532 } 533 checkState( 534 !fileDescriptorsByExtensionAndNumber.get(extensionName).containsKey(extensionNumber), 535 "Extension name and number already defined: %s, %s", 536 extensionName, 537 extensionNumber); 538 fileDescriptorsByExtensionAndNumber.get(extensionName).put(extensionNumber, fd); 539 } 540 } 541 } 542