1// Copyright 2017 syzkaller project authors. All rights reserved.
2// Use of this source code is governed by Apache 2 LICENSE that can be found in the LICENSE file.
3
4package compiler
5
6import (
7	"bufio"
8	"bytes"
9	"fmt"
10	"io/ioutil"
11	"path/filepath"
12	"sort"
13	"strconv"
14	"strings"
15
16	"github.com/google/syzkaller/pkg/ast"
17	"github.com/google/syzkaller/prog"
18	"github.com/google/syzkaller/sys/targets"
19)
20
21type ConstInfo struct {
22	Consts   []string
23	Includes []string
24	Incdirs  []string
25	Defines  map[string]string
26}
27
28func ExtractConsts(desc *ast.Description, target *targets.Target, eh ast.ErrorHandler) map[string]*ConstInfo {
29	res := Compile(desc, nil, target, eh)
30	if res == nil {
31		return nil
32	}
33	return res.fileConsts
34}
35
36// extractConsts returns list of literal constants and other info required for const value extraction.
37func (comp *compiler) extractConsts() map[string]*ConstInfo {
38	infos := make(map[string]*constInfo)
39	for _, decl := range comp.desc.Nodes {
40		pos, _, _ := decl.Info()
41		info := getConstInfo(infos, pos)
42		switch n := decl.(type) {
43		case *ast.Include:
44			info.includeArray = append(info.includeArray, n.File.Value)
45		case *ast.Incdir:
46			info.incdirArray = append(info.incdirArray, n.Dir.Value)
47		case *ast.Define:
48			v := fmt.Sprint(n.Value.Value)
49			switch {
50			case n.Value.CExpr != "":
51				v = n.Value.CExpr
52			case n.Value.Ident != "":
53				v = n.Value.Ident
54			}
55			name := n.Name.Name
56			info.defines[name] = v
57			info.consts[name] = true
58		case *ast.Call:
59			if comp.target.SyscallNumbers && !strings.HasPrefix(n.CallName, "syz_") {
60				info.consts[comp.target.SyscallPrefix+n.CallName] = true
61			}
62		}
63	}
64
65	for _, decl := range comp.desc.Nodes {
66		switch decl.(type) {
67		case *ast.Call, *ast.Struct, *ast.Resource, *ast.TypeDef:
68			comp.foreachType(decl, func(t *ast.Type, desc *typeDesc,
69				args []*ast.Type, _ prog.IntTypeCommon) {
70				for i, arg := range args {
71					if desc.Args[i].Type.Kind == kindInt {
72						if arg.Ident != "" {
73							info := getConstInfo(infos, arg.Pos)
74							info.consts[arg.Ident] = true
75						}
76						if arg.Ident2 != "" {
77							info := getConstInfo(infos, arg.Pos2)
78							info.consts[arg.Ident2] = true
79						}
80					}
81				}
82			})
83		}
84	}
85
86	for _, decl := range comp.desc.Nodes {
87		switch n := decl.(type) {
88		case *ast.Struct:
89			for _, attr := range n.Attrs {
90				if attr.Ident == "size" {
91					info := getConstInfo(infos, attr.Pos)
92					info.consts[attr.Args[0].Ident] = true
93				}
94			}
95		}
96	}
97
98	comp.desc.Walk(ast.Recursive(func(n0 ast.Node) {
99		if n, ok := n0.(*ast.Int); ok {
100			info := getConstInfo(infos, n.Pos)
101			info.consts[n.Ident] = true
102		}
103	}))
104
105	return convertConstInfo(infos)
106}
107
108type constInfo struct {
109	consts       map[string]bool
110	defines      map[string]string
111	includeArray []string
112	incdirArray  []string
113}
114
115func getConstInfo(infos map[string]*constInfo, pos ast.Pos) *constInfo {
116	info := infos[pos.File]
117	if info == nil {
118		info = &constInfo{
119			consts:  make(map[string]bool),
120			defines: make(map[string]string),
121		}
122		infos[pos.File] = info
123	}
124	return info
125}
126
127func convertConstInfo(infos map[string]*constInfo) map[string]*ConstInfo {
128	res := make(map[string]*ConstInfo)
129	for file, info := range infos {
130		res[file] = &ConstInfo{
131			Consts:   toArray(info.consts),
132			Includes: info.includeArray,
133			Incdirs:  info.incdirArray,
134			Defines:  info.defines,
135		}
136	}
137	return res
138}
139
140// assignSyscallNumbers assigns syscall numbers, discards unsupported syscalls.
141func (comp *compiler) assignSyscallNumbers(consts map[string]uint64) {
142	for _, decl := range comp.desc.Nodes {
143		c, ok := decl.(*ast.Call)
144		if !ok || strings.HasPrefix(c.CallName, "syz_") {
145			continue
146		}
147		str := comp.target.SyscallPrefix + c.CallName
148		nr, ok := consts[str]
149		if ok {
150			c.NR = nr
151			continue
152		}
153		c.NR = ^uint64(0) // mark as unused to not generate it
154		name := "syscall " + c.CallName
155		if !comp.unsupported[name] {
156			comp.unsupported[name] = true
157			comp.warning(c.Pos, "unsupported syscall: %v due to missing const %v",
158				c.CallName, str)
159		}
160	}
161}
162
163// patchConsts replaces all symbolic consts with their numeric values taken from consts map.
164// Updates desc and returns set of unsupported syscalls and flags.
165func (comp *compiler) patchConsts(consts map[string]uint64) {
166	for _, decl := range comp.desc.Nodes {
167		switch decl.(type) {
168		case *ast.IntFlags:
169			// Unsupported flag values are dropped.
170			n := decl.(*ast.IntFlags)
171			var values []*ast.Int
172			for _, v := range n.Values {
173				if comp.patchIntConst(&v.Value, &v.Ident, consts, nil) {
174					values = append(values, v)
175				}
176			}
177			n.Values = values
178		case *ast.Resource, *ast.Struct, *ast.Call, *ast.TypeDef:
179			// Walk whole tree and replace consts in Type's and Int's.
180			missing := ""
181			comp.foreachType(decl, func(_ *ast.Type, desc *typeDesc,
182				args []*ast.Type, _ prog.IntTypeCommon) {
183				for i, arg := range args {
184					if desc.Args[i].Type.Kind == kindInt {
185						comp.patchIntConst(&arg.Value, &arg.Ident, consts, &missing)
186						if arg.HasColon {
187							comp.patchIntConst(&arg.Value2,
188								&arg.Ident2, consts, &missing)
189						}
190					}
191				}
192			})
193			if n, ok := decl.(*ast.Resource); ok {
194				for _, v := range n.Values {
195					comp.patchIntConst(&v.Value, &v.Ident, consts, &missing)
196				}
197			}
198			if n, ok := decl.(*ast.Struct); ok {
199				for _, attr := range n.Attrs {
200					if attr.Ident == "size" {
201						sz := attr.Args[0]
202						comp.patchIntConst(&sz.Value, &sz.Ident, consts, &missing)
203					}
204				}
205			}
206			if missing == "" {
207				continue
208			}
209			// Produce a warning about unsupported syscall/resource/struct.
210			// TODO(dvyukov): we should transitively remove everything that
211			// depends on unsupported things.
212			// Potentially we still can get, say, a bad int range error
213			// due to the 0 const value.
214			pos, typ, name := decl.Info()
215			if id := typ + " " + name; !comp.unsupported[id] {
216				comp.unsupported[id] = true
217				comp.warning(pos, "unsupported %v: %v due to missing const %v",
218					typ, name, missing)
219			}
220			if c, ok := decl.(*ast.Call); ok {
221				c.NR = ^uint64(0) // mark as unused to not generate it
222			}
223		}
224	}
225}
226
227func (comp *compiler) patchIntConst(val *uint64, id *string, consts map[string]uint64, missing *string) bool {
228	if *id == "" {
229		return true
230	}
231	v, ok := consts[*id]
232	if !ok {
233		if missing != nil && *missing == "" {
234			*missing = *id
235		}
236	}
237	*val = v
238	return ok
239}
240
241func SerializeConsts(consts map[string]uint64, undeclared map[string]bool) []byte {
242	type nameValuePair struct {
243		declared bool
244		name     string
245		val      uint64
246	}
247	var nv []nameValuePair
248	for k, v := range consts {
249		nv = append(nv, nameValuePair{true, k, v})
250	}
251	for k := range undeclared {
252		nv = append(nv, nameValuePair{false, k, 0})
253	}
254	sort.Slice(nv, func(i, j int) bool {
255		return nv[i].name < nv[j].name
256	})
257
258	buf := new(bytes.Buffer)
259	fmt.Fprintf(buf, "# AUTOGENERATED FILE\n")
260	for _, x := range nv {
261		if x.declared {
262			fmt.Fprintf(buf, "%v = %v\n", x.name, x.val)
263		} else {
264			fmt.Fprintf(buf, "# %v is not set\n", x.name)
265		}
266	}
267	return buf.Bytes()
268}
269
270func DeserializeConsts(data []byte, file string, eh ast.ErrorHandler) map[string]uint64 {
271	consts := make(map[string]uint64)
272	pos := ast.Pos{
273		File: file,
274		Line: 1,
275	}
276	ok := true
277	s := bufio.NewScanner(bytes.NewReader(data))
278	for ; s.Scan(); pos.Line++ {
279		line := s.Text()
280		if line == "" || line[0] == '#' {
281			continue
282		}
283		eq := strings.IndexByte(line, '=')
284		if eq == -1 {
285			eh(pos, "expect '='")
286			ok = false
287			continue
288		}
289		name := strings.TrimSpace(line[:eq])
290		val, err := strconv.ParseUint(strings.TrimSpace(line[eq+1:]), 0, 64)
291		if err != nil {
292			eh(pos, fmt.Sprintf("failed to parse int: %v", err))
293			ok = false
294			continue
295		}
296		if _, dup := consts[name]; dup {
297			eh(pos, fmt.Sprintf("duplicate const %q", name))
298			ok = false
299			continue
300		}
301		consts[name] = val
302	}
303	if err := s.Err(); err != nil {
304		eh(pos, fmt.Sprintf("failed to parse: %v", err))
305		ok = false
306	}
307	if !ok {
308		return nil
309	}
310	return consts
311}
312
313func DeserializeConstsGlob(glob string, eh ast.ErrorHandler) map[string]uint64 {
314	if eh == nil {
315		eh = ast.LoggingHandler
316	}
317	files, err := filepath.Glob(glob)
318	if err != nil {
319		eh(ast.Pos{}, fmt.Sprintf("failed to find const files: %v", err))
320		return nil
321	}
322	if len(files) == 0 {
323		eh(ast.Pos{}, fmt.Sprintf("no const files matched by glob %q", glob))
324		return nil
325	}
326	consts := make(map[string]uint64)
327	for _, f := range files {
328		data, err := ioutil.ReadFile(f)
329		if err != nil {
330			eh(ast.Pos{}, fmt.Sprintf("failed to read const file: %v", err))
331			return nil
332		}
333		consts1 := DeserializeConsts(data, filepath.Base(f), eh)
334		if consts1 == nil {
335			consts = nil
336		}
337		if consts != nil {
338			for n, v := range consts1 {
339				if old, ok := consts[n]; ok && old != v {
340					eh(ast.Pos{}, fmt.Sprintf(
341						"different values for const %q: %v vs %v", n, v, old))
342					return nil
343				}
344				consts[n] = v
345			}
346		}
347	}
348	return consts
349}
350