1// Copyright 2017 syzkaller project authors. All rights reserved.
2// Use of this source code is governed by Apache 2 LICENSE that can be found in the LICENSE file.
3
4package prog
5
6import (
7	"fmt"
8)
9
10type CsumChunkKind int
11
12const (
13	CsumChunkArg CsumChunkKind = iota
14	CsumChunkConst
15)
16
17type CsumInfo struct {
18	Kind   CsumKind
19	Chunks []CsumChunk
20}
21
22type CsumChunk struct {
23	Kind  CsumChunkKind
24	Arg   Arg    // for CsumChunkArg
25	Value uint64 // for CsumChunkConst
26	Size  uint64 // for CsumChunkConst
27}
28
29func calcChecksumsCall(c *Call) (map[Arg]CsumInfo, map[Arg]struct{}) {
30	var inetCsumFields, pseudoCsumFields []Arg
31
32	// Find all csum fields.
33	ForeachArg(c, func(arg Arg, _ *ArgCtx) {
34		if typ, ok := arg.Type().(*CsumType); ok {
35			switch typ.Kind {
36			case CsumInet:
37				inetCsumFields = append(inetCsumFields, arg)
38			case CsumPseudo:
39				pseudoCsumFields = append(pseudoCsumFields, arg)
40			default:
41				panic(fmt.Sprintf("unknown csum kind %v", typ.Kind))
42			}
43		}
44	})
45
46	if len(inetCsumFields) == 0 && len(pseudoCsumFields) == 0 {
47		return nil, nil
48	}
49
50	// Build map of each field to its parent struct.
51	parentsMap := make(map[Arg]Arg)
52	ForeachArg(c, func(arg Arg, _ *ArgCtx) {
53		if _, ok := arg.Type().(*StructType); ok {
54			for _, field := range arg.(*GroupArg).Inner {
55				parentsMap[InnerArg(field)] = arg
56			}
57		}
58	})
59
60	csumMap := make(map[Arg]CsumInfo)
61	csumUses := make(map[Arg]struct{})
62
63	// Calculate generic inet checksums.
64	for _, arg := range inetCsumFields {
65		typ, _ := arg.Type().(*CsumType)
66		csummedArg := findCsummedArg(arg, typ, parentsMap)
67		csumUses[csummedArg] = struct{}{}
68		chunk := CsumChunk{CsumChunkArg, csummedArg, 0, 0}
69		csumMap[arg] = CsumInfo{Kind: CsumInet, Chunks: []CsumChunk{chunk}}
70	}
71
72	// No need to continue if there are no pseudo csum fields.
73	if len(pseudoCsumFields) == 0 {
74		return csumMap, csumUses
75	}
76
77	// Extract ipv4 or ipv6 source and destination addresses.
78	var ipSrcAddr, ipDstAddr Arg
79	ForeachArg(c, func(arg Arg, _ *ArgCtx) {
80		groupArg, ok := arg.(*GroupArg)
81		if !ok {
82			return
83		}
84		// syz_csum_* structs are used in tests
85		switch groupArg.Type().Name() {
86		case "ipv4_header", "syz_csum_ipv4_header":
87			ipSrcAddr, ipDstAddr = extractHeaderParams(groupArg, 4)
88		case "ipv6_packet", "syz_csum_ipv6_header":
89			ipSrcAddr, ipDstAddr = extractHeaderParams(groupArg, 16)
90		}
91	})
92	if ipSrcAddr == nil || ipDstAddr == nil {
93		panic("no ipv4 nor ipv6 header found")
94	}
95
96	// Calculate pseudo checksums.
97	for _, arg := range pseudoCsumFields {
98		typ, _ := arg.Type().(*CsumType)
99		csummedArg := findCsummedArg(arg, typ, parentsMap)
100		protocol := uint8(typ.Protocol)
101		var info CsumInfo
102		if ipSrcAddr.Size() == 4 {
103			info = composePseudoCsumIPv4(csummedArg, ipSrcAddr, ipDstAddr, protocol)
104		} else {
105			info = composePseudoCsumIPv6(csummedArg, ipSrcAddr, ipDstAddr, protocol)
106		}
107		csumMap[arg] = info
108	}
109
110	return csumMap, csumUses
111}
112
113func findCsummedArg(arg Arg, typ *CsumType, parentsMap map[Arg]Arg) Arg {
114	if typ.Buf == "parent" {
115		if csummedArg, ok := parentsMap[arg]; ok {
116			return csummedArg
117		}
118		panic(fmt.Sprintf("parent for %v is not in parents map", typ.Name()))
119	} else {
120		for parent := parentsMap[arg]; parent != nil; parent = parentsMap[parent] {
121			if typ.Buf == parent.Type().Name() {
122				return parent
123			}
124		}
125	}
126	panic(fmt.Sprintf("csum field '%v' references non existent field '%v'", typ.FieldName(), typ.Buf))
127}
128
129func composePseudoCsumIPv4(tcpPacket, srcAddr, dstAddr Arg, protocol uint8) CsumInfo {
130	info := CsumInfo{Kind: CsumInet}
131	info.Chunks = append(info.Chunks, CsumChunk{CsumChunkArg, srcAddr, 0, 0})
132	info.Chunks = append(info.Chunks, CsumChunk{CsumChunkArg, dstAddr, 0, 0})
133	info.Chunks = append(info.Chunks, CsumChunk{CsumChunkConst, nil, uint64(swap16(uint16(protocol))), 2})
134	info.Chunks = append(info.Chunks, CsumChunk{CsumChunkConst, nil, uint64(swap16(uint16(tcpPacket.Size()))), 2})
135	info.Chunks = append(info.Chunks, CsumChunk{CsumChunkArg, tcpPacket, 0, 0})
136	return info
137}
138
139func composePseudoCsumIPv6(tcpPacket, srcAddr, dstAddr Arg, protocol uint8) CsumInfo {
140	info := CsumInfo{Kind: CsumInet}
141	info.Chunks = append(info.Chunks, CsumChunk{CsumChunkArg, srcAddr, 0, 0})
142	info.Chunks = append(info.Chunks, CsumChunk{CsumChunkArg, dstAddr, 0, 0})
143	info.Chunks = append(info.Chunks, CsumChunk{CsumChunkConst, nil, uint64(swap32(uint32(tcpPacket.Size()))), 4})
144	info.Chunks = append(info.Chunks, CsumChunk{CsumChunkConst, nil, uint64(swap32(uint32(protocol))), 4})
145	info.Chunks = append(info.Chunks, CsumChunk{CsumChunkArg, tcpPacket, 0, 0})
146	return info
147}
148
149func extractHeaderParams(arg *GroupArg, size uint64) (Arg, Arg) {
150	srcAddr := getFieldByName(arg, "src_ip")
151	dstAddr := getFieldByName(arg, "dst_ip")
152	if srcAddr.Size() != size || dstAddr.Size() != size {
153		panic(fmt.Sprintf("src/dst_ip fields in %v must be %v bytes", arg.Type().Name(), size))
154	}
155	return srcAddr, dstAddr
156}
157
158func getFieldByName(arg *GroupArg, name string) Arg {
159	for _, field := range arg.Inner {
160		if field.Type().FieldName() == name {
161			return field
162		}
163	}
164	panic(fmt.Sprintf("failed to find %v field in %v", name, arg.Type().Name()))
165}
166