1// Copyright 2015 syzkaller project authors. All rights reserved.
2// Use of this source code is governed by Apache 2 LICENSE that can be found in the LICENSE file.
3
4package prog
5
6import (
7	"fmt"
8)
9
10type Prog struct {
11	Target   *Target
12	Calls    []*Call
13	Comments []string
14}
15
16type Call struct {
17	Meta    *Syscall
18	Args    []Arg
19	Ret     *ResultArg
20	Comment string
21}
22
23type Arg interface {
24	Type() Type
25	Size() uint64
26
27	validate(ctx *validCtx) error
28	serialize(ctx *serializer)
29}
30
31type ArgCommon struct {
32	typ Type
33}
34
35func (arg *ArgCommon) Type() Type {
36	return arg.typ
37}
38
39// Used for ConstType, IntType, FlagsType, LenType, ProcType and CsumType.
40type ConstArg struct {
41	ArgCommon
42	Val uint64
43}
44
45func MakeConstArg(t Type, v uint64) *ConstArg {
46	return &ConstArg{ArgCommon: ArgCommon{typ: t}, Val: v}
47}
48
49func (arg *ConstArg) Size() uint64 {
50	return arg.typ.Size()
51}
52
53// Value returns value, pid stride and endianness.
54func (arg *ConstArg) Value() (uint64, uint64) {
55	switch typ := (*arg).Type().(type) {
56	case *IntType:
57		return arg.Val, 0
58	case *ConstType:
59		return arg.Val, 0
60	case *FlagsType:
61		return arg.Val, 0
62	case *LenType:
63		return arg.Val, 0
64	case *ResourceType:
65		return arg.Val, 0
66	case *CsumType:
67		// Checksums are computed dynamically in executor.
68		return 0, 0
69	case *ProcType:
70		if arg.Val == procDefaultValue {
71			return 0, 0
72		}
73		return typ.ValuesStart + arg.Val, typ.ValuesPerProc
74	default:
75		panic(fmt.Sprintf("unknown ConstArg type %#v", typ))
76	}
77}
78
79// Used for PtrType and VmaType.
80type PointerArg struct {
81	ArgCommon
82	Address uint64
83	VmaSize uint64 // size of the referenced region for vma args
84	Res     Arg    // pointee (nil for vma)
85}
86
87func MakePointerArg(t Type, addr uint64, data Arg) *PointerArg {
88	if data == nil {
89		panic("nil pointer data arg")
90	}
91	return &PointerArg{
92		ArgCommon: ArgCommon{typ: t},
93		Address:   addr,
94		Res:       data,
95	}
96}
97
98func MakeVmaPointerArg(t Type, addr, size uint64) *PointerArg {
99	if addr%1024 != 0 {
100		panic("unaligned vma address")
101	}
102	return &PointerArg{
103		ArgCommon: ArgCommon{typ: t},
104		Address:   addr,
105		VmaSize:   size,
106	}
107}
108
109func MakeNullPointerArg(t Type) *PointerArg {
110	return &PointerArg{
111		ArgCommon: ArgCommon{typ: t},
112	}
113}
114
115func (arg *PointerArg) Size() uint64 {
116	return arg.typ.Size()
117}
118
119func (arg *PointerArg) IsNull() bool {
120	return arg.Address == 0 && arg.VmaSize == 0 && arg.Res == nil
121}
122
123// Used for BufferType.
124type DataArg struct {
125	ArgCommon
126	data []byte // for in/inout args
127	size uint64 // for out Args
128}
129
130func MakeDataArg(t Type, data []byte) *DataArg {
131	if t.Dir() == DirOut {
132		panic("non-empty output data arg")
133	}
134	return &DataArg{ArgCommon: ArgCommon{typ: t}, data: append([]byte{}, data...)}
135}
136
137func MakeOutDataArg(t Type, size uint64) *DataArg {
138	if t.Dir() != DirOut {
139		panic("empty input data arg")
140	}
141	return &DataArg{ArgCommon: ArgCommon{typ: t}, size: size}
142}
143
144func (arg *DataArg) Size() uint64 {
145	if len(arg.data) != 0 {
146		return uint64(len(arg.data))
147	}
148	return arg.size
149}
150
151func (arg *DataArg) Data() []byte {
152	if arg.Type().Dir() == DirOut {
153		panic("getting data of output data arg")
154	}
155	return arg.data
156}
157
158// Used for StructType and ArrayType.
159// Logical group of args (struct or array).
160type GroupArg struct {
161	ArgCommon
162	Inner []Arg
163}
164
165func MakeGroupArg(t Type, inner []Arg) *GroupArg {
166	return &GroupArg{ArgCommon: ArgCommon{typ: t}, Inner: inner}
167}
168
169func (arg *GroupArg) Size() uint64 {
170	typ0 := arg.Type()
171	if !typ0.Varlen() {
172		return typ0.Size()
173	}
174	switch typ := typ0.(type) {
175	case *StructType:
176		var size uint64
177		for _, fld := range arg.Inner {
178			if !fld.Type().BitfieldMiddle() {
179				size += fld.Size()
180			}
181		}
182		if typ.AlignAttr != 0 && size%typ.AlignAttr != 0 {
183			size += typ.AlignAttr - size%typ.AlignAttr
184		}
185		return size
186	case *ArrayType:
187		var size uint64
188		for _, elem := range arg.Inner {
189			size += elem.Size()
190		}
191		return size
192	default:
193		panic(fmt.Sprintf("bad group arg type %v", typ))
194	}
195}
196
197func (arg *GroupArg) fixedInnerSize() bool {
198	switch typ := arg.Type().(type) {
199	case *StructType:
200		return true
201	case *ArrayType:
202		return typ.Kind == ArrayRangeLen && typ.RangeBegin == typ.RangeEnd
203	default:
204		panic(fmt.Sprintf("bad group arg type %v", typ))
205	}
206}
207
208// Used for UnionType.
209type UnionArg struct {
210	ArgCommon
211	Option Arg
212}
213
214func MakeUnionArg(t Type, opt Arg) *UnionArg {
215	return &UnionArg{ArgCommon: ArgCommon{typ: t}, Option: opt}
216}
217
218func (arg *UnionArg) Size() uint64 {
219	if !arg.Type().Varlen() {
220		return arg.Type().Size()
221	}
222	return arg.Option.Size()
223}
224
225// Used for ResourceType.
226// This is the only argument that can be used as syscall return value.
227// Either holds constant value or reference another ResultArg.
228type ResultArg struct {
229	ArgCommon
230	Res   *ResultArg          // reference to arg which we use
231	OpDiv uint64              // divide result (executed before OpAdd)
232	OpAdd uint64              // add to result
233	Val   uint64              // value used if Res is nil
234	uses  map[*ResultArg]bool // ArgResult args that use this arg
235}
236
237func MakeResultArg(t Type, r *ResultArg, v uint64) *ResultArg {
238	arg := &ResultArg{ArgCommon: ArgCommon{typ: t}, Res: r, Val: v}
239	if r == nil {
240		return arg
241	}
242	if r.uses == nil {
243		r.uses = make(map[*ResultArg]bool)
244	}
245	r.uses[arg] = true
246	return arg
247}
248
249func MakeReturnArg(t Type) *ResultArg {
250	if t == nil {
251		return nil
252	}
253	if t.Dir() != DirOut {
254		panic("return arg is not out")
255	}
256	return &ResultArg{ArgCommon: ArgCommon{typ: t}}
257}
258
259func (arg *ResultArg) Size() uint64 {
260	return arg.typ.Size()
261}
262
263// Returns inner arg for pointer args.
264func InnerArg(arg Arg) Arg {
265	if t, ok := arg.Type().(*PtrType); ok {
266		if a, ok := arg.(*PointerArg); ok {
267			if a.Res == nil {
268				if !t.Optional() {
269					panic(fmt.Sprintf("non-optional pointer is nil\narg: %+v\ntype: %+v", a, t))
270				}
271				return nil
272			}
273			return InnerArg(a.Res)
274		}
275		return nil // *ConstArg.
276	}
277	return arg // Not a pointer.
278}
279
280func isDefault(arg Arg) bool {
281	return arg.Type().isDefaultArg(arg)
282}
283
284func (p *Prog) insertBefore(c *Call, calls []*Call) {
285	idx := 0
286	for ; idx < len(p.Calls); idx++ {
287		if p.Calls[idx] == c {
288			break
289		}
290	}
291	var newCalls []*Call
292	newCalls = append(newCalls, p.Calls[:idx]...)
293	newCalls = append(newCalls, calls...)
294	if idx < len(p.Calls) {
295		newCalls = append(newCalls, p.Calls[idx])
296		newCalls = append(newCalls, p.Calls[idx+1:]...)
297	}
298	p.Calls = newCalls
299}
300
301// replaceArg replaces arg with arg1 in a program.
302func replaceArg(arg, arg1 Arg) {
303	switch a := arg.(type) {
304	case *ConstArg:
305		*a = *arg1.(*ConstArg)
306	case *ResultArg:
307		replaceResultArg(a, arg1.(*ResultArg))
308	case *PointerArg:
309		*a = *arg1.(*PointerArg)
310	case *UnionArg:
311		*a = *arg1.(*UnionArg)
312	case *DataArg:
313		*a = *arg1.(*DataArg)
314	case *GroupArg:
315		a1 := arg1.(*GroupArg)
316		if len(a.Inner) != len(a1.Inner) {
317			panic(fmt.Sprintf("replaceArg: group fields don't match: %v/%v",
318				len(a.Inner), len(a1.Inner)))
319		}
320		a.ArgCommon = a1.ArgCommon
321		for i := range a.Inner {
322			replaceArg(a.Inner[i], a1.Inner[i])
323		}
324	default:
325		panic(fmt.Sprintf("replaceArg: bad arg kind %#v", arg))
326	}
327}
328
329func replaceResultArg(arg, arg1 *ResultArg) {
330	// Remove link from `a.Res` to `arg`.
331	if arg.Res != nil {
332		delete(arg.Res.uses, arg)
333	}
334	// Copy all fields from `arg1` to `arg` except for the list of args that use `arg`.
335	uses := arg.uses
336	*arg = *arg1
337	arg.uses = uses
338	// Make the link in `arg.Res` (which is now `Res` of `arg1`) to point to `arg` instead of `arg1`.
339	if arg.Res != nil {
340		resUses := arg.Res.uses
341		delete(resUses, arg1)
342		resUses[arg] = true
343	}
344}
345
346// removeArg removes all references to/from arg0 from a program.
347func removeArg(arg0 Arg) {
348	ForeachSubArg(arg0, func(arg Arg, ctx *ArgCtx) {
349		a, ok := arg.(*ResultArg)
350		if !ok {
351			return
352		}
353		if a.Res != nil {
354			uses := a.Res.uses
355			if !uses[a] {
356				panic("broken tree")
357			}
358			delete(uses, a)
359		}
360		for arg1 := range a.uses {
361			arg2 := arg1.Type().makeDefaultArg().(*ResultArg)
362			replaceResultArg(arg1, arg2)
363		}
364	})
365}
366
367// removeCall removes call idx from p.
368func (p *Prog) removeCall(idx int) {
369	c := p.Calls[idx]
370	for _, arg := range c.Args {
371		removeArg(arg)
372	}
373	if c.Ret != nil {
374		removeArg(c.Ret)
375	}
376	copy(p.Calls[idx:], p.Calls[idx+1:])
377	p.Calls = p.Calls[:len(p.Calls)-1]
378}
379