1package runner
2
3import (
4	"bufio"
5	"encoding/hex"
6	"errors"
7	"fmt"
8	"io"
9	"net"
10	"strconv"
11	"strings"
12	"sync"
13)
14
15type flowType int
16
17const (
18	readFlow flowType = iota
19	writeFlow
20	specialFlow
21)
22
23type flow struct {
24	flowType flowType
25	message  string
26	data     []byte
27}
28
29// recordingConn is a net.Conn that records the traffic that passes through it.
30// WriteTo can be used to produce output that can be later be loaded with
31// ParseTestData.
32type recordingConn struct {
33	net.Conn
34	sync.Mutex
35	flows       []flow
36	isDatagram  bool
37	local, peer string
38}
39
40func (r *recordingConn) appendFlow(flowType flowType, message string, data []byte) {
41	r.Lock()
42	defer r.Unlock()
43
44	if l := len(r.flows); flowType == specialFlow || r.isDatagram || l == 0 || r.flows[l-1].flowType != flowType {
45		buf := make([]byte, len(data))
46		copy(buf, data)
47		r.flows = append(r.flows, flow{flowType, message, buf})
48	} else {
49		r.flows[l-1].data = append(r.flows[l-1].data, data...)
50	}
51}
52
53func (r *recordingConn) Read(b []byte) (n int, err error) {
54	if n, err = r.Conn.Read(b); n == 0 {
55		return
56	}
57	r.appendFlow(readFlow, "", b[:n])
58	return
59}
60
61func (r *recordingConn) Write(b []byte) (n int, err error) {
62	if n, err = r.Conn.Write(b); n == 0 {
63		return
64	}
65	r.appendFlow(writeFlow, "", b[:n])
66	return
67}
68
69// LogSpecial appends an entry to the record of type 'special'.
70func (r *recordingConn) LogSpecial(message string, data []byte) {
71	r.appendFlow(specialFlow, message, data)
72}
73
74// WriteTo writes hex dumps to w that contains the recorded traffic.
75func (r *recordingConn) WriteTo(w io.Writer) {
76	fmt.Fprintf(w, ">>> runner is %s, shim is %s\n", r.local, r.peer)
77	for i, flow := range r.flows {
78		switch flow.flowType {
79		case readFlow:
80			fmt.Fprintf(w, ">>> Flow %d (%s to %s)\n", i+1, r.peer, r.local)
81		case writeFlow:
82			fmt.Fprintf(w, ">>> Flow %d (%s to %s)\n", i+1, r.local, r.peer)
83		case specialFlow:
84			fmt.Fprintf(w, ">>> Flow %d %q\n", i+1, flow.message)
85		}
86
87		if flow.data != nil {
88			dumper := hex.Dumper(w)
89			dumper.Write(flow.data)
90			dumper.Close()
91		}
92	}
93}
94
95func parseTestData(r io.Reader) (flows [][]byte, err error) {
96	var currentFlow []byte
97
98	scanner := bufio.NewScanner(r)
99	for scanner.Scan() {
100		line := scanner.Text()
101		// If the line starts with ">>> " then it marks the beginning
102		// of a new flow.
103		if strings.HasPrefix(line, ">>> ") {
104			if len(currentFlow) > 0 || len(flows) > 0 {
105				flows = append(flows, currentFlow)
106				currentFlow = nil
107			}
108			continue
109		}
110
111		// Otherwise the line is a line of hex dump that looks like:
112		// 00000170  fc f5 06 bf (...)  |.....X{&?......!|
113		// (Some bytes have been omitted from the middle section.)
114
115		if i := strings.IndexByte(line, ' '); i >= 0 {
116			line = line[i:]
117		} else {
118			return nil, errors.New("invalid test data")
119		}
120
121		if i := strings.IndexByte(line, '|'); i >= 0 {
122			line = line[:i]
123		} else {
124			return nil, errors.New("invalid test data")
125		}
126
127		hexBytes := strings.Fields(line)
128		for _, hexByte := range hexBytes {
129			val, err := strconv.ParseUint(hexByte, 16, 8)
130			if err != nil {
131				return nil, errors.New("invalid hex byte in test data: " + err.Error())
132			}
133			currentFlow = append(currentFlow, byte(val))
134		}
135	}
136
137	if len(currentFlow) > 0 {
138		flows = append(flows, currentFlow)
139	}
140
141	return flows, nil
142}
143