1// Copyright 2017 syzkaller project authors. All rights reserved.
2// Use of this source code is governed by Apache 2 LICENSE that can be found in the LICENSE file.
3
4package prog
5
6import (
7	"fmt"
8)
9
10type ExecProg struct {
11	Calls []ExecCall
12	Vars  []uint64
13}
14
15type ExecCall struct {
16	Meta    *Syscall
17	Index   uint64
18	Args    []ExecArg
19	Copyin  []ExecCopyin
20	Copyout []ExecCopyout
21}
22
23type ExecCopyin struct {
24	Addr uint64
25	Arg  ExecArg
26}
27
28type ExecCopyout struct {
29	Index uint64
30	Addr  uint64
31	Size  uint64
32}
33
34type ExecArg interface{} // one of ExecArg*
35
36type ExecArgConst struct {
37	Size           uint64
38	Format         BinaryFormat
39	Value          uint64
40	BitfieldOffset uint64
41	BitfieldLength uint64
42	PidStride      uint64
43}
44
45type ExecArgResult struct {
46	Size    uint64
47	Format  BinaryFormat
48	Index   uint64
49	DivOp   uint64
50	AddOp   uint64
51	Default uint64
52}
53
54type ExecArgData struct {
55	Data []byte
56}
57
58type ExecArgCsum struct {
59	Size   uint64
60	Kind   uint64
61	Chunks []ExecCsumChunk
62}
63
64type ExecCsumChunk struct {
65	Kind  uint64
66	Value uint64
67	Size  uint64
68}
69
70func (target *Target) DeserializeExec(exec []byte) (ExecProg, error) {
71	dec := &execDecoder{target: target, data: exec}
72	dec.parse()
73	if dec.err != nil {
74		return ExecProg{}, dec.err
75	}
76	if uint64(len(dec.vars)) != dec.numVars {
77		return ExecProg{}, fmt.Errorf("mismatching number of vars: %v/%v",
78			len(dec.vars), dec.numVars)
79	}
80	p := ExecProg{
81		Calls: dec.calls,
82		Vars:  dec.vars,
83	}
84	return p, nil
85}
86
87type execDecoder struct {
88	target  *Target
89	data    []byte
90	err     error
91	numVars uint64
92	vars    []uint64
93	call    ExecCall
94	calls   []ExecCall
95}
96
97func (dec *execDecoder) parse() {
98	for dec.err == nil {
99		switch instr := dec.read(); instr {
100		case execInstrCopyin:
101			dec.commitCall()
102			dec.call.Copyin = append(dec.call.Copyin, ExecCopyin{
103				Addr: dec.read(),
104				Arg:  dec.readArg(),
105			})
106		case execInstrCopyout:
107			dec.call.Copyout = append(dec.call.Copyout, ExecCopyout{
108				Index: dec.read(),
109				Addr:  dec.read(),
110				Size:  dec.read(),
111			})
112		default:
113			dec.commitCall()
114			if instr >= uint64(len(dec.target.Syscalls)) {
115				dec.setErr(fmt.Errorf("bad syscall %v", instr))
116				return
117			}
118			dec.call.Meta = dec.target.Syscalls[instr]
119			dec.call.Index = dec.read()
120			for i := dec.read(); i > 0; i-- {
121				switch arg := dec.readArg(); arg.(type) {
122				case ExecArgConst, ExecArgResult:
123					dec.call.Args = append(dec.call.Args, arg)
124				default:
125					dec.setErr(fmt.Errorf("bad call arg %+v", arg))
126					return
127				}
128			}
129		case execInstrEOF:
130			dec.commitCall()
131			return
132		}
133	}
134}
135
136func (dec *execDecoder) readArg() ExecArg {
137	switch typ := dec.read(); typ {
138	case execArgConst:
139		meta := dec.read()
140		return ExecArgConst{
141			Value:          dec.read(),
142			Size:           meta & 0xff,
143			Format:         BinaryFormat((meta >> 8) & 0xff),
144			BitfieldOffset: (meta >> 16) & 0xff,
145			BitfieldLength: (meta >> 24) & 0xff,
146			PidStride:      meta >> 32,
147		}
148	case execArgResult:
149		meta := dec.read()
150		arg := ExecArgResult{
151			Size:    meta & 0xff,
152			Format:  BinaryFormat((meta >> 8) & 0xff),
153			Index:   dec.read(),
154			DivOp:   dec.read(),
155			AddOp:   dec.read(),
156			Default: dec.read(),
157		}
158		for uint64(len(dec.vars)) <= arg.Index {
159			dec.vars = append(dec.vars, 0)
160		}
161		dec.vars[arg.Index] = arg.Default
162		return arg
163	case execArgData:
164		return ExecArgData{
165			Data: dec.readBlob(dec.read()),
166		}
167	case execArgCsum:
168		size := dec.read()
169		switch kind := dec.read(); kind {
170		case ExecArgCsumInet:
171			chunks := make([]ExecCsumChunk, dec.read())
172			for i := range chunks {
173				chunks[i] = ExecCsumChunk{
174					Kind:  dec.read(),
175					Value: dec.read(),
176					Size:  dec.read(),
177				}
178			}
179			return ExecArgCsum{
180				Size:   size,
181				Kind:   kind,
182				Chunks: chunks,
183			}
184		default:
185			dec.setErr(fmt.Errorf("unknown csum kind %v", kind))
186			return nil
187		}
188	default:
189		dec.setErr(fmt.Errorf("bad argument type %v", typ))
190		return nil
191	}
192}
193
194func (dec *execDecoder) read() uint64 {
195	if len(dec.data) < 8 {
196		dec.setErr(fmt.Errorf("exec program overflow"))
197	}
198	if dec.err != nil {
199		return 0
200	}
201	var v uint64
202	for i := 0; i < 8; i++ {
203		v |= uint64(dec.data[i]) << uint(i*8)
204	}
205	dec.data = dec.data[8:]
206	return v
207}
208
209func (dec *execDecoder) readBlob(size uint64) []byte {
210	padded := (size + 7) / 8 * 8
211	if uint64(len(dec.data)) < padded {
212		dec.setErr(fmt.Errorf("exec program overflow"))
213	}
214	if dec.err != nil {
215		return nil
216	}
217	data := dec.data[:size]
218	dec.data = dec.data[padded:]
219	return data
220}
221
222func (dec *execDecoder) setErr(err error) {
223	if dec.err == nil {
224		dec.err = err
225	}
226}
227
228func (dec *execDecoder) commitCall() {
229	if dec.call.Meta == nil {
230		return
231	}
232	if dec.call.Index != ExecNoCopyout && dec.numVars < dec.call.Index+1 {
233		dec.numVars = dec.call.Index + 1
234	}
235	for _, copyout := range dec.call.Copyout {
236		if dec.numVars < copyout.Index+1 {
237			dec.numVars = copyout.Index + 1
238		}
239	}
240	dec.calls = append(dec.calls, dec.call)
241	dec.call = ExecCall{}
242}
243