1/*
2 *
3 * Copyright 2014 gRPC authors.
4 *
5 * Licensed under the Apache License, Version 2.0 (the "License");
6 * you may not use this file except in compliance with the License.
7 * You may obtain a copy of the License at
8 *
9 *     http://www.apache.org/licenses/LICENSE-2.0
10 *
11 * Unless required by applicable law or agreed to in writing, software
12 * distributed under the License is distributed on an "AS IS" BASIS,
13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 * See the License for the specific language governing permissions and
15 * limitations under the License.
16 *
17 */
18
19package transport
20
21import (
22	"bytes"
23	"fmt"
24	"runtime"
25	"sync"
26
27	"golang.org/x/net/http2"
28	"golang.org/x/net/http2/hpack"
29)
30
31var updateHeaderTblSize = func(e *hpack.Encoder, v uint32) {
32	e.SetMaxDynamicTableSizeLimit(v)
33}
34
35type itemNode struct {
36	it   interface{}
37	next *itemNode
38}
39
40type itemList struct {
41	head *itemNode
42	tail *itemNode
43}
44
45func (il *itemList) enqueue(i interface{}) {
46	n := &itemNode{it: i}
47	if il.tail == nil {
48		il.head, il.tail = n, n
49		return
50	}
51	il.tail.next = n
52	il.tail = n
53}
54
55// peek returns the first item in the list without removing it from the
56// list.
57func (il *itemList) peek() interface{} {
58	return il.head.it
59}
60
61func (il *itemList) dequeue() interface{} {
62	if il.head == nil {
63		return nil
64	}
65	i := il.head.it
66	il.head = il.head.next
67	if il.head == nil {
68		il.tail = nil
69	}
70	return i
71}
72
73func (il *itemList) dequeueAll() *itemNode {
74	h := il.head
75	il.head, il.tail = nil, nil
76	return h
77}
78
79func (il *itemList) isEmpty() bool {
80	return il.head == nil
81}
82
83// The following defines various control items which could flow through
84// the control buffer of transport. They represent different aspects of
85// control tasks, e.g., flow control, settings, streaming resetting, etc.
86
87// registerStream is used to register an incoming stream with loopy writer.
88type registerStream struct {
89	streamID uint32
90	wq       *writeQuota
91}
92
93// headerFrame is also used to register stream on the client-side.
94type headerFrame struct {
95	streamID   uint32
96	hf         []hpack.HeaderField
97	endStream  bool                       // Valid on server side.
98	initStream func(uint32) (bool, error) // Used only on the client side.
99	onWrite    func()
100	wq         *writeQuota    // write quota for the stream created.
101	cleanup    *cleanupStream // Valid on the server side.
102	onOrphaned func(error)    // Valid on client-side
103}
104
105type cleanupStream struct {
106	streamID uint32
107	idPtr    *uint32
108	rst      bool
109	rstCode  http2.ErrCode
110	onWrite  func()
111}
112
113type dataFrame struct {
114	streamID  uint32
115	endStream bool
116	h         []byte
117	d         []byte
118	// onEachWrite is called every time
119	// a part of d is written out.
120	onEachWrite func()
121}
122
123type incomingWindowUpdate struct {
124	streamID  uint32
125	increment uint32
126}
127
128type outgoingWindowUpdate struct {
129	streamID  uint32
130	increment uint32
131}
132
133type incomingSettings struct {
134	ss []http2.Setting
135}
136
137type outgoingSettings struct {
138	ss []http2.Setting
139}
140
141type settingsAck struct {
142}
143
144type incomingGoAway struct {
145}
146
147type goAway struct {
148	code      http2.ErrCode
149	debugData []byte
150	headsUp   bool
151	closeConn bool
152}
153
154type ping struct {
155	ack  bool
156	data [8]byte
157}
158
159type outFlowControlSizeRequest struct {
160	resp chan uint32
161}
162
163type outStreamState int
164
165const (
166	active outStreamState = iota
167	empty
168	waitingOnStreamQuota
169)
170
171type outStream struct {
172	id               uint32
173	state            outStreamState
174	itl              *itemList
175	bytesOutStanding int
176	wq               *writeQuota
177
178	next *outStream
179	prev *outStream
180}
181
182func (s *outStream) deleteSelf() {
183	if s.prev != nil {
184		s.prev.next = s.next
185	}
186	if s.next != nil {
187		s.next.prev = s.prev
188	}
189	s.next, s.prev = nil, nil
190}
191
192type outStreamList struct {
193	// Following are sentinel objects that mark the
194	// beginning and end of the list. They do not
195	// contain any item lists. All valid objects are
196	// inserted in between them.
197	// This is needed so that an outStream object can
198	// deleteSelf() in O(1) time without knowing which
199	// list it belongs to.
200	head *outStream
201	tail *outStream
202}
203
204func newOutStreamList() *outStreamList {
205	head, tail := new(outStream), new(outStream)
206	head.next = tail
207	tail.prev = head
208	return &outStreamList{
209		head: head,
210		tail: tail,
211	}
212}
213
214func (l *outStreamList) enqueue(s *outStream) {
215	e := l.tail.prev
216	e.next = s
217	s.prev = e
218	s.next = l.tail
219	l.tail.prev = s
220}
221
222// remove from the beginning of the list.
223func (l *outStreamList) dequeue() *outStream {
224	b := l.head.next
225	if b == l.tail {
226		return nil
227	}
228	b.deleteSelf()
229	return b
230}
231
232type controlBuffer struct {
233	ch              chan struct{}
234	done            <-chan struct{}
235	mu              sync.Mutex
236	consumerWaiting bool
237	list            *itemList
238	err             error
239}
240
241func newControlBuffer(done <-chan struct{}) *controlBuffer {
242	return &controlBuffer{
243		ch:   make(chan struct{}, 1),
244		list: &itemList{},
245		done: done,
246	}
247}
248
249func (c *controlBuffer) put(it interface{}) error {
250	_, err := c.executeAndPut(nil, it)
251	return err
252}
253
254func (c *controlBuffer) executeAndPut(f func(it interface{}) bool, it interface{}) (bool, error) {
255	var wakeUp bool
256	c.mu.Lock()
257	if c.err != nil {
258		c.mu.Unlock()
259		return false, c.err
260	}
261	if f != nil {
262		if !f(it) { // f wasn't successful
263			c.mu.Unlock()
264			return false, nil
265		}
266	}
267	if c.consumerWaiting {
268		wakeUp = true
269		c.consumerWaiting = false
270	}
271	c.list.enqueue(it)
272	c.mu.Unlock()
273	if wakeUp {
274		select {
275		case c.ch <- struct{}{}:
276		default:
277		}
278	}
279	return true, nil
280}
281
282func (c *controlBuffer) get(block bool) (interface{}, error) {
283	for {
284		c.mu.Lock()
285		if c.err != nil {
286			c.mu.Unlock()
287			return nil, c.err
288		}
289		if !c.list.isEmpty() {
290			h := c.list.dequeue()
291			c.mu.Unlock()
292			return h, nil
293		}
294		if !block {
295			c.mu.Unlock()
296			return nil, nil
297		}
298		c.consumerWaiting = true
299		c.mu.Unlock()
300		select {
301		case <-c.ch:
302		case <-c.done:
303			c.finish()
304			return nil, ErrConnClosing
305		}
306	}
307}
308
309func (c *controlBuffer) finish() {
310	c.mu.Lock()
311	if c.err != nil {
312		c.mu.Unlock()
313		return
314	}
315	c.err = ErrConnClosing
316	// There may be headers for streams in the control buffer.
317	// These streams need to be cleaned out since the transport
318	// is still not aware of these yet.
319	for head := c.list.dequeueAll(); head != nil; head = head.next {
320		hdr, ok := head.it.(*headerFrame)
321		if !ok {
322			continue
323		}
324		if hdr.onOrphaned != nil { // It will be nil on the server-side.
325			hdr.onOrphaned(ErrConnClosing)
326		}
327	}
328	c.mu.Unlock()
329}
330
331type side int
332
333const (
334	clientSide side = iota
335	serverSide
336)
337
338type loopyWriter struct {
339	side          side
340	cbuf          *controlBuffer
341	sendQuota     uint32
342	oiws          uint32                // outbound initial window size.
343	estdStreams   map[uint32]*outStream // Established streams.
344	activeStreams *outStreamList        // Streams that are sending data.
345	framer        *framer
346	hBuf          *bytes.Buffer  // The buffer for HPACK encoding.
347	hEnc          *hpack.Encoder // HPACK encoder.
348	bdpEst        *bdpEstimator
349	draining      bool
350
351	// Side-specific handlers
352	ssGoAwayHandler func(*goAway) (bool, error)
353}
354
355func newLoopyWriter(s side, fr *framer, cbuf *controlBuffer, bdpEst *bdpEstimator) *loopyWriter {
356	var buf bytes.Buffer
357	l := &loopyWriter{
358		side:          s,
359		cbuf:          cbuf,
360		sendQuota:     defaultWindowSize,
361		oiws:          defaultWindowSize,
362		estdStreams:   make(map[uint32]*outStream),
363		activeStreams: newOutStreamList(),
364		framer:        fr,
365		hBuf:          &buf,
366		hEnc:          hpack.NewEncoder(&buf),
367		bdpEst:        bdpEst,
368	}
369	return l
370}
371
372const minBatchSize = 1000
373
374// run should be run in a separate goroutine.
375func (l *loopyWriter) run() (err error) {
376	defer func() {
377		if err == ErrConnClosing {
378			// Don't log ErrConnClosing as error since it happens
379			// 1. When the connection is closed by some other known issue.
380			// 2. User closed the connection.
381			// 3. A graceful close of connection.
382			infof("transport: loopyWriter.run returning. %v", err)
383			err = nil
384		}
385	}()
386	for {
387		it, err := l.cbuf.get(true)
388		if err != nil {
389			return err
390		}
391		if err = l.handle(it); err != nil {
392			return err
393		}
394		if _, err = l.processData(); err != nil {
395			return err
396		}
397		gosched := true
398	hasdata:
399		for {
400			it, err := l.cbuf.get(false)
401			if err != nil {
402				return err
403			}
404			if it != nil {
405				if err = l.handle(it); err != nil {
406					return err
407				}
408				if _, err = l.processData(); err != nil {
409					return err
410				}
411				continue hasdata
412			}
413			isEmpty, err := l.processData()
414			if err != nil {
415				return err
416			}
417			if !isEmpty {
418				continue hasdata
419			}
420			if gosched {
421				gosched = false
422				if l.framer.writer.offset < minBatchSize {
423					runtime.Gosched()
424					continue hasdata
425				}
426			}
427			l.framer.writer.Flush()
428			break hasdata
429
430		}
431	}
432}
433
434func (l *loopyWriter) outgoingWindowUpdateHandler(w *outgoingWindowUpdate) error {
435	return l.framer.fr.WriteWindowUpdate(w.streamID, w.increment)
436}
437
438func (l *loopyWriter) incomingWindowUpdateHandler(w *incomingWindowUpdate) error {
439	// Otherwise update the quota.
440	if w.streamID == 0 {
441		l.sendQuota += w.increment
442		return nil
443	}
444	// Find the stream and update it.
445	if str, ok := l.estdStreams[w.streamID]; ok {
446		str.bytesOutStanding -= int(w.increment)
447		if strQuota := int(l.oiws) - str.bytesOutStanding; strQuota > 0 && str.state == waitingOnStreamQuota {
448			str.state = active
449			l.activeStreams.enqueue(str)
450			return nil
451		}
452	}
453	return nil
454}
455
456func (l *loopyWriter) outgoingSettingsHandler(s *outgoingSettings) error {
457	return l.framer.fr.WriteSettings(s.ss...)
458}
459
460func (l *loopyWriter) incomingSettingsHandler(s *incomingSettings) error {
461	if err := l.applySettings(s.ss); err != nil {
462		return err
463	}
464	return l.framer.fr.WriteSettingsAck()
465}
466
467func (l *loopyWriter) registerStreamHandler(h *registerStream) error {
468	str := &outStream{
469		id:    h.streamID,
470		state: empty,
471		itl:   &itemList{},
472		wq:    h.wq,
473	}
474	l.estdStreams[h.streamID] = str
475	return nil
476}
477
478func (l *loopyWriter) headerHandler(h *headerFrame) error {
479	if l.side == serverSide {
480		str, ok := l.estdStreams[h.streamID]
481		if !ok {
482			warningf("transport: loopy doesn't recognize the stream: %d", h.streamID)
483			return nil
484		}
485		// Case 1.A: Server is responding back with headers.
486		if !h.endStream {
487			return l.writeHeader(h.streamID, h.endStream, h.hf, h.onWrite)
488		}
489		// else:  Case 1.B: Server wants to close stream.
490
491		if str.state != empty { // either active or waiting on stream quota.
492			// add it str's list of items.
493			str.itl.enqueue(h)
494			return nil
495		}
496		if err := l.writeHeader(h.streamID, h.endStream, h.hf, h.onWrite); err != nil {
497			return err
498		}
499		return l.cleanupStreamHandler(h.cleanup)
500	}
501	// Case 2: Client wants to originate stream.
502	str := &outStream{
503		id:    h.streamID,
504		state: empty,
505		itl:   &itemList{},
506		wq:    h.wq,
507	}
508	str.itl.enqueue(h)
509	return l.originateStream(str)
510}
511
512func (l *loopyWriter) originateStream(str *outStream) error {
513	hdr := str.itl.dequeue().(*headerFrame)
514	sendPing, err := hdr.initStream(str.id)
515	if err != nil {
516		if err == ErrConnClosing {
517			return err
518		}
519		// Other errors(errStreamDrain) need not close transport.
520		return nil
521	}
522	if err = l.writeHeader(str.id, hdr.endStream, hdr.hf, hdr.onWrite); err != nil {
523		return err
524	}
525	l.estdStreams[str.id] = str
526	if sendPing {
527		return l.pingHandler(&ping{data: [8]byte{}})
528	}
529	return nil
530}
531
532func (l *loopyWriter) writeHeader(streamID uint32, endStream bool, hf []hpack.HeaderField, onWrite func()) error {
533	if onWrite != nil {
534		onWrite()
535	}
536	l.hBuf.Reset()
537	for _, f := range hf {
538		if err := l.hEnc.WriteField(f); err != nil {
539			warningf("transport: loopyWriter.writeHeader encountered error while encoding headers:", err)
540		}
541	}
542	var (
543		err               error
544		endHeaders, first bool
545	)
546	first = true
547	for !endHeaders {
548		size := l.hBuf.Len()
549		if size > http2MaxFrameLen {
550			size = http2MaxFrameLen
551		} else {
552			endHeaders = true
553		}
554		if first {
555			first = false
556			err = l.framer.fr.WriteHeaders(http2.HeadersFrameParam{
557				StreamID:      streamID,
558				BlockFragment: l.hBuf.Next(size),
559				EndStream:     endStream,
560				EndHeaders:    endHeaders,
561			})
562		} else {
563			err = l.framer.fr.WriteContinuation(
564				streamID,
565				endHeaders,
566				l.hBuf.Next(size),
567			)
568		}
569		if err != nil {
570			return err
571		}
572	}
573	return nil
574}
575
576func (l *loopyWriter) preprocessData(df *dataFrame) error {
577	str, ok := l.estdStreams[df.streamID]
578	if !ok {
579		return nil
580	}
581	// If we got data for a stream it means that
582	// stream was originated and the headers were sent out.
583	str.itl.enqueue(df)
584	if str.state == empty {
585		str.state = active
586		l.activeStreams.enqueue(str)
587	}
588	return nil
589}
590
591func (l *loopyWriter) pingHandler(p *ping) error {
592	if !p.ack {
593		l.bdpEst.timesnap(p.data)
594	}
595	return l.framer.fr.WritePing(p.ack, p.data)
596
597}
598
599func (l *loopyWriter) outFlowControlSizeRequestHandler(o *outFlowControlSizeRequest) error {
600	o.resp <- l.sendQuota
601	return nil
602}
603
604func (l *loopyWriter) cleanupStreamHandler(c *cleanupStream) error {
605	c.onWrite()
606	if str, ok := l.estdStreams[c.streamID]; ok {
607		// On the server side it could be a trailers-only response or
608		// a RST_STREAM before stream initialization thus the stream might
609		// not be established yet.
610		delete(l.estdStreams, c.streamID)
611		str.deleteSelf()
612	}
613	if c.rst { // If RST_STREAM needs to be sent.
614		if err := l.framer.fr.WriteRSTStream(c.streamID, c.rstCode); err != nil {
615			return err
616		}
617	}
618	if l.side == clientSide && l.draining && len(l.estdStreams) == 0 {
619		return ErrConnClosing
620	}
621	return nil
622}
623
624func (l *loopyWriter) incomingGoAwayHandler(*incomingGoAway) error {
625	if l.side == clientSide {
626		l.draining = true
627		if len(l.estdStreams) == 0 {
628			return ErrConnClosing
629		}
630	}
631	return nil
632}
633
634func (l *loopyWriter) goAwayHandler(g *goAway) error {
635	// Handling of outgoing GoAway is very specific to side.
636	if l.ssGoAwayHandler != nil {
637		draining, err := l.ssGoAwayHandler(g)
638		if err != nil {
639			return err
640		}
641		l.draining = draining
642	}
643	return nil
644}
645
646func (l *loopyWriter) handle(i interface{}) error {
647	switch i := i.(type) {
648	case *incomingWindowUpdate:
649		return l.incomingWindowUpdateHandler(i)
650	case *outgoingWindowUpdate:
651		return l.outgoingWindowUpdateHandler(i)
652	case *incomingSettings:
653		return l.incomingSettingsHandler(i)
654	case *outgoingSettings:
655		return l.outgoingSettingsHandler(i)
656	case *headerFrame:
657		return l.headerHandler(i)
658	case *registerStream:
659		return l.registerStreamHandler(i)
660	case *cleanupStream:
661		return l.cleanupStreamHandler(i)
662	case *incomingGoAway:
663		return l.incomingGoAwayHandler(i)
664	case *dataFrame:
665		return l.preprocessData(i)
666	case *ping:
667		return l.pingHandler(i)
668	case *goAway:
669		return l.goAwayHandler(i)
670	case *outFlowControlSizeRequest:
671		return l.outFlowControlSizeRequestHandler(i)
672	default:
673		return fmt.Errorf("transport: unknown control message type %T", i)
674	}
675}
676
677func (l *loopyWriter) applySettings(ss []http2.Setting) error {
678	for _, s := range ss {
679		switch s.ID {
680		case http2.SettingInitialWindowSize:
681			o := l.oiws
682			l.oiws = s.Val
683			if o < l.oiws {
684				// If the new limit is greater make all depleted streams active.
685				for _, stream := range l.estdStreams {
686					if stream.state == waitingOnStreamQuota {
687						stream.state = active
688						l.activeStreams.enqueue(stream)
689					}
690				}
691			}
692		case http2.SettingHeaderTableSize:
693			updateHeaderTblSize(l.hEnc, s.Val)
694		}
695	}
696	return nil
697}
698
699func (l *loopyWriter) processData() (bool, error) {
700	if l.sendQuota == 0 {
701		return true, nil
702	}
703	str := l.activeStreams.dequeue()
704	if str == nil {
705		return true, nil
706	}
707	dataItem := str.itl.peek().(*dataFrame)
708	if len(dataItem.h) == 0 && len(dataItem.d) == 0 {
709		// Client sends out empty data frame with endStream = true
710		if err := l.framer.fr.WriteData(dataItem.streamID, dataItem.endStream, nil); err != nil {
711			return false, err
712		}
713		str.itl.dequeue()
714		if str.itl.isEmpty() {
715			str.state = empty
716		} else if trailer, ok := str.itl.peek().(*headerFrame); ok { // the next item is trailers.
717			if err := l.writeHeader(trailer.streamID, trailer.endStream, trailer.hf, trailer.onWrite); err != nil {
718				return false, err
719			}
720			if err := l.cleanupStreamHandler(trailer.cleanup); err != nil {
721				return false, nil
722			}
723		} else {
724			l.activeStreams.enqueue(str)
725		}
726		return false, nil
727	}
728	var (
729		idx int
730		buf []byte
731	)
732	if len(dataItem.h) != 0 { // data header has not been written out yet.
733		buf = dataItem.h
734	} else {
735		idx = 1
736		buf = dataItem.d
737	}
738	size := http2MaxFrameLen
739	if len(buf) < size {
740		size = len(buf)
741	}
742	if strQuota := int(l.oiws) - str.bytesOutStanding; strQuota <= 0 {
743		str.state = waitingOnStreamQuota
744		return false, nil
745	} else if strQuota < size {
746		size = strQuota
747	}
748
749	if l.sendQuota < uint32(size) {
750		size = int(l.sendQuota)
751	}
752	// Now that outgoing flow controls are checked we can replenish str's write quota
753	str.wq.replenish(size)
754	var endStream bool
755	// This last data message on this stream and all
756	// of it can be written in this go.
757	if dataItem.endStream && size == len(buf) {
758		// buf contains either data or it contains header but data is empty.
759		if idx == 1 || len(dataItem.d) == 0 {
760			endStream = true
761		}
762	}
763	if dataItem.onEachWrite != nil {
764		dataItem.onEachWrite()
765	}
766	if err := l.framer.fr.WriteData(dataItem.streamID, endStream, buf[:size]); err != nil {
767		return false, err
768	}
769	buf = buf[size:]
770	str.bytesOutStanding += size
771	l.sendQuota -= uint32(size)
772	if idx == 0 {
773		dataItem.h = buf
774	} else {
775		dataItem.d = buf
776	}
777
778	if len(dataItem.h) == 0 && len(dataItem.d) == 0 { // All the data from that message was written out.
779		str.itl.dequeue()
780	}
781	if str.itl.isEmpty() {
782		str.state = empty
783	} else if trailer, ok := str.itl.peek().(*headerFrame); ok { // The next item is trailers.
784		if err := l.writeHeader(trailer.streamID, trailer.endStream, trailer.hf, trailer.onWrite); err != nil {
785			return false, err
786		}
787		if err := l.cleanupStreamHandler(trailer.cleanup); err != nil {
788			return false, err
789		}
790	} else if int(l.oiws)-str.bytesOutStanding <= 0 { // Ran out of stream quota.
791		str.state = waitingOnStreamQuota
792	} else { // Otherwise add it back to the list of active streams.
793		l.activeStreams.enqueue(str)
794	}
795	return false, nil
796}
797