1// Copyright 2017 syzkaller project authors. All rights reserved.
2// Use of this source code is governed by Apache 2 LICENSE that can be found in the LICENSE file.
3
4package compiler
5
6import (
7	"fmt"
8	"sort"
9
10	"github.com/google/syzkaller/pkg/ast"
11	"github.com/google/syzkaller/prog"
12)
13
14const sizeUnassigned = ^uint64(0)
15
16func (comp *compiler) genResources() []*prog.ResourceDesc {
17	var resources []*prog.ResourceDesc
18	for name, n := range comp.resources {
19		if !comp.used[name] {
20			continue
21		}
22		resources = append(resources, comp.genResource(n))
23	}
24	sort.Slice(resources, func(i, j int) bool {
25		return resources[i].Name < resources[j].Name
26	})
27	return resources
28}
29
30func (comp *compiler) genResource(n *ast.Resource) *prog.ResourceDesc {
31	res := &prog.ResourceDesc{
32		Name: n.Name.Name,
33	}
34	var base *ast.Type
35	for n != nil {
36		res.Values = append(genIntArray(n.Values), res.Values...)
37		res.Kind = append([]string{n.Name.Name}, res.Kind...)
38		base = n.Base
39		n = comp.resources[n.Base.Ident]
40	}
41	if len(res.Values) == 0 {
42		res.Values = []uint64{0}
43	}
44	res.Type = comp.genType(base, "", prog.DirIn, false)
45	return res
46}
47
48func (comp *compiler) genSyscalls() []*prog.Syscall {
49	var calls []*prog.Syscall
50	for _, decl := range comp.desc.Nodes {
51		if n, ok := decl.(*ast.Call); ok && n.NR != ^uint64(0) {
52			calls = append(calls, comp.genSyscall(n))
53		}
54	}
55	sort.Slice(calls, func(i, j int) bool {
56		return calls[i].Name < calls[j].Name
57	})
58	return calls
59}
60
61func (comp *compiler) genSyscall(n *ast.Call) *prog.Syscall {
62	var ret prog.Type
63	if n.Ret != nil {
64		ret = comp.genType(n.Ret, "ret", prog.DirOut, true)
65	}
66	return &prog.Syscall{
67		Name:     n.Name.Name,
68		CallName: n.CallName,
69		NR:       n.NR,
70		Args:     comp.genFieldArray(n.Args, prog.DirIn, true),
71		Ret:      ret,
72	}
73}
74
75func (comp *compiler) genStructDescs(syscalls []*prog.Syscall) []*prog.KeyedStruct {
76	// Calculate struct/union/array sizes, add padding to structs and detach
77	// StructDesc's from StructType's. StructType's can be recursive so it's
78	// not possible to write them out inline as other types. To break the
79	// recursion detach them, and write StructDesc's out as separate array
80	// of KeyedStruct's. prog package will reattach them during init.
81	ctx := &structGen{
82		comp:   comp,
83		padded: make(map[interface{}]bool),
84		detach: make(map[**prog.StructDesc]bool),
85	}
86	// We have to do this in the loop until we pad nothing new
87	// due to recursive structs.
88	for {
89		start := len(ctx.padded)
90		for _, c := range syscalls {
91			for _, a := range c.Args {
92				ctx.walk(a)
93			}
94			if c.Ret != nil {
95				ctx.walk(c.Ret)
96			}
97		}
98		if start == len(ctx.padded) {
99			break
100		}
101	}
102
103	// Detach StructDesc's from StructType's. prog will reattach them again.
104	for descp := range ctx.detach {
105		*descp = nil
106	}
107
108	sort.Slice(ctx.structs, func(i, j int) bool {
109		si, sj := ctx.structs[i], ctx.structs[j]
110		if si.Key.Name != sj.Key.Name {
111			return si.Key.Name < sj.Key.Name
112		}
113		return si.Key.Dir < sj.Key.Dir
114	})
115	return ctx.structs
116}
117
118type structGen struct {
119	comp    *compiler
120	padded  map[interface{}]bool
121	detach  map[**prog.StructDesc]bool
122	structs []*prog.KeyedStruct
123}
124
125func (ctx *structGen) check(key prog.StructKey, descp **prog.StructDesc) bool {
126	ctx.detach[descp] = true
127	desc := *descp
128	if ctx.padded[desc] {
129		return false
130	}
131	ctx.padded[desc] = true
132	for _, f := range desc.Fields {
133		ctx.walk(f)
134		if !f.Varlen() && f.Size() == sizeUnassigned {
135			// An inner struct is not padded yet.
136			// Leave this struct for next iteration.
137			delete(ctx.padded, desc)
138			return false
139		}
140	}
141	if ctx.comp.used[key.Name] {
142		ctx.structs = append(ctx.structs, &prog.KeyedStruct{
143			Key:  key,
144			Desc: desc,
145		})
146	}
147	return true
148}
149
150func (ctx *structGen) walk(t0 prog.Type) {
151	switch t := t0.(type) {
152	case *prog.PtrType:
153		ctx.walk(t.Type)
154	case *prog.ArrayType:
155		ctx.walkArray(t)
156	case *prog.StructType:
157		ctx.walkStruct(t)
158	case *prog.UnionType:
159		ctx.walkUnion(t)
160	}
161}
162
163func (ctx *structGen) walkArray(t *prog.ArrayType) {
164	if ctx.padded[t] {
165		return
166	}
167	ctx.walk(t.Type)
168	if !t.Type.Varlen() && t.Type.Size() == sizeUnassigned {
169		// An inner struct is not padded yet.
170		// Leave this array for next iteration.
171		return
172	}
173	ctx.padded[t] = true
174	t.TypeSize = 0
175	if t.Kind == prog.ArrayRangeLen && t.RangeBegin == t.RangeEnd && !t.Type.Varlen() {
176		t.TypeSize = t.RangeBegin * t.Type.Size()
177	}
178}
179
180func (ctx *structGen) walkStruct(t *prog.StructType) {
181	if !ctx.check(t.Key, &t.StructDesc) {
182		return
183	}
184	comp := ctx.comp
185	structNode := comp.structNodes[t.StructDesc]
186	// Add paddings, calculate size, mark bitfields.
187	varlen := false
188	for _, f := range t.Fields {
189		if f.Varlen() {
190			varlen = true
191		}
192	}
193	comp.markBitfields(t.Fields)
194	packed, sizeAttr, alignAttr := comp.parseStructAttrs(structNode)
195	t.Fields = comp.addAlignment(t.Fields, varlen, packed, alignAttr)
196	t.AlignAttr = alignAttr
197	t.TypeSize = 0
198	if !varlen {
199		for _, f := range t.Fields {
200			if !f.BitfieldMiddle() {
201				t.TypeSize += f.Size()
202			}
203		}
204		if sizeAttr != sizeUnassigned {
205			if t.TypeSize > sizeAttr {
206				comp.error(structNode.Pos, "struct %v has size attribute %v"+
207					" which is less than struct size %v",
208					structNode.Name.Name, sizeAttr, t.TypeSize)
209			}
210			if pad := sizeAttr - t.TypeSize; pad != 0 {
211				t.Fields = append(t.Fields, genPad(pad))
212			}
213			t.TypeSize = sizeAttr
214		}
215	}
216}
217
218func (ctx *structGen) walkUnion(t *prog.UnionType) {
219	if !ctx.check(t.Key, &t.StructDesc) {
220		return
221	}
222	comp := ctx.comp
223	structNode := comp.structNodes[t.StructDesc]
224	varlen, sizeAttr := comp.parseUnionAttrs(structNode)
225	t.TypeSize = 0
226	if !varlen {
227		for _, fld := range t.Fields {
228			sz := fld.Size()
229			if sizeAttr != sizeUnassigned && sz > sizeAttr {
230				comp.error(structNode.Pos, "union %v has size attribute %v"+
231					" which is less than field %v size %v",
232					structNode.Name.Name, sizeAttr, fld.Name(), sz)
233			}
234			if t.TypeSize < sz {
235				t.TypeSize = sz
236			}
237		}
238		if sizeAttr != sizeUnassigned {
239			t.TypeSize = sizeAttr
240		}
241	}
242}
243
244func (comp *compiler) genStructDesc(res *prog.StructDesc, n *ast.Struct, dir prog.Dir, varlen bool) {
245	// Leave node for genStructDescs to calculate size/padding.
246	comp.structNodes[res] = n
247	common := genCommon(n.Name.Name, "", sizeUnassigned, dir, false)
248	common.IsVarlen = varlen
249	*res = prog.StructDesc{
250		TypeCommon: common,
251		Fields:     comp.genFieldArray(n.Fields, dir, false),
252	}
253}
254
255func (comp *compiler) markBitfields(fields []prog.Type) {
256	var bfOffset uint64
257	for i, f := range fields {
258		if f.BitfieldLength() == 0 {
259			continue
260		}
261		off, middle := bfOffset, true
262		bfOffset += f.BitfieldLength()
263		if i == len(fields)-1 || // Last bitfield in a group, if last field of the struct...
264			fields[i+1].BitfieldLength() == 0 || // or next field is not a bitfield...
265			f.Size() != fields[i+1].Size() || // or next field is of different size...
266			bfOffset+fields[i+1].BitfieldLength() > f.Size()*8 { // or next field does not fit into the current group.
267			middle, bfOffset = false, 0
268		}
269		setBitfieldOffset(f, off, middle)
270	}
271}
272
273func setBitfieldOffset(t0 prog.Type, offset uint64, middle bool) {
274	switch t := t0.(type) {
275	case *prog.IntType:
276		t.BitfieldOff, t.BitfieldMdl = offset, middle
277	case *prog.ConstType:
278		t.BitfieldOff, t.BitfieldMdl = offset, middle
279	case *prog.LenType:
280		t.BitfieldOff, t.BitfieldMdl = offset, middle
281	case *prog.FlagsType:
282		t.BitfieldOff, t.BitfieldMdl = offset, middle
283	case *prog.ProcType:
284		t.BitfieldOff, t.BitfieldMdl = offset, middle
285	default:
286		panic(fmt.Sprintf("type %#v can't be a bitfield", t))
287	}
288}
289
290func (comp *compiler) addAlignment(fields []prog.Type, varlen, packed bool, alignAttr uint64) []prog.Type {
291	var newFields []prog.Type
292	if packed {
293		// If a struct is packed, statically sized and has explicitly set alignment,
294		// add a padding at the end.
295		newFields = fields
296		if !varlen && alignAttr != 0 {
297			size := uint64(0)
298			for _, f := range fields {
299				if !f.BitfieldMiddle() {
300					size += f.Size()
301				}
302			}
303			if tail := size % alignAttr; tail != 0 {
304				newFields = append(newFields, genPad(alignAttr-tail))
305			}
306		}
307		return newFields
308	}
309	var align, off uint64
310	for i, f := range fields {
311		if i == 0 || !fields[i-1].BitfieldMiddle() {
312			a := comp.typeAlign(f)
313			if align < a {
314				align = a
315			}
316			// Append padding if the last field is not a bitfield or it's the last bitfield in a set.
317			if off%a != 0 {
318				pad := a - off%a
319				off += pad
320				newFields = append(newFields, genPad(pad))
321			}
322		}
323		newFields = append(newFields, f)
324		if !f.BitfieldMiddle() && (i != len(fields)-1 || !f.Varlen()) {
325			// Increase offset if the current field is not a bitfield
326			// or it's the last bitfield in a set, except when it's
327			// the last field in a struct and has variable length.
328			off += f.Size()
329		}
330	}
331	if alignAttr != 0 {
332		align = alignAttr
333	}
334	if align != 0 && off%align != 0 && !varlen {
335		pad := align - off%align
336		off += pad
337		newFields = append(newFields, genPad(pad))
338	}
339	return newFields
340}
341
342func (comp *compiler) typeAlign(t0 prog.Type) uint64 {
343	switch t0.(type) {
344	case *prog.IntType, *prog.ConstType, *prog.LenType, *prog.FlagsType, *prog.ProcType,
345		*prog.CsumType, *prog.PtrType, *prog.VmaType, *prog.ResourceType:
346		return t0.Size()
347	case *prog.BufferType:
348		return 1
349	}
350
351	switch t := t0.(type) {
352	case *prog.ArrayType:
353		return comp.typeAlign(t.Type)
354	case *prog.StructType:
355		packed, _, alignAttr := comp.parseStructAttrs(comp.structNodes[t.StructDesc])
356		if alignAttr != 0 {
357			return alignAttr // overrided by user attribute
358		}
359		if packed {
360			return 1
361		}
362		align := uint64(0)
363		for _, f := range t.Fields {
364			if a := comp.typeAlign(f); align < a {
365				align = a
366			}
367		}
368		return align
369	case *prog.UnionType:
370		align := uint64(0)
371		for _, f := range t.Fields {
372			if a := comp.typeAlign(f); align < a {
373				align = a
374			}
375		}
376		return align
377	default:
378		panic(fmt.Sprintf("unknown type: %#v", t))
379	}
380}
381
382func genPad(size uint64) prog.Type {
383	return &prog.ConstType{
384		IntTypeCommon: genIntCommon(genCommon("pad", "", size, prog.DirIn, false), 0, false),
385		IsPad:         true,
386	}
387}
388
389func (comp *compiler) genField(f *ast.Field, dir prog.Dir, isArg bool) prog.Type {
390	return comp.genType(f.Type, f.Name.Name, dir, isArg)
391}
392
393func (comp *compiler) genFieldArray(fields []*ast.Field, dir prog.Dir, isArg bool) []prog.Type {
394	var res []prog.Type
395	for _, f := range fields {
396		res = append(res, comp.genField(f, dir, isArg))
397	}
398	return res
399}
400
401func (comp *compiler) genType(t *ast.Type, field string, dir prog.Dir, isArg bool) prog.Type {
402	desc, args, base := comp.getArgsBase(t, field, dir, isArg)
403	if desc.Gen == nil {
404		panic(fmt.Sprintf("no gen for %v %#v", field, t))
405	}
406	base.IsVarlen = desc.Varlen != nil && desc.Varlen(comp, t, args)
407	return desc.Gen(comp, t, args, base)
408}
409
410func genCommon(name, field string, size uint64, dir prog.Dir, opt bool) prog.TypeCommon {
411	return prog.TypeCommon{
412		TypeName:   name,
413		TypeSize:   size,
414		FldName:    field,
415		ArgDir:     dir,
416		IsOptional: opt,
417	}
418}
419
420func genIntCommon(com prog.TypeCommon, bitLen uint64, bigEndian bool) prog.IntTypeCommon {
421	bf := prog.FormatNative
422	if bigEndian {
423		bf = prog.FormatBigEndian
424	}
425	return prog.IntTypeCommon{
426		TypeCommon:  com,
427		ArgFormat:   bf,
428		BitfieldLen: bitLen,
429	}
430}
431
432func genIntArray(a []*ast.Int) []uint64 {
433	r := make([]uint64, len(a))
434	for i, v := range a {
435		r[i] = v.Value
436	}
437	return r
438}
439
440func genStrArray(a []*ast.String) []string {
441	r := make([]string, len(a))
442	for i, v := range a {
443		r[i] = v.Value
444	}
445	return r
446}
447