1// Copyright 2018 syzkaller project authors. All rights reserved.
2// Use of this source code is governed by Apache 2 LICENSE that can be found in the LICENSE file.
3
4package prog
5
6import (
7	"fmt"
8)
9
10// Minimize minimizes program p into an equivalent program using the equivalence
11// predicate pred.  It iteratively generates simpler programs and asks pred
12// whether it is equal to the original program or not. If it is equivalent then
13// the simplification attempt is committed and the process continues.
14func Minimize(p0 *Prog, callIndex0 int, crash bool, pred0 func(*Prog, int) bool) (*Prog, int) {
15	pred := func(p *Prog, callIndex int) bool {
16		for _, call := range p.Calls {
17			p.Target.SanitizeCall(call)
18		}
19		p.debugValidate()
20		return pred0(p, callIndex)
21	}
22	name0 := ""
23	if callIndex0 != -1 {
24		if callIndex0 < 0 || callIndex0 >= len(p0.Calls) {
25			panic("bad call index")
26		}
27		name0 = p0.Calls[callIndex0].Meta.Name
28	}
29
30	// Try to remove all calls except the last one one-by-one.
31	p0, callIndex0 = removeCalls(p0, callIndex0, crash, pred)
32
33	// Try to minimize individual args.
34	for i := 0; i < len(p0.Calls); i++ {
35		ctx := &minimizeArgsCtx{
36			target:     p0.Target,
37			p0:         &p0,
38			callIndex0: callIndex0,
39			crash:      crash,
40			pred:       pred,
41			triedPaths: make(map[string]bool),
42		}
43	again:
44		ctx.p = p0.Clone()
45		ctx.call = ctx.p.Calls[i]
46		for j, arg := range ctx.call.Args {
47			if ctx.do(arg, fmt.Sprintf("%v", j)) {
48				goto again
49			}
50		}
51	}
52
53	if callIndex0 != -1 {
54		if callIndex0 < 0 || callIndex0 >= len(p0.Calls) || name0 != p0.Calls[callIndex0].Meta.Name {
55			panic(fmt.Sprintf("bad call index after minimization: ncalls=%v index=%v call=%v/%v",
56				len(p0.Calls), callIndex0, name0, p0.Calls[callIndex0].Meta.Name))
57		}
58	}
59	return p0, callIndex0
60}
61
62func removeCalls(p0 *Prog, callIndex0 int, crash bool, pred func(*Prog, int) bool) (*Prog, int) {
63	for i := len(p0.Calls) - 1; i >= 0; i-- {
64		if i == callIndex0 {
65			continue
66		}
67		callIndex := callIndex0
68		if i < callIndex {
69			callIndex--
70		}
71		p := p0.Clone()
72		p.removeCall(i)
73		if !pred(p, callIndex) {
74			continue
75		}
76		p0 = p
77		callIndex0 = callIndex
78	}
79	return p0, callIndex0
80}
81
82type minimizeArgsCtx struct {
83	target     *Target
84	p0         **Prog
85	p          *Prog
86	call       *Call
87	callIndex0 int
88	crash      bool
89	pred       func(*Prog, int) bool
90	triedPaths map[string]bool
91}
92
93func (ctx *minimizeArgsCtx) do(arg Arg, path string) bool {
94	path += fmt.Sprintf("-%v", arg.Type().FieldName())
95	if ctx.triedPaths[path] {
96		return false
97	}
98	if arg.Type().minimize(ctx, arg, path) {
99		return true
100	}
101	ctx.triedPaths[path] = true
102	return false
103}
104
105func (typ *TypeCommon) minimize(ctx *minimizeArgsCtx, arg Arg, path string) bool {
106	return false
107}
108
109func (typ *StructType) minimize(ctx *minimizeArgsCtx, arg Arg, path string) bool {
110	a := arg.(*GroupArg)
111	for _, innerArg := range a.Inner {
112		if ctx.do(innerArg, path) {
113			return true
114		}
115	}
116	return false
117}
118
119func (typ *UnionType) minimize(ctx *minimizeArgsCtx, arg Arg, path string) bool {
120	return ctx.do(arg.(*UnionArg).Option, path)
121}
122
123func (typ *PtrType) minimize(ctx *minimizeArgsCtx, arg Arg, path string) bool {
124	// TODO: try to remove optional ptrs
125	a := arg.(*PointerArg)
126	if a.Res == nil {
127		return false
128	}
129	return ctx.do(a.Res, path)
130}
131
132func (typ *ArrayType) minimize(ctx *minimizeArgsCtx, arg Arg, path string) bool {
133	a := arg.(*GroupArg)
134	for i := len(a.Inner) - 1; i >= 0; i-- {
135		elem := a.Inner[i]
136		elemPath := fmt.Sprintf("%v-%v", path, i)
137		// Try to remove individual elements one-by-one.
138		if !ctx.crash && !ctx.triedPaths[elemPath] &&
139			(typ.Kind == ArrayRandLen ||
140				typ.Kind == ArrayRangeLen && uint64(len(a.Inner)) > typ.RangeBegin) {
141			ctx.triedPaths[elemPath] = true
142			copy(a.Inner[i:], a.Inner[i+1:])
143			a.Inner = a.Inner[:len(a.Inner)-1]
144			removeArg(elem)
145			ctx.target.assignSizesCall(ctx.call)
146			if ctx.pred(ctx.p, ctx.callIndex0) {
147				*ctx.p0 = ctx.p
148			}
149			return true
150		}
151		if ctx.do(elem, elemPath) {
152			return true
153		}
154	}
155	return false
156}
157
158func (typ *IntType) minimize(ctx *minimizeArgsCtx, arg Arg, path string) bool {
159	return minimizeInt(ctx, arg, path)
160}
161
162func (typ *FlagsType) minimize(ctx *minimizeArgsCtx, arg Arg, path string) bool {
163	return minimizeInt(ctx, arg, path)
164}
165
166func (typ *ProcType) minimize(ctx *minimizeArgsCtx, arg Arg, path string) bool {
167	return minimizeInt(ctx, arg, path)
168}
169
170func minimizeInt(ctx *minimizeArgsCtx, arg Arg, path string) bool {
171	// TODO: try to reset bits in ints
172	// TODO: try to set separate flags
173	if ctx.crash {
174		return false
175	}
176	a := arg.(*ConstArg)
177	def := arg.Type().makeDefaultArg().(*ConstArg)
178	if a.Val == def.Val {
179		return false
180	}
181	v0 := a.Val
182	a.Val = def.Val
183	if ctx.pred(ctx.p, ctx.callIndex0) {
184		*ctx.p0 = ctx.p
185	} else {
186		a.Val = v0
187	}
188	return false
189}
190
191func (typ *ResourceType) minimize(ctx *minimizeArgsCtx, arg Arg, path string) bool {
192	if ctx.crash {
193		return false
194	}
195	a := arg.(*ResultArg)
196	if a.Res == nil {
197		return false
198	}
199	r0 := a.Res
200	delete(a.Res.uses, a)
201	a.Res, a.Val = nil, typ.Default()
202	if ctx.pred(ctx.p, ctx.callIndex0) {
203		*ctx.p0 = ctx.p
204	} else {
205		a.Res, a.Val = r0, 0
206		a.Res.uses[a] = true
207	}
208	return false
209}
210
211func (typ *BufferType) minimize(ctx *minimizeArgsCtx, arg Arg, path string) bool {
212	// TODO: try to set individual bytes to 0
213	if typ.Kind != BufferBlobRand && typ.Kind != BufferBlobRange || typ.Dir() == DirOut {
214		return false
215	}
216	a := arg.(*DataArg)
217	minLen := int(typ.RangeBegin)
218	for step := len(a.Data()) - minLen; len(a.Data()) > minLen && step > 0; {
219		if len(a.Data())-step >= minLen {
220			a.data = a.Data()[:len(a.Data())-step]
221			ctx.target.assignSizesCall(ctx.call)
222			if ctx.pred(ctx.p, ctx.callIndex0) {
223				continue
224			}
225			a.data = a.Data()[:len(a.Data())+step]
226			ctx.target.assignSizesCall(ctx.call)
227		}
228		step /= 2
229		if ctx.crash {
230			break
231		}
232	}
233	*ctx.p0 = ctx.p
234	return false
235}
236