1// Copyright 2017 syzkaller project authors. All rights reserved.
2// Use of this source code is governed by Apache 2 LICENSE that can be found in the LICENSE file.
3
4//go:generate go run gen.go
5
6package csource
7
8import (
9	"bytes"
10	"fmt"
11	"sort"
12	"strings"
13
14	"github.com/google/syzkaller/pkg/osutil"
15	"github.com/google/syzkaller/prog"
16	"github.com/google/syzkaller/sys/targets"
17)
18
19const (
20	linux = "linux"
21
22	sandboxNone      = "none"
23	sandboxSetuid    = "setuid"
24	sandboxNamespace = "namespace"
25)
26
27func createCommonHeader(p, mmapProg *prog.Prog, replacements map[string]string, opts Options) ([]byte, error) {
28	defines := defineList(p, mmapProg, opts)
29	cmd := osutil.Command("cpp", "-nostdinc", "-undef", "-fdirectives-only", "-dDI", "-E", "-P", "-")
30	for _, def := range defines {
31		cmd.Args = append(cmd.Args, "-D"+def)
32	}
33	cmd.Stdin = strings.NewReader(commonHeader)
34	stderr := new(bytes.Buffer)
35	stdout := new(bytes.Buffer)
36	cmd.Stderr = stderr
37	cmd.Stdout = stdout
38	if err := cmd.Run(); len(stdout.Bytes()) == 0 {
39		return nil, fmt.Errorf("cpp failed: %v\n%v\n%v", err, stdout.String(), stderr.String())
40	}
41
42	src, err := removeSystemDefines(stdout.Bytes(), defines)
43	if err != nil {
44		return nil, err
45	}
46
47	for from, to := range replacements {
48		src = bytes.Replace(src, []byte("[["+from+"]]"), []byte(to), -1)
49	}
50
51	for from, to := range map[string]string{
52		"uint64": "uint64_t",
53		"uint32": "uint32_t",
54		"uint16": "uint16_t",
55		"uint8":  "uint8_t",
56	} {
57		src = bytes.Replace(src, []byte(from), []byte(to), -1)
58	}
59
60	return src, nil
61}
62
63func defineList(p, mmapProg *prog.Prog, opts Options) (defines []string) {
64	sysTarget := targets.Get(p.Target.OS, p.Target.Arch)
65	bitmasks, csums := prog.RequiredFeatures(p)
66	enabled := map[string]bool{
67		"GOOS_" + p.Target.OS:           true,
68		"GOARCH_" + p.Target.Arch:       true,
69		"SYZ_USE_BITMASKS":              bitmasks,
70		"SYZ_USE_CHECKSUMS":             csums,
71		"SYZ_SANDBOX_NONE":              opts.Sandbox == sandboxNone,
72		"SYZ_SANDBOX_SETUID":            opts.Sandbox == sandboxSetuid,
73		"SYZ_SANDBOX_NAMESPACE":         opts.Sandbox == sandboxNamespace,
74		"SYZ_THREADED":                  opts.Threaded,
75		"SYZ_COLLIDE":                   opts.Collide,
76		"SYZ_REPEAT":                    opts.Repeat,
77		"SYZ_REPEAT_TIMES":              opts.RepeatTimes > 1,
78		"SYZ_PROCS":                     opts.Procs > 1,
79		"SYZ_FAULT_INJECTION":           opts.Fault,
80		"SYZ_TUN_ENABLE":                opts.EnableTun,
81		"SYZ_ENABLE_CGROUPS":            opts.EnableCgroups,
82		"SYZ_ENABLE_NETDEV":             opts.EnableNetdev,
83		"SYZ_RESET_NET_NAMESPACE":       opts.ResetNet,
84		"SYZ_USE_TMP_DIR":               opts.UseTmpDir,
85		"SYZ_HANDLE_SEGV":               opts.HandleSegv,
86		"SYZ_REPRO":                     opts.Repro,
87		"SYZ_TRACE":                     opts.Trace,
88		"SYZ_EXECUTOR_USES_SHMEM":       sysTarget.ExecutorUsesShmem,
89		"SYZ_EXECUTOR_USES_FORK_SERVER": sysTarget.ExecutorUsesForkServer,
90	}
91	for def, ok := range enabled {
92		if ok {
93			defines = append(defines, def)
94		}
95	}
96	for _, c := range p.Calls {
97		defines = append(defines, "__NR_"+c.Meta.CallName)
98	}
99	for _, c := range mmapProg.Calls {
100		defines = append(defines, "__NR_"+c.Meta.CallName)
101	}
102	sort.Strings(defines)
103	return
104}
105
106func removeSystemDefines(src []byte, defines []string) ([]byte, error) {
107	remove := map[string]string{
108		"__STDC__":        "1",
109		"__STDC_HOSTED__": "1",
110		"__STDC_UTF_16__": "1",
111		"__STDC_UTF_32__": "1",
112	}
113	for _, def := range defines {
114		eq := strings.IndexByte(def, '=')
115		if eq == -1 {
116			remove[def] = "1"
117		} else {
118			remove[def[:eq]] = def[eq+1:]
119		}
120	}
121	for def, val := range remove {
122		src = bytes.Replace(src, []byte("#define "+def+" "+val+"\n"), nil, -1)
123	}
124	// strip: #define __STDC_VERSION__ 201112L
125	for _, def := range []string{"__STDC_VERSION__"} {
126		pos := bytes.Index(src, []byte("#define "+def))
127		if pos == -1 {
128			continue
129		}
130		end := bytes.IndexByte(src[pos:], '\n')
131		if end == -1 {
132			continue
133		}
134		src = bytes.Replace(src, src[pos:end+1], nil, -1)
135	}
136	return src, nil
137}
138