1 /*
2  * Copyright (C) 2023 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 @file:JvmName("PacketReflectorUtil")
18 
19 package com.android.testutils
20 
21 import android.system.ErrnoException
22 import android.system.Os
23 import android.system.OsConstants
24 import com.android.net.module.util.IpUtils
25 import com.android.testutils.PacketReflector.IPV4_HEADER_LENGTH
26 import com.android.testutils.PacketReflector.IPV6_HEADER_LENGTH
27 import java.io.FileDescriptor
28 import java.io.InterruptedIOException
29 import java.net.InetAddress
30 import java.nio.ByteBuffer
31 
readPacketnull32 fun readPacket(fd: FileDescriptor, buf: ByteArray): Int {
33     return try {
34         Os.read(fd, buf, 0, buf.size)
35     } catch (e: ErrnoException) {
36         // Ignore normal use cases such as the EAGAIN error indicates that the read operation
37         // cannot be completed immediately, or the EINTR error indicates that the read
38         // operation was interrupted by a signal.
39         if (e.errno == OsConstants.EAGAIN || e.errno == OsConstants.EINTR) {
40             -1
41         } else {
42             throw e
43         }
44     } catch (e: InterruptedIOException) {
45         -1
46     }
47 }
48 
getInetAddressAtnull49 fun getInetAddressAt(buf: ByteArray, pos: Int, len: Int): InetAddress =
50     InetAddress.getByAddress(buf.copyOfRange(pos, pos + len))
51 
52 /**
53  * Reads a 16-bit unsigned int at pos in big endian, with no alignment requirements.
54  */
55 fun getPortAt(buf: ByteArray, pos: Int): Int {
56     return (buf[pos].toInt() and 0xff shl 8) + (buf[pos + 1].toInt() and 0xff)
57 }
58 
setPortAtnull59 fun setPortAt(port: Int, buf: ByteArray, pos: Int) {
60     buf[pos] = (port ushr 8).toByte()
61     buf[pos + 1] = (port and 0xff).toByte()
62 }
63 
getAddressPositionAndLengthnull64 fun getAddressPositionAndLength(version: Int) = when (version) {
65     4 -> PacketReflector.IPV4_ADDR_OFFSET to PacketReflector.IPV4_ADDR_LENGTH
66     6 -> PacketReflector.IPV6_ADDR_OFFSET to PacketReflector.IPV6_ADDR_LENGTH
67     else -> throw IllegalArgumentException("Unknown IP version $version")
68 }
69 
70 private const val IPV4_CHKSUM_OFFSET = 10
71 private const val UDP_CHECKSUM_OFFSET = 6
72 private const val TCP_CHECKSUM_OFFSET = 16
73 
fixPacketChecksumnull74 fun fixPacketChecksum(buf: ByteArray, len: Int, version: Int, protocol: Byte) {
75     // Fill Ip checksum for IPv4. IPv6 header doesn't have a checksum field.
76     if (version == 4) {
77         val checksum = IpUtils.ipChecksum(ByteBuffer.wrap(buf), 0)
78         // Place checksum in Big-endian order.
79         buf[IPV4_CHKSUM_OFFSET] = (checksum.toInt() ushr 8).toByte()
80         buf[IPV4_CHKSUM_OFFSET + 1] = (checksum.toInt() and 0xff).toByte()
81     }
82 
83     // Fill transport layer checksum.
84     val transportOffset = if (version == 4) IPV4_HEADER_LENGTH else IPV6_HEADER_LENGTH
85     when (protocol) {
86         PacketReflector.IPPROTO_UDP -> {
87             val checksumPos = transportOffset + UDP_CHECKSUM_OFFSET
88             // Clear before calculate.
89             buf[checksumPos + 1] = 0x00
90             buf[checksumPos] = buf[checksumPos + 1]
91             val checksum = IpUtils.udpChecksum(
92                 ByteBuffer.wrap(buf), 0,
93                 transportOffset
94             )
95             buf[checksumPos] = (checksum.toInt() ushr 8).toByte()
96             buf[checksumPos + 1] = (checksum.toInt() and 0xff).toByte()
97         }
98         PacketReflector.IPPROTO_TCP -> {
99             val checksumPos = transportOffset + TCP_CHECKSUM_OFFSET
100             // Clear before calculate.
101             buf[checksumPos + 1] = 0x00
102             buf[checksumPos] = buf[checksumPos + 1]
103             val transportLen: Int = len - transportOffset
104             val checksum = IpUtils.tcpChecksum(
105                 ByteBuffer.wrap(buf), 0, transportOffset,
106                 transportLen
107             )
108             buf[checksumPos] = (checksum.toInt() ushr 8).toByte()
109             buf[checksumPos + 1] = (checksum.toInt() and 0xff).toByte()
110         }
111         // TODO: Support ICMP.
112         else -> throw IllegalArgumentException("Unsupported protocol: $protocol")
113     }
114 }
115 
swapBytesnull116 fun swapBytes(buf: ByteArray, pos1: Int, pos2: Int, len: Int) {
117     for (i in 0 until len) {
118         val b = buf[pos1 + i]
119         buf[pos1 + i] = buf[pos2 + i]
120         buf[pos2 + i] = b
121     }
122 }
123 
swapAddressesnull124 fun swapAddresses(buf: ByteArray, version: Int) {
125     val addrPos: Int
126     val addrLen: Int
127     when (version) {
128         4 -> {
129             addrPos = PacketReflector.IPV4_ADDR_OFFSET
130             addrLen = PacketReflector.IPV4_ADDR_LENGTH
131         }
132         6 -> {
133             addrPos = PacketReflector.IPV6_ADDR_OFFSET
134             addrLen = PacketReflector.IPV6_ADDR_LENGTH
135         }
136         else -> throw java.lang.IllegalArgumentException()
137     }
138     swapBytes(buf, addrPos, addrPos + addrLen, addrLen)
139 }
140