1// Copyright 2017 syzkaller project authors. All rights reserved.
2// Use of this source code is governed by Apache 2 LICENSE that can be found in the LICENSE file.
3
4package prog
5
6import (
7	"fmt"
8	"strings"
9)
10
11func (target *Target) generateSize(arg Arg, lenType *LenType) uint64 {
12	if arg == nil {
13		// Arg is an optional pointer, set size to 0.
14		return 0
15	}
16
17	bitSize := lenType.BitSize
18	if bitSize == 0 {
19		bitSize = 8
20	}
21	switch arg.Type().(type) {
22	case *VmaType:
23		a := arg.(*PointerArg)
24		return a.VmaSize * 8 / bitSize
25	case *ArrayType:
26		a := arg.(*GroupArg)
27		if lenType.BitSize != 0 {
28			return a.Size() * 8 / bitSize
29		}
30		return uint64(len(a.Inner))
31	default:
32		return arg.Size() * 8 / bitSize
33	}
34}
35
36func (target *Target) assignSizes(args []Arg, parentsMap map[Arg]Arg) {
37	// Create a map from field names to args.
38	argsMap := make(map[string]Arg)
39	for _, arg := range args {
40		if IsPad(arg.Type()) {
41			continue
42		}
43		argsMap[arg.Type().FieldName()] = arg
44	}
45
46	// Fill in size arguments.
47	for _, arg := range args {
48		if arg = InnerArg(arg); arg == nil {
49			continue // Pointer to optional len field, no need to fill in value.
50		}
51		if typ, ok := arg.Type().(*LenType); ok {
52			a := arg.(*ConstArg)
53
54			buf, ok := argsMap[typ.Buf]
55			if ok {
56				a.Val = target.generateSize(InnerArg(buf), typ)
57				continue
58			}
59
60			if typ.Buf == "parent" {
61				a.Val = parentsMap[arg].Size()
62				if typ.BitSize != 0 {
63					a.Val = a.Val * 8 / typ.BitSize
64				}
65				continue
66			}
67
68			sizeAssigned := false
69			for parent := parentsMap[arg]; parent != nil; parent = parentsMap[parent] {
70				parentName := parent.Type().Name()
71				if pos := strings.IndexByte(parentName, '['); pos != -1 {
72					// For template parents, strip arguments.
73					parentName = parentName[:pos]
74				}
75				if typ.Buf == parentName {
76					a.Val = parent.Size()
77					if typ.BitSize != 0 {
78						a.Val = a.Val * 8 / typ.BitSize
79					}
80					sizeAssigned = true
81					break
82				}
83			}
84			if sizeAssigned {
85				continue
86			}
87
88			panic(fmt.Sprintf("len field '%v' references non existent field '%v', argsMap: %+v",
89				typ.FieldName(), typ.Buf, argsMap))
90		}
91	}
92}
93
94func (target *Target) assignSizesArray(args []Arg) {
95	parentsMap := make(map[Arg]Arg)
96	for _, arg := range args {
97		ForeachSubArg(arg, func(arg Arg, _ *ArgCtx) {
98			if _, ok := arg.Type().(*StructType); ok {
99				for _, field := range arg.(*GroupArg).Inner {
100					parentsMap[InnerArg(field)] = arg
101				}
102			}
103		})
104	}
105	target.assignSizes(args, parentsMap)
106	for _, arg := range args {
107		ForeachSubArg(arg, func(arg Arg, _ *ArgCtx) {
108			if _, ok := arg.Type().(*StructType); ok {
109				target.assignSizes(arg.(*GroupArg).Inner, parentsMap)
110			}
111		})
112	}
113}
114
115func (target *Target) assignSizesCall(c *Call) {
116	target.assignSizesArray(c.Args)
117}
118
119func (r *randGen) mutateSize(arg *ConstArg, parent []Arg) bool {
120	typ := arg.Type().(*LenType)
121	elemSize := typ.BitSize / 8
122	if elemSize == 0 {
123		elemSize = 1
124		for _, field := range parent {
125			if typ.Buf != field.Type().FieldName() {
126				continue
127			}
128			if inner := InnerArg(field); inner != nil {
129				switch targetType := inner.Type().(type) {
130				case *VmaType:
131					return false
132				case *ArrayType:
133					if targetType.Type.Varlen() {
134						return false
135					}
136					elemSize = targetType.Type.Size()
137				}
138			}
139			break
140		}
141	}
142	if r.oneOf(100) {
143		arg.Val = r.rand64()
144		return true
145	}
146	if r.bin() {
147		// Small adjustment to trigger missed size checks.
148		if arg.Val != 0 && r.bin() {
149			arg.Val = r.randRangeInt(0, arg.Val-1)
150		} else {
151			arg.Val = r.randRangeInt(arg.Val+1, arg.Val+1000)
152		}
153		return true
154	}
155	// Try to provoke int overflows.
156	max := ^uint64(0)
157	if r.oneOf(3) {
158		max = 1<<32 - 1
159		if r.oneOf(2) {
160			max = 1<<16 - 1
161			if r.oneOf(2) {
162				max = 1<<8 - 1
163			}
164		}
165	}
166	n := max / elemSize
167	delta := uint64(1000 - r.biasedRand(1000, 10))
168	if elemSize == 1 || r.oneOf(10) {
169		n -= delta
170	} else {
171		n += delta
172	}
173	arg.Val = n
174	return true
175}
176