1package prog
2
3import (
4	"fmt"
5)
6
7type anyTypes struct {
8	union  *UnionType
9	array  *ArrayType
10	blob   *BufferType
11	ptrPtr *PtrType
12	ptr64  *PtrType
13	res16  *ResourceType
14	res32  *ResourceType
15	res64  *ResourceType
16	resdec *ResourceType
17	reshex *ResourceType
18	resoct *ResourceType
19}
20
21// This generates type descriptions for:
22//
23// resource ANYRES16[int16]: 0xffffffffffffffff, 0
24// resource ANYRES32[int32]: 0xffffffffffffffff, 0
25// resource ANYRES64[int64]: 0xffffffffffffffff, 0
26// ANY [
27// 	bin	array[int8]
28// 	ptr	ptr[in, array[ANY], opt]
29// 	ptr64	ptr64[in, array[ANY], opt]
30// 	res16	ANYRES16
31// 	res32	ANYRES32
32// 	res64	ANYRES64
33//	resdec	fmt[dec, ANYRES64]
34//	reshex	fmt[hex, ANYRES64]
35//	resoct	fmt[oct, ANYRES64]
36// ] [varlen]
37func initAnyTypes(target *Target) {
38	target.any.union = &UnionType{
39		FldName: "ANYUNION",
40	}
41	target.any.array = &ArrayType{
42		TypeCommon: TypeCommon{
43			TypeName: "ANYARRAY",
44			FldName:  "ANYARRAY",
45			IsVarlen: true,
46		},
47		Type: target.any.union,
48	}
49	target.any.ptrPtr = &PtrType{
50		TypeCommon: TypeCommon{
51			TypeName:   "ptr",
52			FldName:    "ANYPTR",
53			TypeSize:   target.PtrSize,
54			IsOptional: true,
55		},
56		Type: target.any.array,
57	}
58	target.any.ptr64 = &PtrType{
59		TypeCommon: TypeCommon{
60			TypeName:   "ptr64",
61			FldName:    "ANYPTR64",
62			TypeSize:   8,
63			IsOptional: true,
64		},
65		Type: target.any.array,
66	}
67	target.any.blob = &BufferType{
68		TypeCommon: TypeCommon{
69			TypeName: "ANYBLOB",
70			FldName:  "ANYBLOB",
71			IsVarlen: true,
72		},
73	}
74	createResource := func(name, base string, bf BinaryFormat, size uint64) *ResourceType {
75		return &ResourceType{
76			TypeCommon: TypeCommon{
77				TypeName:   name,
78				FldName:    name,
79				ArgDir:     DirIn,
80				TypeSize:   size,
81				IsOptional: true,
82			},
83			ArgFormat: bf,
84			Desc: &ResourceDesc{
85				Name:   name,
86				Kind:   []string{name},
87				Values: []uint64{^uint64(0), 0},
88				Type: &IntType{
89					IntTypeCommon: IntTypeCommon{
90						TypeCommon: TypeCommon{
91							TypeName: base,
92							TypeSize: size,
93						},
94					},
95				},
96			},
97		}
98	}
99	target.any.res16 = createResource("ANYRES16", "int16", FormatNative, 2)
100	target.any.res32 = createResource("ANYRES32", "int32", FormatNative, 4)
101	target.any.res64 = createResource("ANYRES64", "int64", FormatNative, 8)
102	target.any.resdec = createResource("ANYRESDEC", "int64", FormatStrDec, 20)
103	target.any.reshex = createResource("ANYRESHEX", "int64", FormatStrHex, 18)
104	target.any.resoct = createResource("ANYRESOCT", "int64", FormatStrOct, 23)
105	target.any.union.StructDesc = &StructDesc{
106		TypeCommon: TypeCommon{
107			TypeName: "ANYUNION",
108			FldName:  "ANYUNION",
109			IsVarlen: true,
110			ArgDir:   DirIn,
111		},
112		Fields: []Type{
113			target.any.blob,
114			target.any.ptrPtr,
115			target.any.ptr64,
116			target.any.res16,
117			target.any.res32,
118			target.any.res64,
119			target.any.resdec,
120			target.any.reshex,
121			target.any.resoct,
122		},
123	}
124}
125
126func (target *Target) makeAnyPtrType(size uint64, field string) *PtrType {
127	// We need to make a copy because type holds field name,
128	// and field names are used as len target.
129	var typ PtrType
130	if size == target.PtrSize {
131		typ = *target.any.ptrPtr
132	} else if size == 8 {
133		typ = *target.any.ptr64
134	} else {
135		panic(fmt.Sprintf("bad pointer size %v", size))
136	}
137	typ.TypeSize = size
138	if field != "" {
139		typ.FldName = field
140	}
141	return &typ
142}
143
144func (target *Target) isAnyPtr(typ Type) bool {
145	ptr, ok := typ.(*PtrType)
146	return ok && ptr.Type == target.any.array
147}
148
149func (p *Prog) complexPtrs() (res []*PointerArg) {
150	for _, c := range p.Calls {
151		ForeachArg(c, func(arg Arg, ctx *ArgCtx) {
152			if ptrArg, ok := arg.(*PointerArg); ok && p.Target.isComplexPtr(ptrArg) {
153				res = append(res, ptrArg)
154				ctx.Stop = true
155			}
156		})
157	}
158	return
159}
160
161func (target *Target) isComplexPtr(arg *PointerArg) bool {
162	if arg.Res == nil || arg.Type().Dir() != DirIn {
163		return false
164	}
165	if target.isAnyPtr(arg.Type()) {
166		return true
167	}
168	res := false
169	ForeachSubArg(arg.Res, func(a1 Arg, ctx *ArgCtx) {
170		switch typ := a1.Type().(type) {
171		case *StructType:
172			if typ.Varlen() {
173				res = true
174				ctx.Stop = true
175			}
176		case *UnionType:
177			if typ.Varlen() && len(typ.Fields) > 5 {
178				res = true
179				ctx.Stop = true
180			}
181		case *PtrType:
182			if a1 != arg {
183				ctx.Stop = true
184			}
185		}
186	})
187	return res
188}
189
190func (target *Target) CallContainsAny(c *Call) (res bool) {
191	ForeachArg(c, func(arg Arg, ctx *ArgCtx) {
192		if target.isAnyPtr(arg.Type()) {
193			res = true
194			ctx.Stop = true
195		}
196	})
197	return
198}
199
200func (target *Target) ArgContainsAny(arg0 Arg) (res bool) {
201	ForeachSubArg(arg0, func(arg Arg, ctx *ArgCtx) {
202		if target.isAnyPtr(arg.Type()) {
203			res = true
204			ctx.Stop = true
205		}
206	})
207	return
208}
209
210func (target *Target) squashPtr(arg *PointerArg, preserveField bool) {
211	if arg.Res == nil || arg.VmaSize != 0 {
212		panic("bad ptr arg")
213	}
214	res0 := arg.Res
215	size0 := res0.Size()
216	var elems []Arg
217	target.squashPtrImpl(arg.Res, &elems)
218	field := ""
219	if preserveField {
220		field = arg.Type().FieldName()
221	}
222	arg.typ = target.makeAnyPtrType(arg.Type().Size(), field)
223	arg.Res = MakeGroupArg(arg.typ.(*PtrType).Type, elems)
224	if size := arg.Res.Size(); size != size0 {
225		panic(fmt.Sprintf("squash changed size %v->%v for %v", size0, size, res0.Type()))
226	}
227}
228
229func (target *Target) squashPtrImpl(a Arg, elems *[]Arg) {
230	if a.Type().BitfieldLength() != 0 {
231		panic("bitfield in squash")
232	}
233	var pad uint64
234	switch arg := a.(type) {
235	case *ConstArg:
236		target.squashConst(arg, elems)
237	case *ResultArg:
238		target.squashResult(arg, elems)
239	case *PointerArg:
240		if arg.Res != nil {
241			target.squashPtr(arg, false)
242			*elems = append(*elems, MakeUnionArg(target.any.union, arg))
243		} else {
244			elem := target.ensureDataElem(elems)
245			addr := target.PhysicalAddr(arg)
246			for i := uint64(0); i < arg.Size(); i++ {
247				elem.data = append(elem.Data(), byte(addr))
248				addr >>= 8
249			}
250		}
251	case *UnionArg:
252		if !arg.Type().Varlen() {
253			pad = arg.Size() - arg.Option.Size()
254		}
255		target.squashPtrImpl(arg.Option, elems)
256	case *DataArg:
257		if arg.Type().Dir() == DirOut {
258			pad = arg.Size()
259		} else {
260			elem := target.ensureDataElem(elems)
261			elem.data = append(elem.Data(), arg.Data()...)
262		}
263	case *GroupArg:
264		target.squashGroup(arg, elems)
265	default:
266		panic("bad arg kind")
267	}
268	if pad != 0 {
269		elem := target.ensureDataElem(elems)
270		elem.data = append(elem.Data(), make([]byte, pad)...)
271	}
272}
273
274func (target *Target) squashConst(arg *ConstArg, elems *[]Arg) {
275	if IsPad(arg.Type()) {
276		elem := target.ensureDataElem(elems)
277		elem.data = append(elem.Data(), make([]byte, arg.Size())...)
278		return
279	}
280	v, bf := target.squashedValue(arg)
281	var data []byte
282	switch bf {
283	case FormatNative:
284		for i := uint64(0); i < arg.Size(); i++ {
285			data = append(data, byte(v))
286			v >>= 8
287		}
288	case FormatStrDec:
289		data = []byte(fmt.Sprintf("%020v", v))
290	case FormatStrHex:
291		data = []byte(fmt.Sprintf("0x%016x", v))
292	case FormatStrOct:
293		data = []byte(fmt.Sprintf("%023o", v))
294	default:
295		panic(fmt.Sprintf("unknown binary format: %v", bf))
296	}
297	if uint64(len(data)) != arg.Size() {
298		panic("squashed value of wrong size")
299	}
300	elem := target.ensureDataElem(elems)
301	elem.data = append(elem.Data(), data...)
302}
303
304func (target *Target) squashResult(arg *ResultArg, elems *[]Arg) {
305	switch arg.Type().Format() {
306	case FormatNative, FormatBigEndian:
307		switch arg.Size() {
308		case 2:
309			arg.typ = target.any.res16
310		case 4:
311			arg.typ = target.any.res32
312		case 8:
313			arg.typ = target.any.res64
314		default:
315			panic("bad size")
316		}
317	case FormatStrDec:
318		arg.typ = target.any.resdec
319	case FormatStrHex:
320		arg.typ = target.any.reshex
321	case FormatStrOct:
322		arg.typ = target.any.resoct
323	default:
324		panic("bad")
325	}
326	*elems = append(*elems, MakeUnionArg(target.any.union, arg))
327}
328
329func (target *Target) squashGroup(arg *GroupArg, elems *[]Arg) {
330	var pad uint64
331	if typ, ok := arg.Type().(*StructType); ok && typ.Varlen() && typ.AlignAttr != 0 {
332		var fieldsSize uint64
333		for _, fld := range arg.Inner {
334			if !fld.Type().BitfieldMiddle() {
335				fieldsSize += fld.Size()
336			}
337		}
338		if fieldsSize%typ.AlignAttr != 0 {
339			pad = typ.AlignAttr - fieldsSize%typ.AlignAttr
340		}
341	}
342	var bitfield uint64
343	for _, fld := range arg.Inner {
344		// Squash bitfields separately.
345		if bfLen := fld.Type().BitfieldLength(); bfLen != 0 {
346			bfOff := fld.Type().BitfieldOffset()
347			// Note: we can have a ResultArg here as well,
348			// but it is unsupported at the moment.
349			v, bf := target.squashedValue(fld.(*ConstArg))
350			if bf != FormatNative {
351				panic(fmt.Sprintf("bitfield has bad format %v", bf))
352			}
353			bitfield |= (v & ((1 << bfLen) - 1)) << bfOff
354			if !fld.Type().BitfieldMiddle() {
355				elem := target.ensureDataElem(elems)
356				for i := uint64(0); i < fld.Size(); i++ {
357					elem.data = append(elem.Data(), byte(bitfield))
358					bitfield >>= 8
359				}
360				bitfield = 0
361			}
362			continue
363		}
364		target.squashPtrImpl(fld, elems)
365	}
366	if pad != 0 {
367		elem := target.ensureDataElem(elems)
368		elem.data = append(elem.Data(), make([]byte, pad)...)
369	}
370}
371
372func (target *Target) squashedValue(arg *ConstArg) (uint64, BinaryFormat) {
373	bf := arg.Type().Format()
374	if _, ok := arg.Type().(*CsumType); ok {
375		// We can't compute value for the checksum here,
376		// but at least leave something recognizable by hints code.
377		// TODO: hints code won't recognize this, because it won't find
378		// the const in any arg. We either need to put this const as
379		// actual csum arg value, or special case it in hints.
380		return 0xabcdef1234567890, FormatNative
381	}
382	// Note: we need a constant value, but it depends on pid for proc.
383	v, _ := arg.Value()
384	if bf == FormatBigEndian {
385		bf = FormatNative
386		switch arg.Size() {
387		case 2:
388			v = uint64(swap16(uint16(v)))
389		case 4:
390			v = uint64(swap32(uint32(v)))
391		case 8:
392			v = swap64(v)
393		default:
394			panic(fmt.Sprintf("bad const size %v", arg.Size()))
395		}
396	}
397	return v, bf
398}
399
400func (target *Target) ensureDataElem(elems *[]Arg) *DataArg {
401	if len(*elems) == 0 {
402		res := MakeDataArg(target.any.blob, nil)
403		*elems = append(*elems, MakeUnionArg(target.any.union, res))
404		return res
405	}
406	res, ok := (*elems)[len(*elems)-1].(*UnionArg).Option.(*DataArg)
407	if !ok {
408		res = MakeDataArg(target.any.blob, nil)
409		*elems = append(*elems, MakeUnionArg(target.any.union, res))
410	}
411	return res
412}
413