1// Copyright 2017 syzkaller project authors. All rights reserved.
2// Use of this source code is governed by Apache 2 LICENSE that can be found in the LICENSE file.
3
4package main
5
6import (
7	"bytes"
8	"fmt"
9	"os"
10	"regexp"
11	"strconv"
12	"strings"
13	"text/template"
14
15	"github.com/google/syzkaller/pkg/compiler"
16	"github.com/google/syzkaller/pkg/osutil"
17)
18
19func extract(info *compiler.ConstInfo, cc string, args []string, addSource string, declarePrintf bool) (
20	map[string]uint64, map[string]bool, error) {
21	data := &CompileData{
22		AddSource:     addSource,
23		Defines:       info.Defines,
24		Includes:      info.Includes,
25		Values:        info.Consts,
26		DeclarePrintf: declarePrintf,
27	}
28	undeclared := make(map[string]bool)
29	bin, out, err := compile(cc, args, data)
30	if err != nil {
31		// Some consts and syscall numbers are not defined on some archs.
32		// Figure out from compiler output undefined consts,
33		// and try to compile again without them.
34		valMap := make(map[string]bool)
35		for _, val := range info.Consts {
36			valMap[val] = true
37		}
38		for _, errMsg := range []string{
39			"error: ‘([a-zA-Z0-9_]+)’ undeclared",
40			"error: '([a-zA-Z0-9_]+)' undeclared",
41			"note: in expansion of macro ‘([a-zA-Z0-9_]+)’",
42			"error: use of undeclared identifier '([a-zA-Z0-9_]+)'",
43		} {
44			re := regexp.MustCompile(errMsg)
45			matches := re.FindAllSubmatch(out, -1)
46			for _, match := range matches {
47				val := string(match[1])
48				if valMap[val] {
49					undeclared[val] = true
50				}
51			}
52		}
53		data.Values = nil
54		for _, v := range info.Consts {
55			if undeclared[v] {
56				continue
57			}
58			data.Values = append(data.Values, v)
59		}
60		bin, out, err = compile(cc, args, data)
61		if err != nil {
62			return nil, nil, fmt.Errorf("failed to run compiler: %v\n%v", err, string(out))
63		}
64	}
65	defer os.Remove(bin)
66
67	out, err = osutil.Command(bin).CombinedOutput()
68	if err != nil {
69		return nil, nil, fmt.Errorf("failed to run flags binary: %v\n%v", err, string(out))
70	}
71	flagVals := strings.Split(string(out), " ")
72	if len(out) == 0 {
73		flagVals = nil
74	}
75	if len(flagVals) != len(data.Values) {
76		return nil, nil, fmt.Errorf("fetched wrong number of values %v, want != %v",
77			len(flagVals), len(data.Values))
78	}
79	res := make(map[string]uint64)
80	for i, name := range data.Values {
81		val := flagVals[i]
82		n, err := strconv.ParseUint(val, 10, 64)
83		if err != nil {
84			return nil, nil, fmt.Errorf("failed to parse value: %v (%v)", err, val)
85		}
86		res[name] = n
87	}
88	return res, undeclared, nil
89}
90
91type CompileData struct {
92	AddSource     string
93	Defines       map[string]string
94	Includes      []string
95	Values        []string
96	DeclarePrintf bool
97}
98
99func compile(cc string, args []string, data *CompileData) (bin string, out []byte, err error) {
100	src := new(bytes.Buffer)
101	if err := srcTemplate.Execute(src, data); err != nil {
102		return "", nil, fmt.Errorf("failed to generate source: %v", err)
103	}
104	binFile, err := osutil.TempFile("syz-extract-bin")
105	if err != nil {
106		return "", nil, err
107	}
108	args = append(args, []string{
109		"-x", "c", "-",
110		"-o", binFile,
111		"-w",
112	}...)
113	cmd := osutil.Command(cc, args...)
114	cmd.Stdin = src
115	if out, err := cmd.CombinedOutput(); err != nil {
116		os.Remove(binFile)
117		return "", out, err
118	}
119	return binFile, nil, nil
120}
121
122var srcTemplate = template.Must(template.New("").Parse(`
123#define __asm__(...)
124
125{{range $incl := $.Includes}}
126#include <{{$incl}}>
127{{end}}
128
129{{range $name, $val := $.Defines}}
130#ifndef {{$name}}
131#	define {{$name}} {{$val}}
132#endif
133{{end}}
134
135{{.AddSource}}
136
137{{if .DeclarePrintf}}
138int printf(const char *format, ...);
139{{end}}
140
141int main() {
142	int i;
143	unsigned long long vals[] = {
144		{{range $val := $.Values}}(unsigned long long){{$val}},
145		{{end}}
146	};
147	for (i = 0; i < sizeof(vals)/sizeof(vals[0]); i++) {
148		if (i != 0)
149			printf(" ");
150		printf("%llu", vals[i]);
151	}
152	return 0;
153}
154`))
155