1// Copyright 2017 syzkaller project authors. All rights reserved.
2// Use of this source code is governed by Apache 2 LICENSE that can be found in the LICENSE file.
3
4// Package compiler generates sys descriptions of syscalls, types and resources
5// from textual descriptions.
6package compiler
7
8import (
9	"fmt"
10	"sort"
11	"strconv"
12	"strings"
13
14	"github.com/google/syzkaller/pkg/ast"
15	"github.com/google/syzkaller/prog"
16	"github.com/google/syzkaller/sys/targets"
17)
18
19// Overview of compilation process:
20// 1. ast.Parse on text file does tokenization and builds AST.
21//    This step catches basic syntax errors. AST contains full debug info.
22// 2. ExtractConsts as AST returns set of constant identifiers.
23//    This step also does verification of include/incdir/define AST nodes.
24// 3. User translates constants to values.
25// 4. Compile on AST and const values does the rest of the work and returns Prog
26//    containing generated prog objects.
27// 4.1. assignSyscallNumbers: uses consts to assign syscall numbers.
28//      This step also detects unsupported syscalls and discards no longer
29//      needed AST nodes (inlcude, define, comments, etc).
30// 4.2. patchConsts: patches Int nodes referring to consts with corresponding values.
31//      Also detects unsupported syscalls, structs, resources due to missing consts.
32// 4.3. check: does extensive semantical checks of AST.
33// 4.4. gen: generates prog objects from AST.
34
35// Prog is description compilation result.
36type Prog struct {
37	Resources   []*prog.ResourceDesc
38	Syscalls    []*prog.Syscall
39	StructDescs []*prog.KeyedStruct
40	// Set of unsupported syscalls/flags.
41	Unsupported map[string]bool
42	// Returned if consts was nil.
43	fileConsts map[string]*ConstInfo
44}
45
46// Compile compiles sys description.
47func Compile(desc *ast.Description, consts map[string]uint64, target *targets.Target, eh ast.ErrorHandler) *Prog {
48	if eh == nil {
49		eh = ast.LoggingHandler
50	}
51	comp := &compiler{
52		desc:         desc.Clone(),
53		target:       target,
54		eh:           eh,
55		ptrSize:      target.PtrSize,
56		unsupported:  make(map[string]bool),
57		resources:    make(map[string]*ast.Resource),
58		typedefs:     make(map[string]*ast.TypeDef),
59		structs:      make(map[string]*ast.Struct),
60		intFlags:     make(map[string]*ast.IntFlags),
61		strFlags:     make(map[string]*ast.StrFlags),
62		used:         make(map[string]bool),
63		usedTypedefs: make(map[string]bool),
64		structDescs:  make(map[prog.StructKey]*prog.StructDesc),
65		structNodes:  make(map[*prog.StructDesc]*ast.Struct),
66		structVarlen: make(map[string]bool),
67	}
68	for name, n := range builtinTypedefs {
69		comp.typedefs[name] = n
70		comp.usedTypedefs[name] = true
71	}
72	for name, n := range builtinStrFlags {
73		comp.strFlags[name] = n
74	}
75	comp.typecheck()
76	// The subsequent, more complex, checks expect basic validity of the tree,
77	// in particular corrent number of type arguments. If there were errors,
78	// don't proceed to avoid out-of-bounds references to type arguments.
79	if comp.errors != 0 {
80		return nil
81	}
82	if consts == nil {
83		fileConsts := comp.extractConsts()
84		if comp.errors != 0 {
85			return nil
86		}
87		return &Prog{fileConsts: fileConsts}
88	}
89	if comp.target.SyscallNumbers {
90		comp.assignSyscallNumbers(consts)
91	}
92	comp.patchConsts(consts)
93	comp.check()
94	if comp.errors != 0 {
95		return nil
96	}
97	for _, w := range comp.warnings {
98		eh(w.pos, w.msg)
99	}
100	syscalls := comp.genSyscalls()
101	prg := &Prog{
102		Resources:   comp.genResources(),
103		Syscalls:    syscalls,
104		StructDescs: comp.genStructDescs(syscalls),
105		Unsupported: comp.unsupported,
106	}
107	if comp.errors != 0 {
108		return nil
109	}
110	return prg
111}
112
113type compiler struct {
114	desc     *ast.Description
115	target   *targets.Target
116	eh       ast.ErrorHandler
117	errors   int
118	warnings []warn
119	ptrSize  uint64
120
121	unsupported  map[string]bool
122	resources    map[string]*ast.Resource
123	typedefs     map[string]*ast.TypeDef
124	structs      map[string]*ast.Struct
125	intFlags     map[string]*ast.IntFlags
126	strFlags     map[string]*ast.StrFlags
127	used         map[string]bool // contains used structs/resources
128	usedTypedefs map[string]bool
129
130	structDescs  map[prog.StructKey]*prog.StructDesc
131	structNodes  map[*prog.StructDesc]*ast.Struct
132	structVarlen map[string]bool
133}
134
135type warn struct {
136	pos ast.Pos
137	msg string
138}
139
140func (comp *compiler) error(pos ast.Pos, msg string, args ...interface{}) {
141	comp.errors++
142	comp.eh(pos, fmt.Sprintf(msg, args...))
143}
144
145func (comp *compiler) warning(pos ast.Pos, msg string, args ...interface{}) {
146	comp.warnings = append(comp.warnings, warn{pos, fmt.Sprintf(msg, args...)})
147}
148
149func (comp *compiler) structIsVarlen(name string) bool {
150	if varlen, ok := comp.structVarlen[name]; ok {
151		return varlen
152	}
153	s := comp.structs[name]
154	if s.IsUnion {
155		if varlen, _ := comp.parseUnionAttrs(s); varlen {
156			comp.structVarlen[name] = true
157			return true
158		}
159	}
160	comp.structVarlen[name] = false // to not hang on recursive types
161	varlen := false
162	for _, fld := range s.Fields {
163		if comp.isVarlen(fld.Type) {
164			varlen = true
165			break
166		}
167	}
168	comp.structVarlen[name] = varlen
169	return varlen
170}
171
172func (comp *compiler) parseUnionAttrs(n *ast.Struct) (varlen bool, size uint64) {
173	size = sizeUnassigned
174	for _, attr := range n.Attrs {
175		switch attr.Ident {
176		case "varlen":
177			if len(attr.Args) != 0 {
178				comp.error(attr.Pos, "%v attribute has args", attr.Ident)
179			}
180			varlen = true
181		case "size":
182			size = comp.parseSizeAttr(attr)
183		default:
184			comp.error(attr.Pos, "unknown union %v attribute %v",
185				n.Name.Name, attr.Ident)
186		}
187	}
188	return
189}
190
191func (comp *compiler) parseStructAttrs(n *ast.Struct) (packed bool, size, align uint64) {
192	size = sizeUnassigned
193	for _, attr := range n.Attrs {
194		switch {
195		case attr.Ident == "packed":
196			if len(attr.Args) != 0 {
197				comp.error(attr.Pos, "%v attribute has args", attr.Ident)
198			}
199			packed = true
200		case attr.Ident == "align_ptr":
201			if len(attr.Args) != 0 {
202				comp.error(attr.Pos, "%v attribute has args", attr.Ident)
203			}
204			align = comp.ptrSize
205		case strings.HasPrefix(attr.Ident, "align_"):
206			if len(attr.Args) != 0 {
207				comp.error(attr.Pos, "%v attribute has args", attr.Ident)
208			}
209			a, err := strconv.ParseUint(attr.Ident[6:], 10, 64)
210			if err != nil {
211				comp.error(attr.Pos, "bad struct %v alignment %v",
212					n.Name.Name, attr.Ident[6:])
213				continue
214			}
215			if a&(a-1) != 0 || a == 0 || a > 1<<30 {
216				comp.error(attr.Pos, "bad struct %v alignment %v (must be a sane power of 2)",
217					n.Name.Name, a)
218			}
219			align = a
220		case attr.Ident == "size":
221			size = comp.parseSizeAttr(attr)
222		default:
223			comp.error(attr.Pos, "unknown struct %v attribute %v",
224				n.Name.Name, attr.Ident)
225		}
226	}
227	return
228}
229
230func (comp *compiler) parseSizeAttr(attr *ast.Type) uint64 {
231	if len(attr.Args) != 1 {
232		comp.error(attr.Pos, "%v attribute is expected to have 1 argument", attr.Ident)
233		return sizeUnassigned
234	}
235	sz := attr.Args[0]
236	if unexpected, _, ok := checkTypeKind(sz, kindInt); !ok {
237		comp.error(sz.Pos, "unexpected %v, expect int", unexpected)
238		return sizeUnassigned
239	}
240	if sz.HasColon || len(sz.Args) != 0 {
241		comp.error(sz.Pos, "size attribute has colon or args")
242		return sizeUnassigned
243	}
244	return sz.Value
245}
246
247func (comp *compiler) getTypeDesc(t *ast.Type) *typeDesc {
248	if desc := builtinTypes[t.Ident]; desc != nil {
249		return desc
250	}
251	if comp.resources[t.Ident] != nil {
252		return typeResource
253	}
254	if comp.structs[t.Ident] != nil {
255		return typeStruct
256	}
257	if comp.typedefs[t.Ident] != nil {
258		return typeTypedef
259	}
260	return nil
261}
262
263func (comp *compiler) getArgsBase(t *ast.Type, field string, dir prog.Dir, isArg bool) (
264	*typeDesc, []*ast.Type, prog.IntTypeCommon) {
265	desc := comp.getTypeDesc(t)
266	if desc == nil {
267		panic(fmt.Sprintf("no type desc for %#v", *t))
268	}
269	args, opt := removeOpt(t)
270	com := genCommon(t.Ident, field, sizeUnassigned, dir, opt != nil)
271	base := genIntCommon(com, 0, false)
272	if desc.NeedBase {
273		base.TypeSize = comp.ptrSize
274		if !isArg {
275			baseType := args[len(args)-1]
276			args = args[:len(args)-1]
277			base = typeInt.Gen(comp, baseType, nil, base).(*prog.IntType).IntTypeCommon
278		}
279	}
280	return desc, args, base
281}
282
283func (comp *compiler) foreachType(n0 ast.Node,
284	cb func(*ast.Type, *typeDesc, []*ast.Type, prog.IntTypeCommon)) {
285	switch n := n0.(type) {
286	case *ast.Call:
287		for _, arg := range n.Args {
288			comp.foreachSubType(arg.Type, true, cb)
289		}
290		if n.Ret != nil {
291			comp.foreachSubType(n.Ret, true, cb)
292		}
293	case *ast.Resource:
294		comp.foreachSubType(n.Base, false, cb)
295	case *ast.Struct:
296		for _, f := range n.Fields {
297			comp.foreachSubType(f.Type, false, cb)
298		}
299	case *ast.TypeDef:
300		if len(n.Args) == 0 {
301			comp.foreachSubType(n.Type, false, cb)
302		}
303	default:
304		panic(fmt.Sprintf("unexpected node %#v", n0))
305	}
306}
307
308func (comp *compiler) foreachSubType(t *ast.Type, isArg bool,
309	cb func(*ast.Type, *typeDesc, []*ast.Type, prog.IntTypeCommon)) {
310	desc, args, base := comp.getArgsBase(t, "", prog.DirIn, isArg)
311	cb(t, desc, args, base)
312	for i, arg := range args {
313		if desc.Args[i].Type == typeArgType {
314			comp.foreachSubType(arg, desc.Args[i].IsArg, cb)
315		}
316	}
317}
318
319func removeOpt(t *ast.Type) ([]*ast.Type, *ast.Type) {
320	args := t.Args
321	if last := len(args) - 1; last >= 0 && args[last].Ident == "opt" {
322		return args[:last], args[last]
323	}
324	return args, nil
325}
326
327func (comp *compiler) parseIntType(name string) (size uint64, bigEndian bool) {
328	be := strings.HasSuffix(name, "be")
329	if be {
330		name = name[:len(name)-len("be")]
331	}
332	size = comp.ptrSize
333	if name != "intptr" {
334		size, _ = strconv.ParseUint(name[3:], 10, 64)
335		size /= 8
336	}
337	return size, be
338}
339
340func toArray(m map[string]bool) []string {
341	delete(m, "")
342	var res []string
343	for v := range m {
344		if v != "" {
345			res = append(res, v)
346		}
347	}
348	sort.Strings(res)
349	return res
350}
351
352func arrayContains(a []string, v string) bool {
353	for _, s := range a {
354		if s == v {
355			return true
356		}
357	}
358	return false
359}
360