1// Copyright 2015 syzkaller project authors. All rights reserved.
2// Use of this source code is governed by Apache 2 LICENSE that can be found in the LICENSE file.
3
4// Package csource generates [almost] equivalent C programs from syzkaller programs.
5package csource
6
7import (
8	"bytes"
9	"fmt"
10	"regexp"
11	"sort"
12	"strings"
13
14	"github.com/google/syzkaller/prog"
15	"github.com/google/syzkaller/sys/targets"
16)
17
18func Write(p *prog.Prog, opts Options) ([]byte, error) {
19	if err := opts.Check(p.Target.OS); err != nil {
20		return nil, fmt.Errorf("csource: invalid opts: %v", err)
21	}
22	ctx := &context{
23		p:         p,
24		opts:      opts,
25		target:    p.Target,
26		sysTarget: targets.Get(p.Target.OS, p.Target.Arch),
27		calls:     make(map[string]uint64),
28	}
29
30	calls, vars, err := ctx.generateProgCalls(ctx.p, opts.Trace)
31	if err != nil {
32		return nil, err
33	}
34
35	mmapProg := p.Target.GenerateUberMmapProg()
36	mmapCalls, _, err := ctx.generateProgCalls(mmapProg, false)
37	if err != nil {
38		return nil, err
39	}
40
41	for _, c := range append(mmapProg.Calls, p.Calls...) {
42		ctx.calls[c.Meta.CallName] = c.Meta.NR
43	}
44
45	varsBuf := new(bytes.Buffer)
46	if len(vars) != 0 {
47		fmt.Fprintf(varsBuf, "uint64 r[%v] = {", len(vars))
48		for i, v := range vars {
49			if i != 0 {
50				fmt.Fprintf(varsBuf, ", ")
51			}
52			fmt.Fprintf(varsBuf, "0x%x", v)
53		}
54		fmt.Fprintf(varsBuf, "};\n")
55	}
56
57	sandboxFunc := "loop();"
58	if opts.Sandbox != "" {
59		sandboxFunc = "do_sandbox_" + opts.Sandbox + "();"
60	}
61	replacements := map[string]string{
62		"PROCS":           fmt.Sprint(opts.Procs),
63		"REPEAT_TIMES":    fmt.Sprint(opts.RepeatTimes),
64		"NUM_CALLS":       fmt.Sprint(len(p.Calls)),
65		"MMAP_DATA":       strings.Join(mmapCalls, ""),
66		"SYSCALL_DEFINES": ctx.generateSyscallDefines(),
67		"SANDBOX_FUNC":    sandboxFunc,
68		"RESULTS":         varsBuf.String(),
69		"SYSCALLS":        ctx.generateSyscalls(calls, len(vars) != 0),
70	}
71	if !opts.Threaded && !opts.Repeat && opts.Sandbox == "" {
72		// This inlines syscalls right into main for the simplest case.
73		replacements["SANDBOX_FUNC"] = replacements["SYSCALLS"]
74		replacements["SYSCALLS"] = "unused"
75	}
76	result, err := createCommonHeader(p, mmapProg, replacements, opts)
77	if err != nil {
78		return nil, err
79	}
80	const header = "// autogenerated by syzkaller (https://github.com/google/syzkaller)\n\n"
81	result = append([]byte(header), result...)
82	result = ctx.postProcess(result)
83	return result, nil
84}
85
86type context struct {
87	p         *prog.Prog
88	opts      Options
89	target    *prog.Target
90	sysTarget *targets.Target
91	calls     map[string]uint64 // CallName -> NR
92}
93
94func (ctx *context) generateSyscalls(calls []string, hasVars bool) string {
95	opts := ctx.opts
96	buf := new(bytes.Buffer)
97	if !opts.Threaded && !opts.Collide {
98		if hasVars || opts.Trace {
99			fmt.Fprintf(buf, "\tlong res = 0;\n")
100		}
101		if opts.Repro {
102			fmt.Fprintf(buf, "\tif (write(1, \"executing program\\n\", sizeof(\"executing program\\n\") - 1)) {}\n")
103		}
104		if opts.Trace {
105			fmt.Fprintf(buf, "\tprintf(\"### start\\n\");\n")
106		}
107		for _, c := range calls {
108			fmt.Fprintf(buf, "%s", c)
109		}
110	} else {
111		if hasVars || opts.Trace {
112			fmt.Fprintf(buf, "\tlong res;")
113		}
114		fmt.Fprintf(buf, "\tswitch (call) {\n")
115		for i, c := range calls {
116			fmt.Fprintf(buf, "\tcase %v:\n", i)
117			fmt.Fprintf(buf, "%s", strings.Replace(c, "\t", "\t\t", -1))
118			fmt.Fprintf(buf, "\t\tbreak;\n")
119		}
120		fmt.Fprintf(buf, "\t}\n")
121	}
122	return buf.String()
123}
124
125func (ctx *context) generateSyscallDefines() string {
126	var calls []string
127	for name, nr := range ctx.calls {
128		if !ctx.sysTarget.SyscallNumbers ||
129			strings.HasPrefix(name, "syz_") || !ctx.sysTarget.NeedSyscallDefine(nr) {
130			continue
131		}
132		calls = append(calls, name)
133	}
134	sort.Strings(calls)
135	buf := new(bytes.Buffer)
136	prefix := ctx.sysTarget.SyscallPrefix
137	for _, name := range calls {
138		fmt.Fprintf(buf, "#ifndef %v%v\n", prefix, name)
139		fmt.Fprintf(buf, "#define %v%v %v\n", prefix, name, ctx.calls[name])
140		fmt.Fprintf(buf, "#endif\n")
141	}
142	if ctx.target.OS == "linux" && ctx.target.PtrSize == 4 {
143		// This is a dirty hack.
144		// On 32-bit linux mmap translated to old_mmap syscall which has a different signature.
145		// mmap2 has the right signature. syz-extract translates mmap to mmap2, do the same here.
146		fmt.Fprintf(buf, "#undef __NR_mmap\n")
147		fmt.Fprintf(buf, "#define __NR_mmap __NR_mmap2\n")
148	}
149	return buf.String()
150}
151
152func (ctx *context) generateProgCalls(p *prog.Prog, trace bool) ([]string, []uint64, error) {
153	exec := make([]byte, prog.ExecBufferSize)
154	progSize, err := p.SerializeForExec(exec)
155	if err != nil {
156		return nil, nil, fmt.Errorf("failed to serialize program: %v", err)
157	}
158	decoded, err := ctx.target.DeserializeExec(exec[:progSize])
159	if err != nil {
160		return nil, nil, err
161	}
162	calls, vars := ctx.generateCalls(decoded, trace)
163	return calls, vars, nil
164}
165
166func (ctx *context) generateCalls(p prog.ExecProg, trace bool) ([]string, []uint64) {
167	var calls []string
168	csumSeq := 0
169	for ci, call := range p.Calls {
170		w := new(bytes.Buffer)
171		// Copyin.
172		for _, copyin := range call.Copyin {
173			ctx.copyin(w, &csumSeq, copyin)
174		}
175
176		if ctx.opts.Fault && ctx.opts.FaultCall == ci {
177			fmt.Fprintf(w, "\twrite_file(\"/sys/kernel/debug/failslab/ignore-gfp-wait\", \"N\");\n")
178			fmt.Fprintf(w, "\twrite_file(\"/sys/kernel/debug/fail_futex/ignore-private\", \"N\");\n")
179			fmt.Fprintf(w, "\tinject_fault(%v);\n", ctx.opts.FaultNth)
180		}
181		// Call itself.
182		callName := call.Meta.CallName
183		resCopyout := call.Index != prog.ExecNoCopyout
184		argCopyout := len(call.Copyout) != 0
185		emitCall := ctx.opts.EnableTun ||
186			callName != "syz_emit_ethernet" &&
187				callName != "syz_extract_tcp_res"
188		// TODO: if we don't emit the call we must also not emit copyin, copyout and fault injection.
189		// However, simply skipping whole iteration breaks tests due to unused static functions.
190		if emitCall {
191			ctx.emitCall(w, call, ci, resCopyout || argCopyout, trace)
192		} else if trace {
193			fmt.Fprintf(w, "\t(void)res;\n")
194		}
195
196		// Copyout.
197		if resCopyout || argCopyout {
198			ctx.copyout(w, call, resCopyout)
199		}
200		calls = append(calls, w.String())
201	}
202	return calls, p.Vars
203}
204
205func (ctx *context) emitCall(w *bytes.Buffer, call prog.ExecCall, ci int, haveCopyout, trace bool) {
206	callName := call.Meta.CallName
207	native := ctx.sysTarget.SyscallNumbers && !strings.HasPrefix(callName, "syz_")
208	fmt.Fprintf(w, "\t")
209	if haveCopyout || trace {
210		fmt.Fprintf(w, "res = ")
211	}
212	if native {
213		fmt.Fprintf(w, "syscall(%v%v", ctx.sysTarget.SyscallPrefix, callName)
214	} else if strings.HasPrefix(callName, "syz_") {
215		fmt.Fprintf(w, "%v(", callName)
216	} else {
217		args := strings.Repeat(",long", len(call.Args))
218		if args != "" {
219			args = args[1:]
220		}
221		fmt.Fprintf(w, "((long(*)(%v))%v)(", args, callName)
222	}
223	for ai, arg := range call.Args {
224		if native || ai > 0 {
225			fmt.Fprintf(w, ", ")
226		}
227		switch arg := arg.(type) {
228		case prog.ExecArgConst:
229			if arg.Format != prog.FormatNative && arg.Format != prog.FormatBigEndian {
230				panic("sring format in syscall argument")
231			}
232			fmt.Fprintf(w, "%v", ctx.constArgToStr(arg))
233		case prog.ExecArgResult:
234			if arg.Format != prog.FormatNative && arg.Format != prog.FormatBigEndian {
235				panic("sring format in syscall argument")
236			}
237			val := ctx.resultArgToStr(arg)
238			if native && ctx.target.PtrSize == 4 {
239				// syscall accepts args as ellipsis, resources are uint64
240				// and take 2 slots without the cast, which would be wrong.
241				val = "(long)" + val
242			}
243			fmt.Fprintf(w, "%v", val)
244		default:
245			panic(fmt.Sprintf("unknown arg type: %+v", arg))
246		}
247	}
248	fmt.Fprintf(w, ");\n")
249	if trace {
250		fmt.Fprintf(w, "\tprintf(\"### call=%v errno=%%u\\n\", res == -1 ? errno : 0);\n", ci)
251	}
252}
253
254func (ctx *context) generateCsumInet(w *bytes.Buffer, addr uint64, arg prog.ExecArgCsum, csumSeq int) {
255	fmt.Fprintf(w, "\tstruct csum_inet csum_%d;\n", csumSeq)
256	fmt.Fprintf(w, "\tcsum_inet_init(&csum_%d);\n", csumSeq)
257	for i, chunk := range arg.Chunks {
258		switch chunk.Kind {
259		case prog.ExecArgCsumChunkData:
260			fmt.Fprintf(w, "\tNONFAILING(csum_inet_update(&csum_%d, (const uint8*)0x%x, %d));\n",
261				csumSeq, chunk.Value, chunk.Size)
262		case prog.ExecArgCsumChunkConst:
263			fmt.Fprintf(w, "\tuint%d csum_%d_chunk_%d = 0x%x;\n",
264				chunk.Size*8, csumSeq, i, chunk.Value)
265			fmt.Fprintf(w, "\tcsum_inet_update(&csum_%d, (const uint8*)&csum_%d_chunk_%d, %d);\n",
266				csumSeq, csumSeq, i, chunk.Size)
267		default:
268			panic(fmt.Sprintf("unknown checksum chunk kind %v", chunk.Kind))
269		}
270	}
271	fmt.Fprintf(w, "\tNONFAILING(*(uint16*)0x%x = csum_inet_digest(&csum_%d));\n",
272		addr, csumSeq)
273}
274
275func (ctx *context) copyin(w *bytes.Buffer, csumSeq *int, copyin prog.ExecCopyin) {
276	switch arg := copyin.Arg.(type) {
277	case prog.ExecArgConst:
278		if arg.BitfieldOffset == 0 && arg.BitfieldLength == 0 {
279			ctx.copyinVal(w, copyin.Addr, arg.Size, ctx.constArgToStr(arg), arg.Format)
280		} else {
281			if arg.Format != prog.FormatNative && arg.Format != prog.FormatBigEndian {
282				panic("bitfield+string format")
283			}
284			fmt.Fprintf(w, "\tNONFAILING(STORE_BY_BITMASK(uint%v, 0x%x, %v, %v, %v));\n",
285				arg.Size*8, copyin.Addr, ctx.constArgToStr(arg),
286				arg.BitfieldOffset, arg.BitfieldLength)
287		}
288	case prog.ExecArgResult:
289		ctx.copyinVal(w, copyin.Addr, arg.Size, ctx.resultArgToStr(arg), arg.Format)
290	case prog.ExecArgData:
291		fmt.Fprintf(w, "\tNONFAILING(memcpy((void*)0x%x, \"%s\", %v));\n",
292			copyin.Addr, toCString(arg.Data), len(arg.Data))
293	case prog.ExecArgCsum:
294		switch arg.Kind {
295		case prog.ExecArgCsumInet:
296			*csumSeq++
297			ctx.generateCsumInet(w, copyin.Addr, arg, *csumSeq)
298		default:
299			panic(fmt.Sprintf("unknown csum kind %v", arg.Kind))
300		}
301	default:
302		panic(fmt.Sprintf("bad argument type: %+v", arg))
303	}
304}
305
306func (ctx *context) copyinVal(w *bytes.Buffer, addr, size uint64, val string, bf prog.BinaryFormat) {
307	switch bf {
308	case prog.FormatNative, prog.FormatBigEndian:
309		fmt.Fprintf(w, "\tNONFAILING(*(uint%v*)0x%x = %v);\n", size*8, addr, val)
310	case prog.FormatStrDec:
311		if size != 20 {
312			panic("bad strdec size")
313		}
314		fmt.Fprintf(w, "\tNONFAILING(sprintf((char*)0x%x, \"%%020llu\", (long long)%v));\n", addr, val)
315	case prog.FormatStrHex:
316		if size != 18 {
317			panic("bad strdec size")
318		}
319		fmt.Fprintf(w, "\tNONFAILING(sprintf((char*)0x%x, \"0x%%016llx\", (long long)%v));\n", addr, val)
320	case prog.FormatStrOct:
321		if size != 23 {
322			panic("bad strdec size")
323		}
324		fmt.Fprintf(w, "\tNONFAILING(sprintf((char*)0x%x, \"%%023llo\", (long long)%v));\n", addr, val)
325	default:
326		panic("unknown binary format")
327	}
328}
329
330func (ctx *context) copyout(w *bytes.Buffer, call prog.ExecCall, resCopyout bool) {
331	if ctx.sysTarget.OS == "fuchsia" {
332		// On fuchsia we have real system calls that return ZX_OK on success,
333		// and libc calls that are casted to function returning long,
334		// as the result int -1 is returned as 0x00000000ffffffff rather than full -1.
335		if strings.HasPrefix(call.Meta.CallName, "zx_") {
336			fmt.Fprintf(w, "\tif (res == ZX_OK)")
337		} else {
338			fmt.Fprintf(w, "\tif ((int)res != -1)")
339		}
340	} else {
341		fmt.Fprintf(w, "\tif (res != -1)")
342	}
343	copyoutMultiple := len(call.Copyout) > 1 || resCopyout && len(call.Copyout) > 0
344	if copyoutMultiple {
345		fmt.Fprintf(w, " {")
346	}
347	fmt.Fprintf(w, "\n")
348	if resCopyout {
349		fmt.Fprintf(w, "\t\tr[%v] = res;\n", call.Index)
350	}
351	for _, copyout := range call.Copyout {
352		fmt.Fprintf(w, "\t\tNONFAILING(r[%v] = *(uint%v*)0x%x);\n",
353			copyout.Index, copyout.Size*8, copyout.Addr)
354	}
355	if copyoutMultiple {
356		fmt.Fprintf(w, "\t}\n")
357	}
358}
359
360func (ctx *context) constArgToStr(arg prog.ExecArgConst) string {
361	mask := (uint64(1) << (arg.Size * 8)) - 1
362	v := arg.Value & mask
363	val := fmt.Sprintf("%v", v)
364	if v == ^uint64(0)&mask {
365		val = "-1"
366	} else if v >= 10 {
367		val = fmt.Sprintf("0x%x", v)
368	}
369	if ctx.opts.Procs > 1 && arg.PidStride != 0 {
370		val += fmt.Sprintf(" + procid*%v", arg.PidStride)
371	}
372	if arg.Format == prog.FormatBigEndian {
373		val = fmt.Sprintf("htobe%v(%v)", arg.Size*8, val)
374	}
375	return val
376}
377
378func (ctx *context) resultArgToStr(arg prog.ExecArgResult) string {
379	res := fmt.Sprintf("r[%v]", arg.Index)
380	if arg.DivOp != 0 {
381		res = fmt.Sprintf("%v/%v", res, arg.DivOp)
382	}
383	if arg.AddOp != 0 {
384		res = fmt.Sprintf("%v+%v", res, arg.AddOp)
385	}
386	if arg.Format == prog.FormatBigEndian {
387		res = fmt.Sprintf("htobe%v(%v)", arg.Size*8, res)
388	}
389	return res
390}
391
392func (ctx *context) postProcess(result []byte) []byte {
393	// Remove NONFAILING, debug, fail, etc calls.
394	if !ctx.opts.HandleSegv {
395		result = regexp.MustCompile(`\t*NONFAILING\((.*)\);\n`).ReplaceAll(result, []byte("$1;\n"))
396	}
397	result = bytes.Replace(result, []byte("NORETURN"), nil, -1)
398	result = bytes.Replace(result, []byte("PRINTF"), nil, -1)
399	result = bytes.Replace(result, []byte("doexit("), []byte("exit("), -1)
400	result = regexp.MustCompile(`\t*debug\((.*\n)*?.*\);\n`).ReplaceAll(result, nil)
401	result = regexp.MustCompile(`\t*debug_dump_data\((.*\n)*?.*\);\n`).ReplaceAll(result, nil)
402	result = regexp.MustCompile(`\t*exitf\((.*\n)*?.*\);\n`).ReplaceAll(result, []byte("\texit(1);\n"))
403	result = regexp.MustCompile(`\t*fail\((.*\n)*?.*\);\n`).ReplaceAll(result, []byte("\texit(1);\n"))
404	result = regexp.MustCompile(`\t*error\((.*\n)*?.*\);\n`).ReplaceAll(result, []byte("\texit(1);\n"))
405
406	result = ctx.hoistIncludes(result)
407	result = ctx.removeEmptyLines(result)
408	return result
409}
410
411// hoistIncludes moves all includes to the top, removes dups and sorts.
412func (ctx *context) hoistIncludes(result []byte) []byte {
413	includesStart := bytes.Index(result, []byte("#include"))
414	if includesStart == -1 {
415		return result
416	}
417	includes := make(map[string]bool)
418	includeRe := regexp.MustCompile("#include <.*>\n")
419	for _, match := range includeRe.FindAll(result, -1) {
420		includes[string(match)] = true
421	}
422	result = includeRe.ReplaceAll(result, nil)
423	// Linux headers are broken, so we have to move all linux includes to the bottom.
424	var sorted, sortedLinux []string
425	for include := range includes {
426		if strings.Contains(include, "<linux/") {
427			sortedLinux = append(sortedLinux, include)
428		} else {
429			sorted = append(sorted, include)
430		}
431	}
432	sort.Strings(sorted)
433	sort.Strings(sortedLinux)
434	newResult := append([]byte{}, result[:includesStart]...)
435	newResult = append(newResult, strings.Join(sorted, "")...)
436	newResult = append(newResult, '\n')
437	newResult = append(newResult, strings.Join(sortedLinux, "")...)
438	newResult = append(newResult, result[includesStart:]...)
439	return newResult
440}
441
442// removeEmptyLines removes duplicate new lines.
443func (ctx *context) removeEmptyLines(result []byte) []byte {
444	for {
445		newResult := bytes.Replace(result, []byte{'\n', '\n', '\n'}, []byte{'\n', '\n'}, -1)
446		newResult = bytes.Replace(newResult, []byte{'\n', '\n', '\t'}, []byte{'\n', '\t'}, -1)
447		newResult = bytes.Replace(newResult, []byte{'\n', '\n', ' '}, []byte{'\n', ' '}, -1)
448		if len(newResult) == len(result) {
449			return result
450		}
451		result = newResult
452	}
453}
454
455func toCString(data []byte) []byte {
456	if len(data) == 0 {
457		return nil
458	}
459	readable := true
460	for i, v := range data {
461		// Allow 0 only as last byte.
462		if !isReadable(v) && (i != len(data)-1 || v != 0) {
463			readable = false
464			break
465		}
466	}
467	if !readable {
468		buf := new(bytes.Buffer)
469		for _, v := range data {
470			buf.Write([]byte{'\\', 'x', toHex(v >> 4), toHex(v << 4 >> 4)})
471		}
472		return buf.Bytes()
473	}
474	if data[len(data)-1] == 0 {
475		// Don't serialize last 0, C strings are 0-terminated anyway.
476		data = data[:len(data)-1]
477	}
478	buf := new(bytes.Buffer)
479	for _, v := range data {
480		switch v {
481		case '\t':
482			buf.Write([]byte{'\\', 't'})
483		case '\r':
484			buf.Write([]byte{'\\', 'r'})
485		case '\n':
486			buf.Write([]byte{'\\', 'n'})
487		case '\\':
488			buf.Write([]byte{'\\', '\\'})
489		case '"':
490			buf.Write([]byte{'\\', '"'})
491		default:
492			if v < 0x20 || v >= 0x7f {
493				panic("unexpected char during data serialization")
494			}
495			buf.WriteByte(v)
496		}
497	}
498	return buf.Bytes()
499}
500
501func isReadable(v byte) bool {
502	return v >= 0x20 && v < 0x7f || v == '\t' || v == '\r' || v == '\n'
503}
504
505func toHex(v byte) byte {
506	if v < 10 {
507		return '0' + v
508	}
509	return 'a' + v - 10
510}
511