1 /*
2  * Copyright 2014 The gRPC Authors
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *     http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 package io.grpc.protobuf.lite;
18 
19 import static com.google.common.base.Preconditions.checkNotNull;
20 
21 import com.google.common.annotations.VisibleForTesting;
22 import com.google.protobuf.CodedInputStream;
23 import com.google.protobuf.ExtensionRegistryLite;
24 import com.google.protobuf.InvalidProtocolBufferException;
25 import com.google.protobuf.MessageLite;
26 import com.google.protobuf.Parser;
27 import io.grpc.ExperimentalApi;
28 import io.grpc.KnownLength;
29 import io.grpc.Metadata;
30 import io.grpc.MethodDescriptor.Marshaller;
31 import io.grpc.MethodDescriptor.PrototypeMarshaller;
32 import io.grpc.Status;
33 import java.io.IOException;
34 import java.io.InputStream;
35 import java.io.OutputStream;
36 import java.lang.ref.Reference;
37 import java.lang.ref.WeakReference;
38 
39 /**
40  * Utility methods for using protobuf with grpc.
41  */
42 @ExperimentalApi("Experimental until Lite is stable in protobuf")
43 public final class ProtoLiteUtils {
44 
45   // default visibility to avoid synthetic accessors
46   static volatile ExtensionRegistryLite globalRegistry =
47       ExtensionRegistryLite.getEmptyRegistry();
48 
49   private static final int BUF_SIZE = 8192;
50 
51   /**
52    * The same value as {@link io.grpc.internal.GrpcUtil#DEFAULT_MAX_MESSAGE_SIZE}.
53    */
54   @VisibleForTesting
55   static final int DEFAULT_MAX_MESSAGE_SIZE = 4 * 1024 * 1024;
56 
57   /**
58    * Sets the global registry for proto marshalling shared across all servers and clients.
59    *
60    * <p>Warning:  This API will likely change over time.  It is not possible to have separate
61    * registries per Process, Server, Channel, Service, or Method.  This is intentional until there
62    * is a more appropriate API to set them.
63    *
64    * <p>Warning:  Do NOT modify the extension registry after setting it.  It is thread safe to call
65    * {@link #setExtensionRegistry}, but not to modify the underlying object.
66    *
67    * <p>If you need custom parsing behavior for protos, you will need to make your own
68    * {@code MethodDescriptor.Marshaller} for the time being.
69    *
70    * @since 1.0.0
71    */
72   @ExperimentalApi("https://github.com/grpc/grpc-java/issues/1787")
setExtensionRegistry(ExtensionRegistryLite newRegistry)73   public static void setExtensionRegistry(ExtensionRegistryLite newRegistry) {
74     globalRegistry = checkNotNull(newRegistry, "newRegistry");
75   }
76 
77   /**
78    * Creates a {@link Marshaller} for protos of the same type as {@code defaultInstance}.
79    *
80    * @since 1.0.0
81    */
marshaller(T defaultInstance)82   public static <T extends MessageLite> Marshaller<T> marshaller(T defaultInstance) {
83     // TODO(ejona): consider changing return type to PrototypeMarshaller (assuming ABI safe)
84     return new MessageMarshaller<T>(defaultInstance);
85   }
86 
87   /**
88    * Produce a metadata marshaller for a protobuf type.
89    *
90    * @since 1.0.0
91    */
metadataMarshaller( T defaultInstance)92   public static <T extends MessageLite> Metadata.BinaryMarshaller<T> metadataMarshaller(
93       T defaultInstance) {
94     return new MetadataMarshaller<T>(defaultInstance);
95   }
96 
97   /** Copies the data from input stream to output stream. */
copy(InputStream from, OutputStream to)98   static long copy(InputStream from, OutputStream to) throws IOException {
99     // Copied from guava com.google.common.io.ByteStreams because its API is unstable (beta)
100     checkNotNull(from);
101     checkNotNull(to);
102     byte[] buf = new byte[BUF_SIZE];
103     long total = 0;
104     while (true) {
105       int r = from.read(buf);
106       if (r == -1) {
107         break;
108       }
109       to.write(buf, 0, r);
110       total += r;
111     }
112     return total;
113   }
114 
ProtoLiteUtils()115   private ProtoLiteUtils() {
116   }
117 
118   private static final class MessageMarshaller<T extends MessageLite>
119       implements PrototypeMarshaller<T> {
120     private static final ThreadLocal<Reference<byte[]>> bufs = new ThreadLocal<Reference<byte[]>>();
121 
122     private final Parser<T> parser;
123     private final T defaultInstance;
124 
125     @SuppressWarnings("unchecked")
MessageMarshaller(T defaultInstance)126     MessageMarshaller(T defaultInstance) {
127       this.defaultInstance = defaultInstance;
128       parser = (Parser<T>) defaultInstance.getParserForType();
129     }
130 
131 
132     @SuppressWarnings("unchecked")
133     @Override
getMessageClass()134     public Class<T> getMessageClass() {
135       // Precisely T since protobuf doesn't let messages extend other messages.
136       return (Class<T>) defaultInstance.getClass();
137     }
138 
139     @Override
getMessagePrototype()140     public T getMessagePrototype() {
141       return defaultInstance;
142     }
143 
144     @Override
stream(T value)145     public InputStream stream(T value) {
146       return new ProtoInputStream(value, parser);
147     }
148 
149     @Override
parse(InputStream stream)150     public T parse(InputStream stream) {
151       if (stream instanceof ProtoInputStream) {
152         ProtoInputStream protoStream = (ProtoInputStream) stream;
153         // Optimization for in-memory transport. Returning provided object is safe since protobufs
154         // are immutable.
155         //
156         // However, we can't assume the types match, so we have to verify the parser matches.
157         // Today the parser is always the same for a given proto, but that isn't guaranteed. Even
158         // if not, using the same MethodDescriptor would ensure the parser matches and permit us
159         // to enable this optimization.
160         if (protoStream.parser() == parser) {
161           try {
162             @SuppressWarnings("unchecked")
163             T message = (T) ((ProtoInputStream) stream).message();
164             return message;
165           } catch (IllegalStateException ex) {
166             // Stream must have been read from, which is a strange state. Since the point of this
167             // optimization is to be transparent, instead of throwing an error we'll continue,
168             // even though it seems likely there's a bug.
169           }
170         }
171       }
172       CodedInputStream cis = null;
173       try {
174         if (stream instanceof KnownLength) {
175           int size = stream.available();
176           if (size > 0 && size <= DEFAULT_MAX_MESSAGE_SIZE) {
177             Reference<byte[]> ref;
178             // buf should not be used after this method has returned.
179             byte[] buf;
180             if ((ref = bufs.get()) == null || (buf = ref.get()) == null || buf.length < size) {
181               buf = new byte[size];
182               bufs.set(new WeakReference<byte[]>(buf));
183             }
184 
185             int remaining = size;
186             while (remaining > 0) {
187               int position = size - remaining;
188               int count = stream.read(buf, position, remaining);
189               if (count == -1) {
190                 break;
191               }
192               remaining -= count;
193             }
194 
195             if (remaining != 0) {
196               int position = size - remaining;
197               throw new RuntimeException("size inaccurate: " + size + " != " + position);
198             }
199             cis = CodedInputStream.newInstance(buf, 0, size);
200           } else if (size == 0) {
201             return defaultInstance;
202           }
203         }
204       } catch (IOException e) {
205         throw new RuntimeException(e);
206       }
207       if (cis == null) {
208         cis = CodedInputStream.newInstance(stream);
209       }
210       // Pre-create the CodedInputStream so that we can remove the size limit restriction
211       // when parsing.
212       cis.setSizeLimit(Integer.MAX_VALUE);
213 
214       try {
215         return parseFrom(cis);
216       } catch (InvalidProtocolBufferException ipbe) {
217         throw Status.INTERNAL.withDescription("Invalid protobuf byte sequence")
218             .withCause(ipbe).asRuntimeException();
219       }
220     }
221 
parseFrom(CodedInputStream stream)222     private T parseFrom(CodedInputStream stream) throws InvalidProtocolBufferException {
223       T message = parser.parseFrom(stream, globalRegistry);
224       try {
225         stream.checkLastTagWas(0);
226         return message;
227       } catch (InvalidProtocolBufferException e) {
228         e.setUnfinishedMessage(message);
229         throw e;
230       }
231     }
232   }
233 
234   private static final class MetadataMarshaller<T extends MessageLite>
235       implements Metadata.BinaryMarshaller<T> {
236 
237     private final T defaultInstance;
238 
MetadataMarshaller(T defaultInstance)239     MetadataMarshaller(T defaultInstance) {
240       this.defaultInstance = defaultInstance;
241     }
242 
243     @Override
toBytes(T value)244     public byte[] toBytes(T value) {
245       return value.toByteArray();
246     }
247 
248     @Override
249     @SuppressWarnings("unchecked")
parseBytes(byte[] serialized)250     public T parseBytes(byte[] serialized) {
251       try {
252         return (T) defaultInstance.getParserForType().parseFrom(serialized, globalRegistry);
253       } catch (InvalidProtocolBufferException ipbe) {
254         throw new IllegalArgumentException(ipbe);
255       }
256     }
257   }
258 }
259