1/*
2 *
3 * Copyright 2014 gRPC authors.
4 *
5 * Licensed under the Apache License, Version 2.0 (the "License");
6 * you may not use this file except in compliance with the License.
7 * You may obtain a copy of the License at
8 *
9 *     http://www.apache.org/licenses/LICENSE-2.0
10 *
11 * Unless required by applicable law or agreed to in writing, software
12 * distributed under the License is distributed on an "AS IS" BASIS,
13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 * See the License for the specific language governing permissions and
15 * limitations under the License.
16 *
17 */
18
19package transport
20
21import (
22	"bufio"
23	"bytes"
24	"encoding/base64"
25	"fmt"
26	"net"
27	"net/http"
28	"strconv"
29	"strings"
30	"time"
31	"unicode/utf8"
32
33	"github.com/golang/protobuf/proto"
34	"golang.org/x/net/http2"
35	"golang.org/x/net/http2/hpack"
36	spb "google.golang.org/genproto/googleapis/rpc/status"
37	"google.golang.org/grpc/codes"
38	"google.golang.org/grpc/status"
39)
40
41const (
42	// http2MaxFrameLen specifies the max length of a HTTP2 frame.
43	http2MaxFrameLen = 16384 // 16KB frame
44	// http://http2.github.io/http2-spec/#SettingValues
45	http2InitHeaderTableSize = 4096
46	// http2IOBufSize specifies the buffer size for sending frames.
47	defaultWriteBufSize = 32 * 1024
48	defaultReadBufSize  = 32 * 1024
49	// baseContentType is the base content-type for gRPC.  This is a valid
50	// content-type on it's own, but can also include a content-subtype such as
51	// "proto" as a suffix after "+" or ";".  See
52	// https://github.com/grpc/grpc/blob/master/doc/PROTOCOL-HTTP2.md#requests
53	// for more details.
54	baseContentType = "application/grpc"
55)
56
57var (
58	clientPreface   = []byte(http2.ClientPreface)
59	http2ErrConvTab = map[http2.ErrCode]codes.Code{
60		http2.ErrCodeNo:                 codes.Internal,
61		http2.ErrCodeProtocol:           codes.Internal,
62		http2.ErrCodeInternal:           codes.Internal,
63		http2.ErrCodeFlowControl:        codes.ResourceExhausted,
64		http2.ErrCodeSettingsTimeout:    codes.Internal,
65		http2.ErrCodeStreamClosed:       codes.Internal,
66		http2.ErrCodeFrameSize:          codes.Internal,
67		http2.ErrCodeRefusedStream:      codes.Unavailable,
68		http2.ErrCodeCancel:             codes.Canceled,
69		http2.ErrCodeCompression:        codes.Internal,
70		http2.ErrCodeConnect:            codes.Internal,
71		http2.ErrCodeEnhanceYourCalm:    codes.ResourceExhausted,
72		http2.ErrCodeInadequateSecurity: codes.PermissionDenied,
73		http2.ErrCodeHTTP11Required:     codes.Internal,
74	}
75	statusCodeConvTab = map[codes.Code]http2.ErrCode{
76		codes.Internal:          http2.ErrCodeInternal,
77		codes.Canceled:          http2.ErrCodeCancel,
78		codes.Unavailable:       http2.ErrCodeRefusedStream,
79		codes.ResourceExhausted: http2.ErrCodeEnhanceYourCalm,
80		codes.PermissionDenied:  http2.ErrCodeInadequateSecurity,
81	}
82	httpStatusConvTab = map[int]codes.Code{
83		// 400 Bad Request - INTERNAL.
84		http.StatusBadRequest: codes.Internal,
85		// 401 Unauthorized  - UNAUTHENTICATED.
86		http.StatusUnauthorized: codes.Unauthenticated,
87		// 403 Forbidden - PERMISSION_DENIED.
88		http.StatusForbidden: codes.PermissionDenied,
89		// 404 Not Found - UNIMPLEMENTED.
90		http.StatusNotFound: codes.Unimplemented,
91		// 429 Too Many Requests - UNAVAILABLE.
92		http.StatusTooManyRequests: codes.Unavailable,
93		// 502 Bad Gateway - UNAVAILABLE.
94		http.StatusBadGateway: codes.Unavailable,
95		// 503 Service Unavailable - UNAVAILABLE.
96		http.StatusServiceUnavailable: codes.Unavailable,
97		// 504 Gateway timeout - UNAVAILABLE.
98		http.StatusGatewayTimeout: codes.Unavailable,
99	}
100)
101
102// Records the states during HPACK decoding. Must be reset once the
103// decoding of the entire headers are finished.
104type decodeState struct {
105	encoding string
106	// statusGen caches the stream status received from the trailer the server
107	// sent.  Client side only.  Do not access directly.  After all trailers are
108	// parsed, use the status method to retrieve the status.
109	statusGen *status.Status
110	// rawStatusCode and rawStatusMsg are set from the raw trailer fields and are not
111	// intended for direct access outside of parsing.
112	rawStatusCode *int
113	rawStatusMsg  string
114	httpStatus    *int
115	// Server side only fields.
116	timeoutSet bool
117	timeout    time.Duration
118	method     string
119	// key-value metadata map from the peer.
120	mdata          map[string][]string
121	statsTags      []byte
122	statsTrace     []byte
123	contentSubtype string
124}
125
126// isReservedHeader checks whether hdr belongs to HTTP2 headers
127// reserved by gRPC protocol. Any other headers are classified as the
128// user-specified metadata.
129func isReservedHeader(hdr string) bool {
130	if hdr != "" && hdr[0] == ':' {
131		return true
132	}
133	switch hdr {
134	case "content-type",
135		"user-agent",
136		"grpc-message-type",
137		"grpc-encoding",
138		"grpc-message",
139		"grpc-status",
140		"grpc-timeout",
141		"grpc-status-details-bin",
142		"te":
143		return true
144	default:
145		return false
146	}
147}
148
149// isWhitelistedHeader checks whether hdr should be propagated
150// into metadata visible to users.
151func isWhitelistedHeader(hdr string) bool {
152	switch hdr {
153	case ":authority", "user-agent":
154		return true
155	default:
156		return false
157	}
158}
159
160// contentSubtype returns the content-subtype for the given content-type.  The
161// given content-type must be a valid content-type that starts with
162// "application/grpc". A content-subtype will follow "application/grpc" after a
163// "+" or ";". See
164// https://github.com/grpc/grpc/blob/master/doc/PROTOCOL-HTTP2.md#requests for
165// more details.
166//
167// If contentType is not a valid content-type for gRPC, the boolean
168// will be false, otherwise true. If content-type == "application/grpc",
169// "application/grpc+", or "application/grpc;", the boolean will be true,
170// but no content-subtype will be returned.
171//
172// contentType is assumed to be lowercase already.
173func contentSubtype(contentType string) (string, bool) {
174	if contentType == baseContentType {
175		return "", true
176	}
177	if !strings.HasPrefix(contentType, baseContentType) {
178		return "", false
179	}
180	// guaranteed since != baseContentType and has baseContentType prefix
181	switch contentType[len(baseContentType)] {
182	case '+', ';':
183		// this will return true for "application/grpc+" or "application/grpc;"
184		// which the previous validContentType function tested to be valid, so we
185		// just say that no content-subtype is specified in this case
186		return contentType[len(baseContentType)+1:], true
187	default:
188		return "", false
189	}
190}
191
192// contentSubtype is assumed to be lowercase
193func contentType(contentSubtype string) string {
194	if contentSubtype == "" {
195		return baseContentType
196	}
197	return baseContentType + "+" + contentSubtype
198}
199
200func (d *decodeState) status() *status.Status {
201	if d.statusGen == nil {
202		// No status-details were provided; generate status using code/msg.
203		d.statusGen = status.New(codes.Code(int32(*(d.rawStatusCode))), d.rawStatusMsg)
204	}
205	return d.statusGen
206}
207
208const binHdrSuffix = "-bin"
209
210func encodeBinHeader(v []byte) string {
211	return base64.RawStdEncoding.EncodeToString(v)
212}
213
214func decodeBinHeader(v string) ([]byte, error) {
215	if len(v)%4 == 0 {
216		// Input was padded, or padding was not necessary.
217		return base64.StdEncoding.DecodeString(v)
218	}
219	return base64.RawStdEncoding.DecodeString(v)
220}
221
222func encodeMetadataHeader(k, v string) string {
223	if strings.HasSuffix(k, binHdrSuffix) {
224		return encodeBinHeader(([]byte)(v))
225	}
226	return v
227}
228
229func decodeMetadataHeader(k, v string) (string, error) {
230	if strings.HasSuffix(k, binHdrSuffix) {
231		b, err := decodeBinHeader(v)
232		return string(b), err
233	}
234	return v, nil
235}
236
237func (d *decodeState) decodeResponseHeader(frame *http2.MetaHeadersFrame) error {
238	for _, hf := range frame.Fields {
239		if err := d.processHeaderField(hf); err != nil {
240			return err
241		}
242	}
243
244	// If grpc status exists, no need to check further.
245	if d.rawStatusCode != nil || d.statusGen != nil {
246		return nil
247	}
248
249	// If grpc status doesn't exist and http status doesn't exist,
250	// then it's a malformed header.
251	if d.httpStatus == nil {
252		return streamErrorf(codes.Internal, "malformed header: doesn't contain status(gRPC or HTTP)")
253	}
254
255	if *(d.httpStatus) != http.StatusOK {
256		code, ok := httpStatusConvTab[*(d.httpStatus)]
257		if !ok {
258			code = codes.Unknown
259		}
260		return streamErrorf(code, http.StatusText(*(d.httpStatus)))
261	}
262
263	// gRPC status doesn't exist and http status is OK.
264	// Set rawStatusCode to be unknown and return nil error.
265	// So that, if the stream has ended this Unknown status
266	// will be propagated to the user.
267	// Otherwise, it will be ignored. In which case, status from
268	// a later trailer, that has StreamEnded flag set, is propagated.
269	code := int(codes.Unknown)
270	d.rawStatusCode = &code
271	return nil
272
273}
274
275func (d *decodeState) addMetadata(k, v string) {
276	if d.mdata == nil {
277		d.mdata = make(map[string][]string)
278	}
279	d.mdata[k] = append(d.mdata[k], v)
280}
281
282func (d *decodeState) processHeaderField(f hpack.HeaderField) error {
283	switch f.Name {
284	case "content-type":
285		contentSubtype, validContentType := contentSubtype(f.Value)
286		if !validContentType {
287			return streamErrorf(codes.Internal, "transport: received the unexpected content-type %q", f.Value)
288		}
289		d.contentSubtype = contentSubtype
290		// TODO: do we want to propagate the whole content-type in the metadata,
291		// or come up with a way to just propagate the content-subtype if it was set?
292		// ie {"content-type": "application/grpc+proto"} or {"content-subtype": "proto"}
293		// in the metadata?
294		d.addMetadata(f.Name, f.Value)
295	case "grpc-encoding":
296		d.encoding = f.Value
297	case "grpc-status":
298		code, err := strconv.Atoi(f.Value)
299		if err != nil {
300			return streamErrorf(codes.Internal, "transport: malformed grpc-status: %v", err)
301		}
302		d.rawStatusCode = &code
303	case "grpc-message":
304		d.rawStatusMsg = decodeGrpcMessage(f.Value)
305	case "grpc-status-details-bin":
306		v, err := decodeBinHeader(f.Value)
307		if err != nil {
308			return streamErrorf(codes.Internal, "transport: malformed grpc-status-details-bin: %v", err)
309		}
310		s := &spb.Status{}
311		if err := proto.Unmarshal(v, s); err != nil {
312			return streamErrorf(codes.Internal, "transport: malformed grpc-status-details-bin: %v", err)
313		}
314		d.statusGen = status.FromProto(s)
315	case "grpc-timeout":
316		d.timeoutSet = true
317		var err error
318		if d.timeout, err = decodeTimeout(f.Value); err != nil {
319			return streamErrorf(codes.Internal, "transport: malformed time-out: %v", err)
320		}
321	case ":path":
322		d.method = f.Value
323	case ":status":
324		code, err := strconv.Atoi(f.Value)
325		if err != nil {
326			return streamErrorf(codes.Internal, "transport: malformed http-status: %v", err)
327		}
328		d.httpStatus = &code
329	case "grpc-tags-bin":
330		v, err := decodeBinHeader(f.Value)
331		if err != nil {
332			return streamErrorf(codes.Internal, "transport: malformed grpc-tags-bin: %v", err)
333		}
334		d.statsTags = v
335		d.addMetadata(f.Name, string(v))
336	case "grpc-trace-bin":
337		v, err := decodeBinHeader(f.Value)
338		if err != nil {
339			return streamErrorf(codes.Internal, "transport: malformed grpc-trace-bin: %v", err)
340		}
341		d.statsTrace = v
342		d.addMetadata(f.Name, string(v))
343	default:
344		if isReservedHeader(f.Name) && !isWhitelistedHeader(f.Name) {
345			break
346		}
347		v, err := decodeMetadataHeader(f.Name, f.Value)
348		if err != nil {
349			errorf("Failed to decode metadata header (%q, %q): %v", f.Name, f.Value, err)
350			return nil
351		}
352		d.addMetadata(f.Name, v)
353	}
354	return nil
355}
356
357type timeoutUnit uint8
358
359const (
360	hour        timeoutUnit = 'H'
361	minute      timeoutUnit = 'M'
362	second      timeoutUnit = 'S'
363	millisecond timeoutUnit = 'm'
364	microsecond timeoutUnit = 'u'
365	nanosecond  timeoutUnit = 'n'
366)
367
368func timeoutUnitToDuration(u timeoutUnit) (d time.Duration, ok bool) {
369	switch u {
370	case hour:
371		return time.Hour, true
372	case minute:
373		return time.Minute, true
374	case second:
375		return time.Second, true
376	case millisecond:
377		return time.Millisecond, true
378	case microsecond:
379		return time.Microsecond, true
380	case nanosecond:
381		return time.Nanosecond, true
382	default:
383	}
384	return
385}
386
387const maxTimeoutValue int64 = 100000000 - 1
388
389// div does integer division and round-up the result. Note that this is
390// equivalent to (d+r-1)/r but has less chance to overflow.
391func div(d, r time.Duration) int64 {
392	if m := d % r; m > 0 {
393		return int64(d/r + 1)
394	}
395	return int64(d / r)
396}
397
398// TODO(zhaoq): It is the simplistic and not bandwidth efficient. Improve it.
399func encodeTimeout(t time.Duration) string {
400	if t <= 0 {
401		return "0n"
402	}
403	if d := div(t, time.Nanosecond); d <= maxTimeoutValue {
404		return strconv.FormatInt(d, 10) + "n"
405	}
406	if d := div(t, time.Microsecond); d <= maxTimeoutValue {
407		return strconv.FormatInt(d, 10) + "u"
408	}
409	if d := div(t, time.Millisecond); d <= maxTimeoutValue {
410		return strconv.FormatInt(d, 10) + "m"
411	}
412	if d := div(t, time.Second); d <= maxTimeoutValue {
413		return strconv.FormatInt(d, 10) + "S"
414	}
415	if d := div(t, time.Minute); d <= maxTimeoutValue {
416		return strconv.FormatInt(d, 10) + "M"
417	}
418	// Note that maxTimeoutValue * time.Hour > MaxInt64.
419	return strconv.FormatInt(div(t, time.Hour), 10) + "H"
420}
421
422func decodeTimeout(s string) (time.Duration, error) {
423	size := len(s)
424	if size < 2 {
425		return 0, fmt.Errorf("transport: timeout string is too short: %q", s)
426	}
427	unit := timeoutUnit(s[size-1])
428	d, ok := timeoutUnitToDuration(unit)
429	if !ok {
430		return 0, fmt.Errorf("transport: timeout unit is not recognized: %q", s)
431	}
432	t, err := strconv.ParseInt(s[:size-1], 10, 64)
433	if err != nil {
434		return 0, err
435	}
436	return d * time.Duration(t), nil
437}
438
439const (
440	spaceByte   = ' '
441	tildeByte   = '~'
442	percentByte = '%'
443)
444
445// encodeGrpcMessage is used to encode status code in header field
446// "grpc-message". It does percent encoding and also replaces invalid utf-8
447// characters with Unicode replacement character.
448//
449// It checks to see if each individual byte in msg is an allowable byte, and
450// then either percent encoding or passing it through. When percent encoding,
451// the byte is converted into hexadecimal notation with a '%' prepended.
452func encodeGrpcMessage(msg string) string {
453	if msg == "" {
454		return ""
455	}
456	lenMsg := len(msg)
457	for i := 0; i < lenMsg; i++ {
458		c := msg[i]
459		if !(c >= spaceByte && c <= tildeByte && c != percentByte) {
460			return encodeGrpcMessageUnchecked(msg)
461		}
462	}
463	return msg
464}
465
466func encodeGrpcMessageUnchecked(msg string) string {
467	var buf bytes.Buffer
468	for len(msg) > 0 {
469		r, size := utf8.DecodeRuneInString(msg)
470		for _, b := range []byte(string(r)) {
471			if size > 1 {
472				// If size > 1, r is not ascii. Always do percent encoding.
473				buf.WriteString(fmt.Sprintf("%%%02X", b))
474				continue
475			}
476
477			// The for loop is necessary even if size == 1. r could be
478			// utf8.RuneError.
479			//
480			// fmt.Sprintf("%%%02X", utf8.RuneError) gives "%FFFD".
481			if b >= spaceByte && b <= tildeByte && b != percentByte {
482				buf.WriteByte(b)
483			} else {
484				buf.WriteString(fmt.Sprintf("%%%02X", b))
485			}
486		}
487		msg = msg[size:]
488	}
489	return buf.String()
490}
491
492// decodeGrpcMessage decodes the msg encoded by encodeGrpcMessage.
493func decodeGrpcMessage(msg string) string {
494	if msg == "" {
495		return ""
496	}
497	lenMsg := len(msg)
498	for i := 0; i < lenMsg; i++ {
499		if msg[i] == percentByte && i+2 < lenMsg {
500			return decodeGrpcMessageUnchecked(msg)
501		}
502	}
503	return msg
504}
505
506func decodeGrpcMessageUnchecked(msg string) string {
507	var buf bytes.Buffer
508	lenMsg := len(msg)
509	for i := 0; i < lenMsg; i++ {
510		c := msg[i]
511		if c == percentByte && i+2 < lenMsg {
512			parsed, err := strconv.ParseUint(msg[i+1:i+3], 16, 8)
513			if err != nil {
514				buf.WriteByte(c)
515			} else {
516				buf.WriteByte(byte(parsed))
517				i += 2
518			}
519		} else {
520			buf.WriteByte(c)
521		}
522	}
523	return buf.String()
524}
525
526type bufWriter struct {
527	buf       []byte
528	offset    int
529	batchSize int
530	conn      net.Conn
531	err       error
532
533	onFlush func()
534}
535
536func newBufWriter(conn net.Conn, batchSize int) *bufWriter {
537	return &bufWriter{
538		buf:       make([]byte, batchSize*2),
539		batchSize: batchSize,
540		conn:      conn,
541	}
542}
543
544func (w *bufWriter) Write(b []byte) (n int, err error) {
545	if w.err != nil {
546		return 0, w.err
547	}
548	for len(b) > 0 {
549		nn := copy(w.buf[w.offset:], b)
550		b = b[nn:]
551		w.offset += nn
552		n += nn
553		if w.offset >= w.batchSize {
554			err = w.Flush()
555		}
556	}
557	return n, err
558}
559
560func (w *bufWriter) Flush() error {
561	if w.err != nil {
562		return w.err
563	}
564	if w.offset == 0 {
565		return nil
566	}
567	if w.onFlush != nil {
568		w.onFlush()
569	}
570	_, w.err = w.conn.Write(w.buf[:w.offset])
571	w.offset = 0
572	return w.err
573}
574
575type framer struct {
576	writer *bufWriter
577	fr     *http2.Framer
578}
579
580func newFramer(conn net.Conn, writeBufferSize, readBufferSize int) *framer {
581	r := bufio.NewReaderSize(conn, readBufferSize)
582	w := newBufWriter(conn, writeBufferSize)
583	f := &framer{
584		writer: w,
585		fr:     http2.NewFramer(w, r),
586	}
587	// Opt-in to Frame reuse API on framer to reduce garbage.
588	// Frames aren't safe to read from after a subsequent call to ReadFrame.
589	f.fr.SetReuseFrames()
590	f.fr.ReadMetaHeaders = hpack.NewDecoder(http2InitHeaderTableSize, nil)
591	return f
592}
593