1 /*
2  * Copyright (C) 2020 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  *
16  */
17 
18 #pragma once
19 
20 #include <map>
21 #include <thread>
22 
23 #include <netinet/ip.h>
24 
25 #include <android-base/result.h>
26 #include <android-base/unique_fd.h>
27 #include <netdutils/Slice.h>
28 
29 namespace android::net {
30 
31 // Given a TUN interface fd, TunForwarder reads packets from the fd, changes their IP header
32 // according to a set of forwarding rules (which can be set by addForwardingRule), and sends
33 // new packets back to the fd. Only IPv4 and IPv6 packets with recognized source and destination
34 // addresses are accepted; other packets are silently ignored.
35 class TunForwarder {
36   public:
37     TunForwarder(base::unique_fd tunFd);
38     ~TunForwarder();
39 
40     bool addForwardingRule(const std::array<std::string, 2>& from,
41                            const std::array<std::string, 2>& to);
42     bool startForwarding();
43     bool stopForwarding();
44 
45     static base::unique_fd createTun(const std::string& ifname);
46 
47   private:
48     // TODO: Considering using IPAddress for v4pair and v6pair. This might requires adding
49     // addr4() and addr6() as IPPrefix does.
50     struct v4pair {
51         static base::Result<v4pair> makePair(const std::array<std::string, 2>& addrs);
52         v4pair() = default;
53         v4pair(int32_t srcAddr, int32_t dstAddr) {
54             src.s_addr = static_cast<in_addr_t>(srcAddr);
55             dst.s_addr = static_cast<in_addr_t>(dstAddr);
56         }
57         in_addr src;
58         in_addr dst;
59         bool operator==(const v4pair& o) const;
60         bool operator<(const v4pair& o) const;
61     };
62 
63     struct v6pair {
64         static base::Result<v6pair> makePair(const std::array<std::string, 2>& addrs);
65         v6pair() = default;
66         v6pair(const in6_addr& srcAddr, const in6_addr& dstAddr) : src(srcAddr), dst(dstAddr) {}
67         in6_addr src;
68         in6_addr dst;
69         bool operator==(const v6pair& o) const;
70         bool operator<(const v6pair& o) const;
71     };
72 
73     void loop();
74     void handlePacket(int fd) const;
75 
76     // Send a signal to terminate the loop thread.
77     bool signalEventFd();
78 
79     // A series of functions to check the packet. Return error if the packet is neither UDP nor TCP.
80     base::Result<void> validatePacket(netdutils::Slice tunPacket) const;
81     base::Result<void> validateIpv4Packet(netdutils::Slice ipv4Packet) const;
82     base::Result<void> validateIpv6Packet(netdutils::Slice ipv6Packet) const;
83     base::Result<void> validateUdpPacket(netdutils::Slice udpPacket) const;
84     base::Result<void> validateTcpPacket(netdutils::Slice tcpPacket) const;
85 
86     // The function assumes |tunPacket| is either UDP or TCP packet, changes the source/destination
87     // addresses, and updates the checksum.
88     base::Result<void> translatePacket(netdutils::Slice tunPacket) const;
89     base::Result<void> translateIpv4Packet(netdutils::Slice ipv4Packet) const;
90     base::Result<void> translateIpv6Packet(netdutils::Slice ipv6Packet) const;
91     void translateUdpPacket(netdutils::Slice udpPacket, uint32_t oldPseudoSum,
92                             uint32_t newPseudoSum) const;
93     void translateTcpPacket(netdutils::Slice tcpPacket, uint32_t oldPseudoSum,
94                             uint32_t newPseudoSum) const;
95 
96     std::thread mForwarder;
97     base::unique_fd mTunFd;
98     base::unique_fd mEventFd;
99     std::map<v4pair, v4pair> mRulesIpv4;
100     std::map<v6pair, v6pair> mRulesIpv6;
101 
102     static constexpr int kPollTimeoutMs = 5000;
103 };
104 
105 }  // namespace android::net
106