1// Copyright 2015 syzkaller project authors. All rights reserved.
2// Use of this source code is governed by Apache 2 LICENSE that can be found in the LICENSE file.
3
4package main
5
6import (
7	"flag"
8	"fmt"
9	"io/ioutil"
10	"os"
11	"runtime"
12
13	"github.com/google/syzkaller/pkg/csource"
14	"github.com/google/syzkaller/prog"
15	_ "github.com/google/syzkaller/sys"
16)
17
18var (
19	flagOS         = flag.String("os", runtime.GOOS, "target os")
20	flagArch       = flag.String("arch", runtime.GOARCH, "target arch")
21	flagBuild      = flag.Bool("build", false, "also build the generated program")
22	flagThreaded   = flag.Bool("threaded", false, "create threaded program")
23	flagCollide    = flag.Bool("collide", false, "create collide program")
24	flagRepeat     = flag.Int("repeat", 1, "repeat program that many times (<=0 - infinitely)")
25	flagProcs      = flag.Int("procs", 1, "number of parallel processes")
26	flagSandbox    = flag.String("sandbox", "", "sandbox to use (none, setuid, namespace)")
27	flagProg       = flag.String("prog", "", "file with program to convert (required)")
28	flagFaultCall  = flag.Int("fault_call", -1, "inject fault into this call (0-based)")
29	flagFaultNth   = flag.Int("fault_nth", 0, "inject fault on n-th operation (0-based)")
30	flagEnableTun  = flag.Bool("tun", false, "set up TUN/TAP interface")
31	flagUseTmpDir  = flag.Bool("tmpdir", false, "create a temporary dir and execute inside it")
32	flagCgroups    = flag.Bool("cgroups", false, "enable cgroups support")
33	flagNetdev     = flag.Bool("netdev", false, "setup various net devices")
34	flagResetNet   = flag.Bool("resetnet", false, "reset net namespace after each test")
35	flagHandleSegv = flag.Bool("segv", false, "catch and ignore SIGSEGV")
36	flagTrace      = flag.Bool("trace", false, "trace syscall results")
37)
38
39func main() {
40	flag.Parse()
41	if *flagProg == "" {
42		flag.PrintDefaults()
43		os.Exit(1)
44	}
45	target, err := prog.GetTarget(*flagOS, *flagArch)
46	if err != nil {
47		fmt.Fprintf(os.Stderr, "%v", err)
48		os.Exit(1)
49	}
50	data, err := ioutil.ReadFile(*flagProg)
51	if err != nil {
52		fmt.Fprintf(os.Stderr, "failed to read prog file: %v\n", err)
53		os.Exit(1)
54	}
55	p, err := target.Deserialize(data)
56	if err != nil {
57		fmt.Fprintf(os.Stderr, "failed to deserialize the program: %v\n", err)
58		os.Exit(1)
59	}
60	opts := csource.Options{
61		Threaded:      *flagThreaded,
62		Collide:       *flagCollide,
63		Repeat:        *flagRepeat != 1,
64		RepeatTimes:   *flagRepeat,
65		Procs:         *flagProcs,
66		Sandbox:       *flagSandbox,
67		Fault:         *flagFaultCall >= 0,
68		FaultCall:     *flagFaultCall,
69		FaultNth:      *flagFaultNth,
70		EnableTun:     *flagEnableTun,
71		UseTmpDir:     *flagUseTmpDir,
72		EnableCgroups: *flagCgroups,
73		EnableNetdev:  *flagNetdev,
74		ResetNet:      *flagResetNet,
75		HandleSegv:    *flagHandleSegv,
76		Repro:         false,
77		Trace:         *flagTrace,
78	}
79	src, err := csource.Write(p, opts)
80	if err != nil {
81		fmt.Fprintf(os.Stderr, "failed to generate C source: %v\n", err)
82		os.Exit(1)
83	}
84	if formatted, err := csource.Format(src); err != nil {
85		fmt.Fprintf(os.Stderr, "%v\n", err)
86	} else {
87		src = formatted
88	}
89	os.Stdout.Write(src)
90	if !*flagBuild {
91		return
92	}
93	bin, err := csource.Build(target, src)
94	if err != nil {
95		fmt.Fprintf(os.Stderr, "failed to build C source: %v\n", err)
96		os.Exit(1)
97	}
98	os.Remove(bin)
99	fmt.Fprintf(os.Stderr, "binary build OK\n")
100}
101