1 /*
2  * Copyright 2014 The gRPC Authors
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *     http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 package io.grpc.okhttp;
18 
19 import static com.google.common.base.Preconditions.checkNotNull;
20 import static com.google.common.base.Preconditions.checkState;
21 import static io.grpc.internal.ClientStreamListener.RpcProgress.PROCESSED;
22 
23 import com.google.common.io.BaseEncoding;
24 import io.grpc.Attributes;
25 import io.grpc.Metadata;
26 import io.grpc.MethodDescriptor;
27 import io.grpc.Status;
28 import io.grpc.internal.AbstractClientStream;
29 import io.grpc.internal.Http2ClientStreamTransportState;
30 import io.grpc.internal.StatsTraceContext;
31 import io.grpc.internal.TransportTracer;
32 import io.grpc.internal.WritableBuffer;
33 import io.grpc.okhttp.internal.framed.ErrorCode;
34 import io.grpc.okhttp.internal.framed.Header;
35 import java.util.ArrayDeque;
36 import java.util.List;
37 import java.util.Queue;
38 import javax.annotation.concurrent.GuardedBy;
39 import okio.Buffer;
40 
41 /**
42  * Client stream for the okhttp transport.
43  */
44 class OkHttpClientStream extends AbstractClientStream {
45 
46   private static final int WINDOW_UPDATE_THRESHOLD = Utils.DEFAULT_WINDOW_SIZE / 2;
47 
48   private static final Buffer EMPTY_BUFFER = new Buffer();
49 
50   public static final int ABSENT_ID = -1;
51 
52   private final MethodDescriptor<?, ?> method;
53 
54   private final String userAgent;
55   private final StatsTraceContext statsTraceCtx;
56   private String authority;
57   private Object outboundFlowState;
58   private volatile int id = ABSENT_ID;
59   private final TransportState state;
60   private final Sink sink = new Sink();
61   private final Attributes attributes;
62 
63   private boolean useGet = false;
64 
OkHttpClientStream( MethodDescriptor<?, ?> method, Metadata headers, AsyncFrameWriter frameWriter, OkHttpClientTransport transport, OutboundFlowController outboundFlow, Object lock, int maxMessageSize, String authority, String userAgent, StatsTraceContext statsTraceCtx, TransportTracer transportTracer)65   OkHttpClientStream(
66       MethodDescriptor<?, ?> method,
67       Metadata headers,
68       AsyncFrameWriter frameWriter,
69       OkHttpClientTransport transport,
70       OutboundFlowController outboundFlow,
71       Object lock,
72       int maxMessageSize,
73       String authority,
74       String userAgent,
75       StatsTraceContext statsTraceCtx,
76       TransportTracer transportTracer) {
77     super(
78         new OkHttpWritableBufferAllocator(),
79         statsTraceCtx,
80         transportTracer,
81         headers,
82         method.isSafe());
83     this.statsTraceCtx = checkNotNull(statsTraceCtx, "statsTraceCtx");
84     this.method = method;
85     this.authority = authority;
86     this.userAgent = userAgent;
87     // OkHttpClientStream is only created after the transport has finished connecting,
88     // so it is safe to read the transport attributes.
89     // We make a copy here for convenience, even though we can ask the transport.
90     this.attributes = transport.getAttributes();
91     this.state = new TransportState(maxMessageSize, statsTraceCtx, lock, frameWriter, outboundFlow,
92         transport);
93   }
94 
95   @Override
transportState()96   protected TransportState transportState() {
97     return state;
98   }
99 
100   @Override
abstractClientStreamSink()101   protected Sink abstractClientStreamSink() {
102     return sink;
103   }
104 
105   /**
106    * Returns the type of this stream.
107    */
getType()108   public MethodDescriptor.MethodType getType() {
109     return method.getType();
110   }
111 
id()112   public int id() {
113     return id;
114   }
115 
116   /**
117    * Returns whether the stream uses GET. This is not known until after {@link Sink#writeHeaders} is
118    * invoked.
119    */
useGet()120   boolean useGet() {
121     return useGet;
122   }
123 
124   @Override
setAuthority(String authority)125   public void setAuthority(String authority) {
126     this.authority = checkNotNull(authority, "authority");
127   }
128 
129   @Override
getAttributes()130   public Attributes getAttributes() {
131     return attributes;
132   }
133 
134   class Sink implements AbstractClientStream.Sink {
135     @SuppressWarnings("BetaApi") // BaseEncoding is stable in Guava 20.0
136     @Override
writeHeaders(Metadata metadata, byte[] payload)137     public void writeHeaders(Metadata metadata, byte[] payload) {
138       String defaultPath = "/" + method.getFullMethodName();
139       if (payload != null) {
140         useGet = true;
141         defaultPath += "?" + BaseEncoding.base64().encode(payload);
142       }
143       synchronized (state.lock) {
144         state.streamReady(metadata, defaultPath);
145       }
146     }
147 
148     @Override
writeFrame( WritableBuffer frame, boolean endOfStream, boolean flush, int numMessages)149     public void writeFrame(
150         WritableBuffer frame, boolean endOfStream, boolean flush, int numMessages) {
151       Buffer buffer;
152       if (frame == null) {
153         buffer = EMPTY_BUFFER;
154       } else {
155         buffer = ((OkHttpWritableBuffer) frame).buffer();
156         int size = (int) buffer.size();
157         if (size > 0) {
158           onSendingBytes(size);
159         }
160       }
161 
162       synchronized (state.lock) {
163         state.sendBuffer(buffer, endOfStream, flush);
164         getTransportTracer().reportMessageSent(numMessages);
165       }
166     }
167 
168     @Override
request(final int numMessages)169     public void request(final int numMessages) {
170       synchronized (state.lock) {
171         state.requestMessagesFromDeframer(numMessages);
172       }
173     }
174 
175     @Override
cancel(Status reason)176     public void cancel(Status reason) {
177       synchronized (state.lock) {
178         state.cancel(reason, true, null);
179       }
180     }
181   }
182 
183   class TransportState extends Http2ClientStreamTransportState {
184     private final Object lock;
185     @GuardedBy("lock")
186     private List<Header> requestHeaders;
187     /**
188      * Null iff {@link #requestHeaders} is null.  Non-null iff neither {@link #cancel} nor
189      * {@link #start(int)} have been called.
190      */
191     @GuardedBy("lock")
192     private Queue<PendingData> pendingData = new ArrayDeque<PendingData>();
193     @GuardedBy("lock")
194     private boolean cancelSent = false;
195     @GuardedBy("lock")
196     private int window = Utils.DEFAULT_WINDOW_SIZE;
197     @GuardedBy("lock")
198     private int processedWindow = Utils.DEFAULT_WINDOW_SIZE;
199     @GuardedBy("lock")
200     private final AsyncFrameWriter frameWriter;
201     @GuardedBy("lock")
202     private final OutboundFlowController outboundFlow;
203     @GuardedBy("lock")
204     private final OkHttpClientTransport transport;
205 
TransportState( int maxMessageSize, StatsTraceContext statsTraceCtx, Object lock, AsyncFrameWriter frameWriter, OutboundFlowController outboundFlow, OkHttpClientTransport transport)206     public TransportState(
207         int maxMessageSize,
208         StatsTraceContext statsTraceCtx,
209         Object lock,
210         AsyncFrameWriter frameWriter,
211         OutboundFlowController outboundFlow,
212         OkHttpClientTransport transport) {
213       super(maxMessageSize, statsTraceCtx, OkHttpClientStream.this.getTransportTracer());
214       this.lock = checkNotNull(lock, "lock");
215       this.frameWriter = frameWriter;
216       this.outboundFlow = outboundFlow;
217       this.transport = transport;
218     }
219 
220     @GuardedBy("lock")
start(int streamId)221     public void start(int streamId) {
222       checkState(id == ABSENT_ID, "the stream has been started with id %s", streamId);
223       id = streamId;
224       state.onStreamAllocated();
225 
226       if (pendingData != null) {
227         // Only happens when the stream has neither been started nor cancelled.
228         frameWriter.synStream(useGet, false, id, 0, requestHeaders);
229         statsTraceCtx.clientOutboundHeaders();
230         requestHeaders = null;
231 
232         boolean flush = false;
233         while (!pendingData.isEmpty()) {
234           PendingData data = pendingData.poll();
235           outboundFlow.data(data.endOfStream, id, data.buffer, false);
236           if (data.flush) {
237             flush = true;
238           }
239         }
240         if (flush) {
241           outboundFlow.flush();
242         }
243         pendingData = null;
244       }
245     }
246 
247     @GuardedBy("lock")
248     @Override
onStreamAllocated()249     protected void onStreamAllocated() {
250       super.onStreamAllocated();
251       getTransportTracer().reportLocalStreamStarted();
252     }
253 
254     @GuardedBy("lock")
255     @Override
http2ProcessingFailed(Status status, boolean stopDelivery, Metadata trailers)256     protected void http2ProcessingFailed(Status status, boolean stopDelivery, Metadata trailers) {
257       cancel(status, stopDelivery, trailers);
258     }
259 
260     @Override
261     @GuardedBy("lock")
deframeFailed(Throwable cause)262     public void deframeFailed(Throwable cause) {
263       http2ProcessingFailed(Status.fromThrowable(cause), true, new Metadata());
264     }
265 
266     @Override
267     @GuardedBy("lock")
bytesRead(int processedBytes)268     public void bytesRead(int processedBytes) {
269       processedWindow -= processedBytes;
270       if (processedWindow <= WINDOW_UPDATE_THRESHOLD) {
271         int delta = Utils.DEFAULT_WINDOW_SIZE - processedWindow;
272         window += delta;
273         processedWindow += delta;
274         frameWriter.windowUpdate(id(), delta);
275       }
276     }
277 
278     @Override
279     @GuardedBy("lock")
deframerClosed(boolean hasPartialMessage)280     public void deframerClosed(boolean hasPartialMessage) {
281       onEndOfStream();
282       super.deframerClosed(hasPartialMessage);
283     }
284 
285     @Override
286     @GuardedBy("lock")
runOnTransportThread(final Runnable r)287     public void runOnTransportThread(final Runnable r) {
288       synchronized (lock) {
289         r.run();
290       }
291     }
292 
293     /**
294      * Must be called with holding the transport lock.
295      */
296     @GuardedBy("lock")
transportHeadersReceived(List<Header> headers, boolean endOfStream)297     public void transportHeadersReceived(List<Header> headers, boolean endOfStream) {
298       if (endOfStream) {
299         transportTrailersReceived(Utils.convertTrailers(headers));
300       } else {
301         transportHeadersReceived(Utils.convertHeaders(headers));
302       }
303     }
304 
305     /**
306      * Must be called with holding the transport lock.
307      */
308     @GuardedBy("lock")
transportDataReceived(okio.Buffer frame, boolean endOfStream)309     public void transportDataReceived(okio.Buffer frame, boolean endOfStream) {
310       // We only support 16 KiB frames, and the max permitted in HTTP/2 is 16 MiB. This is verified
311       // in OkHttp's Http2 deframer. In addition, this code is after the data has been read.
312       int length = (int) frame.size();
313       window -= length;
314       if (window < 0) {
315         frameWriter.rstStream(id(), ErrorCode.FLOW_CONTROL_ERROR);
316         transport.finishStream(
317             id(),
318             Status.INTERNAL.withDescription(
319                 "Received data size exceeded our receiving window size"),
320             PROCESSED, false, null, null);
321         return;
322       }
323       super.transportDataReceived(new OkHttpReadableBuffer(frame), endOfStream);
324     }
325 
326     @GuardedBy("lock")
onEndOfStream()327     private void onEndOfStream() {
328       if (!isOutboundClosed()) {
329         // If server's end-of-stream is received before client sends end-of-stream, we just send a
330         // reset to server to fully close the server side stream.
331         transport.finishStream(id(),null, PROCESSED, false, ErrorCode.CANCEL, null);
332       } else {
333         transport.finishStream(id(), null, PROCESSED, false, null, null);
334       }
335     }
336 
337     @GuardedBy("lock")
cancel(Status reason, boolean stopDelivery, Metadata trailers)338     private void cancel(Status reason, boolean stopDelivery, Metadata trailers) {
339       if (cancelSent) {
340         return;
341       }
342       cancelSent = true;
343       if (pendingData != null) {
344         // stream is pending.
345         transport.removePendingStream(OkHttpClientStream.this);
346         // release holding data, so they can be GCed or returned to pool earlier.
347         requestHeaders = null;
348         for (PendingData data : pendingData) {
349           data.buffer.clear();
350         }
351         pendingData = null;
352         transportReportStatus(reason, true, trailers != null ? trailers : new Metadata());
353       } else {
354         // If pendingData is null, start must have already been called, which means synStream has
355         // been called as well.
356         transport.finishStream(
357             id(), reason, PROCESSED, stopDelivery, ErrorCode.CANCEL, trailers);
358       }
359     }
360 
361     @GuardedBy("lock")
sendBuffer(Buffer buffer, boolean endOfStream, boolean flush)362     private void sendBuffer(Buffer buffer, boolean endOfStream, boolean flush) {
363       if (cancelSent) {
364         return;
365       }
366       if (pendingData != null) {
367         // Stream is pending start, queue the data.
368         pendingData.add(new PendingData(buffer, endOfStream, flush));
369       } else {
370         checkState(id() != ABSENT_ID, "streamId should be set");
371         // If buffer > frameWriter.maxDataLength() the flow-controller will ensure that it is
372         // properly chunked.
373         outboundFlow.data(endOfStream, id(), buffer, flush);
374       }
375     }
376 
377     @GuardedBy("lock")
streamReady(Metadata metadata, String path)378     private void streamReady(Metadata metadata, String path) {
379       requestHeaders = Headers.createRequestHeaders(metadata, path, authority, userAgent, useGet);
380       transport.streamReadyToStart(OkHttpClientStream.this);
381     }
382   }
383 
setOutboundFlowState(Object outboundFlowState)384   void setOutboundFlowState(Object outboundFlowState) {
385     this.outboundFlowState = outboundFlowState;
386   }
387 
getOutboundFlowState()388   Object getOutboundFlowState() {
389     return outboundFlowState;
390   }
391 
392   private static class PendingData {
393     Buffer buffer;
394     boolean endOfStream;
395     boolean flush;
396 
PendingData(Buffer buffer, boolean endOfStream, boolean flush)397     PendingData(Buffer buffer, boolean endOfStream, boolean flush) {
398       this.buffer = buffer;
399       this.endOfStream = endOfStream;
400       this.flush = flush;
401     }
402   }
403 }
404