1// Copyright 2016 syzkaller project authors. All rights reserved.
2// Use of this source code is governed by Apache 2 LICENSE that can be found in the LICENSE file.
3
4package prog
5
6import (
7	"bytes"
8	"fmt"
9	"math/rand"
10	"reflect"
11	"regexp"
12	"sort"
13	"strings"
14	"testing"
15)
16
17func setToArray(s map[string]struct{}) []string {
18	a := make([]string, 0, len(s))
19	for c := range s {
20		a = append(a, c)
21	}
22	sort.Strings(a)
23	return a
24}
25
26func TestSerializeData(t *testing.T) {
27	t.Parallel()
28	r := rand.New(rand.NewSource(0))
29	for i := 0; i < 1e4; i++ {
30		data := make([]byte, r.Intn(4))
31		for i := range data {
32			data[i] = byte(r.Intn(256))
33		}
34		buf := new(bytes.Buffer)
35		serializeData(buf, data)
36		p := newParser(buf.Bytes())
37		if !p.Scan() {
38			t.Fatalf("parser does not scan")
39		}
40		data1, err := deserializeData(p)
41		if err != nil {
42			t.Fatalf("failed to deserialize %q -> %s: %v", data, buf.Bytes(), err)
43		}
44		if !bytes.Equal(data, data1) {
45			t.Fatalf("corrupted data %q -> %s -> %q", data, buf.Bytes(), data1)
46		}
47	}
48}
49
50func TestCallSet(t *testing.T) {
51	t.Parallel()
52	tests := []struct {
53		prog  string
54		ok    bool
55		calls []string
56	}{
57		{
58			"",
59			false,
60			[]string{},
61		},
62		{
63			"r0 =  (foo)",
64			false,
65			[]string{},
66		},
67		{
68			"getpid()",
69			true,
70			[]string{"getpid"},
71		},
72		{
73			"r11 =  getpid()",
74			true,
75			[]string{"getpid"},
76		},
77		{
78			"getpid()\n" +
79				"open(0x1, something that this package may not understand)\n" +
80				"getpid()\n" +
81				"#read()\n" +
82				"\n" +
83				"close$foo(&(0x0000) = {})\n",
84			true,
85			[]string{"getpid", "open", "close$foo"},
86		},
87	}
88	for i, test := range tests {
89		t.Run(fmt.Sprint(i), func(t *testing.T) {
90			calls, err := CallSet([]byte(test.prog))
91			if err != nil && test.ok {
92				t.Fatalf("parsing failed: %v", err)
93			}
94			if err == nil && !test.ok {
95				t.Fatalf("parsing did not fail")
96			}
97			callArray := setToArray(calls)
98			sort.Strings(test.calls)
99			if !reflect.DeepEqual(callArray, test.calls) {
100				t.Fatalf("got call set %+v, expect %+v", callArray, test.calls)
101			}
102		})
103	}
104}
105
106func TestCallSetRandom(t *testing.T) {
107	target, rs, iters := initTest(t)
108	for i := 0; i < iters; i++ {
109		p := target.Generate(rs, 10, nil)
110		calls0 := make(map[string]struct{})
111		for _, c := range p.Calls {
112			calls0[c.Meta.Name] = struct{}{}
113		}
114		calls1, err := CallSet(p.Serialize())
115		if err != nil {
116			t.Fatalf("CallSet failed: %v", err)
117		}
118		callArray0 := setToArray(calls0)
119		callArray1 := setToArray(calls1)
120		if !reflect.DeepEqual(callArray0, callArray1) {
121			t.Fatalf("got call set:\n%+v\nexpect:\n%+v", callArray1, callArray0)
122		}
123	}
124}
125
126func TestDeserialize(t *testing.T) {
127	target := initTargetTest(t, "test", "64")
128	tests := []struct {
129		input  string
130		output string
131		err    *regexp.Regexp
132	}{
133		{
134			input: `test$struct(&(0x7f0000000000)={0x0, {0x0}})`,
135		},
136		{
137			input:  `test$struct(&(0x7f0000000000)=0x0)`,
138			output: `test$struct(&(0x7f0000000000))`,
139		},
140		{
141			input: `test$regression1(&(0x7f0000000000)=[{"000000"}, {"0000000000"}])`,
142		},
143		{
144			input: `test$regression2(&(0x7f0000000000)=[0x1, 0x2, 0x3, 0x4, 0x5, 0x6])`,
145		},
146		{
147			input: `test$excessive_args1(0x0, 0x1, {0x1, &(0x7f0000000000)=[0x1, 0x2]})`,
148		},
149		{
150			input: `test$excessive_args2(0x0, 0x1, {0x1, &(0x7f0000000000)={0x1, 0x2}})`,
151		},
152		{
153			input: `test$excessive_args2(0x0, 0x1, {0x1, &(0x7f0000000000)=nil})`,
154		},
155		{
156			input: `test$excessive_args2(0x0, &(0x7f0000000000), 0x0)`,
157		},
158		{
159			input: `test$excessive_fields1(&(0x7f0000000000)={0x1, &(0x7f0000000000)=[{0x0}, 0x2]}, {0x1, 0x2, [0x1, 0x2]})`,
160		},
161		{
162			input:  `test$excessive_fields1(0x0)`,
163			output: `test$excessive_fields1(&(0x7f0000000000))`,
164		},
165		{
166			input:  `test$excessive_fields1(r0)`,
167			output: `test$excessive_fields1(&(0x7f0000000000))`,
168		},
169		{
170			input:  `test$excessive_args2(r1)`,
171			output: `test$excessive_args2(0x0)`,
172		},
173		{
174			input:  `test$excessive_args2({0x0, 0x1})`,
175			output: `test$excessive_args2(0x0)`,
176		},
177		{
178			input:  `test$excessive_args2([0x0], 0x0)`,
179			output: `test$excessive_args2(0x0)`,
180		},
181		{
182			input:  `test$excessive_args2(@foo)`,
183			output: `test$excessive_args2(0x0)`,
184		},
185		{
186			input:  `test$excessive_args2('foo')`,
187			output: `test$excessive_args2(0x0)`,
188		},
189		{
190			input:  `test$excessive_args2(&(0x7f0000000000)={0x0, 0x1})`,
191			output: `test$excessive_args2(0x0)`,
192		},
193		{
194			input:  `test$excessive_args2(nil)`,
195			output: `test$excessive_args2(0x0)`,
196		},
197		{
198			input:  `test$type_confusion1(&(0x7f0000000000)=@unknown)`,
199			output: `test$type_confusion1(&(0x7f0000000000))`,
200		},
201		{
202			input:  `test$type_confusion1(&(0x7f0000000000)=@unknown={0x0, 'abc'}, 0x0)`,
203			output: `test$type_confusion1(&(0x7f0000000000))`,
204		},
205		{
206			input:  `test$excessive_fields1(&(0x7f0000000000)=0x0)`,
207			output: `test$excessive_fields1(&(0x7f0000000000))`,
208		},
209	}
210	buf := make([]byte, ExecBufferSize)
211	for _, test := range tests {
212		p, err := target.Deserialize([]byte(test.input))
213		if err != nil {
214			if test.err == nil {
215				t.Fatalf("deserialization failed with\n%s\ndata:\n%s\n", err, test.input)
216			}
217			if !test.err.MatchString(err.Error()) {
218				t.Fatalf("deserialization failed with\n%s\nwhich doesn't match\n%s\ndata:\n%s",
219					err, test.err, test.input)
220			}
221			if test.output != "" {
222				t.Fatalf("both err and output are set")
223			}
224		} else {
225			if test.err != nil {
226				t.Fatalf("deserialization should have failed with:\n%s\ndata:\n%s\n",
227					test.err, test.input)
228			}
229			output := strings.TrimSpace(string(p.Serialize()))
230			if test.output != "" && test.output != output {
231				t.Fatalf("wrong serialized data:\n%s\nexpect:\n%s\n",
232					output, test.output)
233			}
234			p.SerializeForExec(buf)
235		}
236	}
237}
238
239func TestSerializeDeserialize(t *testing.T) {
240	target := initTargetTest(t, "test", "64")
241	tests := [][2]string{
242		{
243			`serialize0(&(0x7f0000408000)={"6861736800000000000000000000", "4849000000"})`,
244			`serialize0(&(0x7f0000408000)={'hash\x00', 'HI\x00'})`,
245		},
246		{
247			`serialize1(&(0x7f0000000000)="0000000000000000", 0x8)`,
248			`serialize1(&(0x7f0000000000)=""/8, 0x8)`,
249		},
250	}
251	for _, test := range tests {
252		p, err := target.Deserialize([]byte(test[0]))
253		if err != nil {
254			t.Fatal(err)
255		}
256		data := p.Serialize()
257		test[1] += "\n"
258		if string(data) != test[1] {
259			t.Fatalf("\ngot : %s\nwant: %s", data, test[1])
260		}
261	}
262}
263
264func TestSerializeDeserializeRandom(t *testing.T) {
265	testEachTargetRandom(t, func(t *testing.T, target *Target, rs rand.Source, iters int) {
266		data0 := make([]byte, ExecBufferSize)
267		data1 := make([]byte, ExecBufferSize)
268		for i := 0; i < iters; i++ {
269			p0 := target.Generate(rs, 10, nil)
270			if ok, _, _ := testSerializeDeserialize(t, p0, data0, data1); ok {
271				continue
272			}
273			p0, _ = Minimize(p0, -1, false, func(p1 *Prog, _ int) bool {
274				ok, _, _ := testSerializeDeserialize(t, p1, data0, data1)
275				return !ok
276			})
277			ok, n0, n1 := testSerializeDeserialize(t, p0, data0, data1)
278			if ok {
279				t.Fatal("flaky?")
280			}
281			t.Fatalf("was: %q\ngot: %q\nprogram:\n%s",
282				data0[:n0], data1[:n1], p0.Serialize())
283		}
284	})
285}
286
287func testSerializeDeserialize(t *testing.T, p0 *Prog, data0, data1 []byte) (bool, int, int) {
288	n0, err := p0.SerializeForExec(data0)
289	if err != nil {
290		t.Fatal(err)
291	}
292	serialized := p0.Serialize()
293	p1, err := p0.Target.Deserialize(serialized)
294	if err != nil {
295		t.Fatal(err)
296	}
297	n1, err := p1.SerializeForExec(data1)
298	if err != nil {
299		t.Fatal(err)
300	}
301	if !bytes.Equal(data0[:n0], data1[:n1]) {
302		return false, n0, n1
303	}
304	return true, 0, 0
305}
306
307func TestDeserializeComments(t *testing.T) {
308	target := initTargetTest(t, "test", "64")
309	p, err := target.Deserialize([]byte(`
310# comment1
311# comment2
312serialize0()
313serialize0()
314# comment3
315serialize0()
316# comment4
317serialize0()	#  comment5
318#comment6
319
320serialize0()
321#comment7
322`))
323	if err != nil {
324		t.Fatal(err)
325	}
326	for i, want := range []string{
327		"comment2",
328		"",
329		"comment3",
330		"comment5",
331		"",
332	} {
333		if got := p.Calls[i].Comment; got != want {
334			t.Errorf("bad call %v comment: %q, want %q", i, got, want)
335		}
336	}
337	wantComments := []string{
338		"comment1",
339		"comment4",
340		"comment6",
341		"comment7",
342	}
343	if !reflect.DeepEqual(p.Comments, wantComments) {
344		t.Errorf("bad program comments %q\nwant: %q", p.Comments, wantComments)
345	}
346}
347