1// Copyright (c) 2016, Google Inc.
2//
3// Permission to use, copy, modify, and/or distribute this software for any
4// purpose with or without fee is hereby granted, provided that the above
5// copyright notice and this permission notice appear in all copies.
6//
7// THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES
8// WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
9// MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR ANY
10// SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
11// WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN ACTION
12// OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF OR IN
13// CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
14
15package runner
16
17import (
18	"bufio"
19	"encoding/hex"
20	"errors"
21	"fmt"
22	"io"
23	"net"
24	"strconv"
25	"strings"
26	"sync"
27)
28
29type flowType int
30
31const (
32	readFlow flowType = iota
33	writeFlow
34	specialFlow
35)
36
37type flow struct {
38	flowType flowType
39	message  string
40	data     []byte
41}
42
43// recordingConn is a net.Conn that records the traffic that passes through it.
44// WriteTo can be used to produce output that can be later be loaded with
45// ParseTestData.
46type recordingConn struct {
47	net.Conn
48	sync.Mutex
49	flows       []flow
50	isDatagram  bool
51	local, peer string
52}
53
54func (r *recordingConn) appendFlow(flowType flowType, message string, data []byte) {
55	r.Lock()
56	defer r.Unlock()
57
58	if l := len(r.flows); flowType == specialFlow || r.isDatagram || l == 0 || r.flows[l-1].flowType != flowType {
59		buf := make([]byte, len(data))
60		copy(buf, data)
61		r.flows = append(r.flows, flow{flowType, message, buf})
62	} else {
63		r.flows[l-1].data = append(r.flows[l-1].data, data...)
64	}
65}
66
67func (r *recordingConn) Read(b []byte) (n int, err error) {
68	if n, err = r.Conn.Read(b); n == 0 {
69		return
70	}
71	r.appendFlow(readFlow, "", b[:n])
72	return
73}
74
75func (r *recordingConn) Write(b []byte) (n int, err error) {
76	if n, err = r.Conn.Write(b); n == 0 {
77		return
78	}
79	r.appendFlow(writeFlow, "", b[:n])
80	return
81}
82
83// LogSpecial appends an entry to the record of type 'special'.
84func (r *recordingConn) LogSpecial(message string, data []byte) {
85	r.appendFlow(specialFlow, message, data)
86}
87
88// WriteTo writes hex dumps to w that contains the recorded traffic.
89func (r *recordingConn) WriteTo(w io.Writer) {
90	fmt.Fprintf(w, ">>> runner is %s, shim is %s\n", r.local, r.peer)
91	for i, flow := range r.flows {
92		switch flow.flowType {
93		case readFlow:
94			fmt.Fprintf(w, ">>> Flow %d (%s to %s)\n", i+1, r.peer, r.local)
95		case writeFlow:
96			fmt.Fprintf(w, ">>> Flow %d (%s to %s)\n", i+1, r.local, r.peer)
97		case specialFlow:
98			fmt.Fprintf(w, ">>> Flow %d %q\n", i+1, flow.message)
99		}
100
101		if flow.data != nil {
102			dumper := hex.Dumper(w)
103			dumper.Write(flow.data)
104			dumper.Close()
105		}
106	}
107}
108
109func (r *recordingConn) Transcript() []byte {
110	var ret []byte
111	for _, flow := range r.flows {
112		if flow.flowType != writeFlow {
113			continue
114		}
115		ret = append(ret, flow.data...)
116	}
117	return ret
118}
119
120func parseTestData(r io.Reader) (flows [][]byte, err error) {
121	var currentFlow []byte
122
123	scanner := bufio.NewScanner(r)
124	for scanner.Scan() {
125		line := scanner.Text()
126		// If the line starts with ">>> " then it marks the beginning
127		// of a new flow.
128		if strings.HasPrefix(line, ">>> ") {
129			if len(currentFlow) > 0 || len(flows) > 0 {
130				flows = append(flows, currentFlow)
131				currentFlow = nil
132			}
133			continue
134		}
135
136		// Otherwise the line is a line of hex dump that looks like:
137		// 00000170  fc f5 06 bf (...)  |.....X{&?......!|
138		// (Some bytes have been omitted from the middle section.)
139
140		if i := strings.IndexByte(line, ' '); i >= 0 {
141			line = line[i:]
142		} else {
143			return nil, errors.New("invalid test data")
144		}
145
146		if i := strings.IndexByte(line, '|'); i >= 0 {
147			line = line[:i]
148		} else {
149			return nil, errors.New("invalid test data")
150		}
151
152		hexBytes := strings.Fields(line)
153		for _, hexByte := range hexBytes {
154			val, err := strconv.ParseUint(hexByte, 16, 8)
155			if err != nil {
156				return nil, errors.New("invalid hex byte in test data: " + err.Error())
157			}
158			currentFlow = append(currentFlow, byte(val))
159		}
160	}
161
162	if len(currentFlow) > 0 {
163		flows = append(flows, currentFlow)
164	}
165
166	return flows, nil
167}
168