1 /*
2  * Copyright (C) 2022 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 package com.android.testutils
18 
19 import android.net.DnsResolver.CLASS_IN
20 import android.net.DnsResolver.TYPE_AAAA
21 import android.net.Network
22 import androidx.test.ext.junit.runners.AndroidJUnit4
23 import androidx.test.filters.SmallTest
24 import com.android.net.module.util.DnsPacket
25 import com.android.net.module.util.DnsPacket.DnsRecord
26 import libcore.net.InetAddressUtils
27 import org.junit.After
28 import org.junit.Test
29 import org.junit.runner.RunWith
30 import org.mockito.Mockito
31 import java.net.DatagramPacket
32 import java.net.DatagramSocket
33 import java.net.InetAddress
34 import java.net.InetSocketAddress
35 import kotlin.test.assertEquals
36 import kotlin.test.assertFailsWith
37 import kotlin.test.assertFalse
38 import kotlin.test.assertTrue
39 
40 val TEST_V6_ADDR = InetAddressUtils.parseNumericAddress("2001:db8::3")
41 const val TEST_DOMAIN = "hello.example.com"
42 
43 @RunWith(AndroidJUnit4::class)
44 @SmallTest
45 class TestDnsServerTest {
46     private val network = Mockito.mock(Network::class.java)
47     private val localAddr = InetSocketAddress(InetAddress.getLocalHost(), 0 /* port */)
48     private val testServer: TestDnsServer = TestDnsServer(network, localAddr)
49 
50     @After
tearDownnull51     fun tearDown() {
52         if (testServer.isAlive) testServer.stop()
53     }
54 
55     @Test
testStartStopnull56     fun testStartStop() {
57         repeat(100) {
58             val server = TestDnsServer(network, localAddr)
59             server.start()
60             assertTrue(server.isAlive)
61             server.stop()
62             assertFalse(server.isAlive)
63         }
64 
65         // Test illegal start/stop.
66         assertFailsWith<IllegalStateException> { testServer.stop() }
67         testServer.start()
68         assertTrue(testServer.isAlive)
69         assertFailsWith<IllegalStateException> { testServer.start() }
70         testServer.stop()
71         assertFalse(testServer.isAlive)
72         assertFailsWith<IllegalStateException> { testServer.stop() }
73         // TestDnsServer rejects start after stop.
74         assertFailsWith<IllegalStateException> { testServer.start() }
75     }
76 
77     @Test
testHandleDnsQuerynull78     fun testHandleDnsQuery() {
79         testServer.setAnswer(TEST_DOMAIN, listOf(TEST_V6_ADDR))
80         testServer.start()
81 
82         // Mock query and send it to the test server.
83         val queryHeader = DnsPacket.DnsHeader(0xbeef /* id */,
84                 0x0 /* flag */, 1 /* qcount */, 0 /* ancount */)
85         val qlist = listOf(DnsRecord.makeQuestion(TEST_DOMAIN, TYPE_AAAA, CLASS_IN))
86         val queryPacket = TestDnsServer.DnsQueryPacket(queryHeader, qlist, emptyList())
87         val response = resolve(queryPacket, testServer.port)
88 
89         // Verify expected answer packet. Set QR bit of flag to 1 for response packet
90         // according to RFC 1035 section 4.1.1.
91         val answerHeader = DnsPacket.DnsHeader(0xbeef,
92             1 shl 15 /* flag */, 1 /* qcount */, 1 /* ancount */)
93         val alist = listOf(DnsRecord.makeAOrAAAARecord(DnsPacket.ANSECTION, TEST_DOMAIN,
94                     CLASS_IN, DEFAULT_TTL_S, TEST_V6_ADDR))
95         val expectedAnswerPacket = TestDnsServer.DnsAnswerPacket(answerHeader, qlist, alist)
96         assertEquals(expectedAnswerPacket, response)
97 
98         // Clean up the server in tearDown.
99     }
100 
resolvenull101     private fun resolve(queryDnsPacket: DnsPacket, serverPort: Int): TestDnsServer.DnsAnswerPacket {
102         val bytes = queryDnsPacket.bytes
103         // Create a new client socket, the socket will be bound to a
104         // random port other than the server port.
105         val socket = DatagramSocket(localAddr).also { it.soTimeout = 100 }
106         val queryPacket = DatagramPacket(bytes, bytes.size, localAddr.address, serverPort)
107 
108         // Send query and wait for the reply.
109         socket.send(queryPacket)
110         val buffer = ByteArray(MAX_BUF_SIZE)
111         val reply = DatagramPacket(buffer, buffer.size)
112         socket.receive(reply)
113         return TestDnsServer.DnsAnswerPacket(reply.data)
114     }
115 
116     // TODO: Add more tests, which includes:
117     //  * Empty question RR packet (or more unexpected states)
118     //  * No answer found (setAnswer empty list at L.78)
119     //  * Test one or multi A record(s)
120     //  * Test multi AAAA records
121     //  * Test CNAME records
122 }
123