1 /* 2 * Copyright (C) 2020 The Android Open Source Project 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 * 16 */ 17 18 #define LOG_TAG "TunForwarder" 19 20 #include "tun_forwarder.h" 21 22 #include <arpa/inet.h> 23 #include <linux/if.h> 24 #include <linux/if_tun.h> 25 #include <linux/ioctl.h> 26 #include <netinet/ip6.h> 27 #include <netinet/tcp.h> 28 #include <netinet/udp.h> 29 #include <sys/eventfd.h> 30 #include <sys/poll.h> 31 32 #include <android-base/logging.h> 33 34 extern "C" { 35 #include <netutils/checksum.h> 36 } 37 38 using android::base::Error; 39 using android::base::Result; 40 using android::base::unique_fd; 41 using android::netdutils::Slice; 42 43 namespace android::net { 44 45 static constexpr int MAXMTU = 1500; 46 static constexpr ssize_t TUN_HDRLEN = sizeof(struct tun_pi); 47 static constexpr ssize_t IP4_HDRLEN = sizeof(struct iphdr); 48 static constexpr ssize_t IP6_HDRLEN = sizeof(struct ip6_hdr); 49 static constexpr ssize_t TCP_HDRLEN = sizeof(struct tcphdr); 50 static constexpr ssize_t UDP_HDRLEN = sizeof(struct udphdr); 51 52 namespace { 53 54 bool operator==(const in6_addr& x, const in6_addr& y) { 55 return std::memcmp(x.s6_addr, y.s6_addr, 16) == 0; 56 } 57 58 bool operator!=(const in6_addr& x, const in6_addr& y) { 59 return !(x == y); 60 } 61 62 bool operator<(const in6_addr& x, const in6_addr& y) { 63 return std::memcmp(x.s6_addr, y.s6_addr, 16) < 0; 64 } 65 66 } // namespace 67 68 Result<TunForwarder::v4pair> TunForwarder::v4pair::makePair( 69 const std::array<std::string, 2>& addrs) { 70 v4pair pair; 71 if (inet_pton(AF_INET, addrs[0].c_str(), &pair.src) != 1 || 72 inet_pton(AF_INET, addrs[1].c_str(), &pair.dst) != 1) { 73 return Error() << "Failed to make v4pair"; 74 } 75 return pair; 76 } 77 78 bool TunForwarder::v4pair::operator==(const v4pair& o) const { 79 return std::tie(src.s_addr, dst.s_addr) == std::tie(o.src.s_addr, o.dst.s_addr); 80 } 81 82 bool TunForwarder::v4pair::operator<(const v4pair& o) const { 83 return std::tie(src.s_addr, dst.s_addr) < std::tie(o.src.s_addr, o.dst.s_addr); 84 } 85 86 Result<TunForwarder::v6pair> TunForwarder::v6pair::makePair( 87 const std::array<std::string, 2>& addrs) { 88 v6pair pair; 89 if (inet_pton(AF_INET6, addrs[0].c_str(), &pair.src) != 1 || 90 inet_pton(AF_INET6, addrs[1].c_str(), &pair.dst) != 1) { 91 return Error() << "Failed to make v6pair"; 92 } 93 return pair; 94 } 95 96 bool TunForwarder::v6pair::operator==(const v6pair& o) const { 97 return src == o.src && dst == o.dst; 98 } 99 100 bool TunForwarder::v6pair::operator<(const v6pair& o) const { 101 if (src != o.src) return src < o.src; 102 return dst < o.dst; 103 } 104 105 TunForwarder::TunForwarder(unique_fd tunFd) : mTunFd(std::move(tunFd)) { 106 mEventFd.reset(eventfd(0, EFD_NONBLOCK | EFD_CLOEXEC)); 107 } 108 109 TunForwarder::~TunForwarder() { 110 stopForwarding(); 111 if (mForwarder.joinable()) { 112 mForwarder.join(); 113 } 114 } 115 116 bool TunForwarder::startForwarding() { 117 if (mForwarder.joinable()) return false; 118 mForwarder = std::thread(&TunForwarder::loop, this); 119 return true; 120 } 121 122 bool TunForwarder::stopForwarding() { 123 return signalEventFd(); 124 } 125 126 // Assume all of the strings in |from| and |to| are the IP addresses of the same IP version. 127 bool TunForwarder::addForwardingRule(const std::array<std::string, 2>& from, 128 const std::array<std::string, 2>& to) { 129 const bool isV4 = (from[0].find(':') == from[0].npos); 130 if (isV4) { 131 auto k = v4pair::makePair(from); 132 auto v = v4pair::makePair(to); 133 if (!k.ok() || !v.ok()) return false; 134 mRulesIpv4[k.value()] = v.value(); 135 } else { 136 auto k = v6pair::makePair(from); 137 auto v = v6pair::makePair(to); 138 if (!k.ok() || !v.ok()) return false; 139 mRulesIpv6[k.value()] = v.value(); 140 } 141 return true; 142 } 143 144 unique_fd TunForwarder::createTun(const std::string& ifname) { 145 unique_fd fd(open("/dev/tun", O_RDWR | O_NONBLOCK | O_CLOEXEC)); 146 if (!fd.ok() == -1) { 147 return {}; 148 } 149 150 ifreq ifr = { 151 .ifr_ifru = {.ifru_flags = IFF_TUN}, 152 }; 153 strlcpy(ifr.ifr_name, ifname.data(), sizeof(ifr.ifr_name)); 154 155 if (ioctl(fd.get(), TUNSETIFF, &ifr) == -1) { 156 PLOG(WARNING) << "failed to bring up tun " << ifr.ifr_name; 157 return {}; 158 } 159 160 unique_fd inet6CtrlSock(socket(AF_INET6, SOCK_DGRAM | SOCK_CLOEXEC, 0)); 161 ifr.ifr_flags = IFF_UP; 162 if (ioctl(inet6CtrlSock.get(), SIOCSIFFLAGS, &ifr) == -1) { 163 PLOG(WARNING) << "failed on SIOCSIFFLAGS " << ifr.ifr_name; 164 return {}; 165 } 166 167 return fd; 168 } 169 170 void TunForwarder::loop() { 171 while (true) { 172 struct pollfd wait_fd[] = { 173 {mEventFd, POLLIN, 0}, 174 {mTunFd.get(), POLLIN, 0}, 175 }; 176 177 if (int ret = poll(wait_fd, std::size(wait_fd), kPollTimeoutMs); ret <= 0) { 178 break; 179 } 180 181 if (wait_fd[0].revents & (POLLIN | POLLERR)) { 182 uint64_t value = 0; 183 eventfd_read(mEventFd, &value); 184 break; 185 } 186 if (wait_fd[1].revents & (POLLIN | POLLERR)) { 187 handlePacket(wait_fd[1].fd); 188 } 189 } 190 } 191 192 void TunForwarder::handlePacket(int fd) const { 193 uint8_t buf[MAXMTU + TUN_HDRLEN]; 194 195 ssize_t readlen = read(fd, buf, std::size(buf)); 196 if (readlen < 0) { 197 PLOG(ERROR) << "failed to read packets from tun"; 198 return; 199 } else if (readlen == 0) { 200 PLOG(ERROR) << "tun interface removed"; 201 return; 202 } 203 204 // Filter the packet. Only TCP and UDP packets are allowed. 205 const Slice tunPacket(buf, readlen); 206 if (auto result = validatePacket(tunPacket); !result.ok()) { 207 LOG(DEBUG) << "validatePacket failed: " << result.error(); 208 return; 209 } 210 211 // Change the packet's source/destination address and checksum. 212 if (auto result = translatePacket(tunPacket); !result.ok()) { 213 LOG(ERROR) << "translatePacket failed: " << result.error(); 214 } 215 216 // Write the new packet to the fd, causing the kernel to receive it on the tun interface. 217 write(fd, buf, readlen); 218 } 219 220 Result<void> TunForwarder::validatePacket(Slice tunPacket) const { 221 if (tunPacket.size() < TUN_HDRLEN) { 222 return Error() << "Too short for a tun header"; 223 } 224 225 const tun_pi* const tunHeader = reinterpret_cast<tun_pi*>(tunPacket.base()); 226 if (tunHeader->flags != 0) { 227 return Error() << "Unexpected tun flags " << static_cast<int>(tunHeader->flags); 228 } 229 230 switch (uint16_t proto = ntohs(tunHeader->proto); proto) { 231 case ETH_P_IP: 232 return validateIpv4Packet(drop(tunPacket, TUN_HDRLEN)); 233 case ETH_P_IPV6: 234 return validateIpv6Packet(drop(tunPacket, TUN_HDRLEN)); 235 default: 236 return Error() << "Unsupported packet type 0x" << std::hex << static_cast<int>(proto); 237 } 238 } 239 240 Result<void> TunForwarder::validateIpv4Packet(Slice ipv4Packet) const { 241 if (ipv4Packet.size() < IP4_HDRLEN) { 242 return Error() << "Too short for an ip header"; 243 } 244 245 const iphdr* const ipHeader = reinterpret_cast<iphdr*>(ipv4Packet.base()); 246 if (ipHeader->ihl < 5) { 247 return Error() << "IP header length set to less than 5"; 248 } 249 if (ipHeader->ihl * 4 > ipv4Packet.size()) { 250 return Error() << "IP header length set too large: " << ipHeader->ihl; 251 } 252 if (ipHeader->version != 4) { 253 return Error() << "IP header version not 4: " << ipHeader->version; 254 } 255 if (mRulesIpv4.find({ipHeader->saddr, ipHeader->daddr}) == mRulesIpv4.end()) { 256 return Error() << "Can't find any v4 rule. Packet hex dump: " << toHex(ipv4Packet, 32); 257 } 258 259 switch (ipHeader->protocol) { 260 case IPPROTO_UDP: 261 return validateUdpPacket(drop(ipv4Packet, ipHeader->ihl * 4)); 262 case IPPROTO_TCP: 263 return validateTcpPacket(drop(ipv4Packet, ipHeader->ihl * 4)); 264 default: 265 return Error() << "Unsupported transport protocol " 266 << static_cast<int>(ipHeader->protocol); 267 } 268 } 269 270 Result<void> TunForwarder::validateIpv6Packet(Slice ipv6Packet) const { 271 if (ipv6Packet.size() < IP6_HDRLEN) { 272 return Error() << "Too short for an ipv6 header"; 273 } 274 275 const ip6_hdr* const ipv6Header = reinterpret_cast<ip6_hdr*>(ipv6Packet.base()); 276 if (mRulesIpv6.find({ipv6Header->ip6_src, ipv6Header->ip6_dst}) == mRulesIpv6.end()) { 277 return Error() << "Can't find any v6 rule. Packet hex dump: " << toHex(ipv6Packet, 32); 278 } 279 280 switch (ipv6Header->ip6_nxt) { 281 case IPPROTO_UDP: 282 return validateUdpPacket(drop(ipv6Packet, IP6_HDRLEN)); 283 case IPPROTO_TCP: 284 return validateTcpPacket(drop(ipv6Packet, IP6_HDRLEN)); 285 default: 286 return Error() << "Expect TCP/UDP in ipv6 next header: " 287 << static_cast<int>(ipv6Header->ip6_nxt); 288 } 289 } 290 291 Result<void> TunForwarder::validateUdpPacket(Slice udpPacket) const { 292 if (udpPacket.size() < UDP_HDRLEN) { 293 return Error() << "Too short for a udp header"; 294 } 295 return {}; 296 } 297 298 Result<void> TunForwarder::validateTcpPacket(Slice tcpPacket) const { 299 if (tcpPacket.size() < TCP_HDRLEN) { 300 return Error() << "Too short for a tcp header"; 301 } 302 303 const tcphdr* const tcpHeader = reinterpret_cast<tcphdr*>(tcpPacket.base()); 304 if (tcpHeader->doff < 5) { 305 return Error() << "TCP header length set to less than 5"; 306 } 307 if (tcpHeader->doff * 4 > tcpPacket.size()) { 308 return Error() << "TCP header length set too large: " << tcpHeader->doff; 309 } 310 return {}; 311 } 312 313 Result<void> TunForwarder::translatePacket(Slice tunPacket) const { 314 const tun_pi* const tunHeader = reinterpret_cast<tun_pi*>(tunPacket.base()); 315 switch (uint16_t proto = ntohs(tunHeader->proto); proto) { 316 case ETH_P_IP: 317 return translateIpv4Packet(drop(tunPacket, TUN_HDRLEN)); 318 case ETH_P_IPV6: 319 return translateIpv6Packet(drop(tunPacket, TUN_HDRLEN)); 320 default: 321 return Error() << "translate: Unsupported packet type 0x" << std::hex 322 << static_cast<int>(proto); 323 } 324 } 325 326 Result<void> TunForwarder::translateIpv4Packet(Slice ipv4Packet) const { 327 iphdr* ipHeader = reinterpret_cast<iphdr*>(ipv4Packet.base()); 328 const size_t ipHeaderLen = ipHeader->ihl * 4; 329 const size_t transport_len = ipv4Packet.size() - ipHeaderLen; 330 331 uint32_t oldPseudoSum = ipv4_pseudo_header_checksum(ipHeader, transport_len); 332 for (const auto& [from, to] : mRulesIpv4) { 333 if (ipHeader->saddr == static_cast<int>(from.src.s_addr) && 334 ipHeader->daddr == static_cast<int>(from.dst.s_addr)) { 335 ipHeader->saddr = to.src.s_addr; 336 ipHeader->daddr = to.dst.s_addr; 337 break; 338 } 339 } 340 uint32_t newPseudoSum = ipv4_pseudo_header_checksum(ipHeader, transport_len); 341 342 ipHeader->check = 0; 343 ipHeader->check = ip_checksum(ipHeader, sizeof(struct iphdr)); 344 345 switch (ipHeader->protocol) { 346 case IPPROTO_UDP: 347 translateUdpPacket(drop(ipv4Packet, ipHeaderLen), oldPseudoSum, newPseudoSum); 348 break; 349 case IPPROTO_TCP: 350 translateTcpPacket(drop(ipv4Packet, ipHeaderLen), oldPseudoSum, newPseudoSum); 351 break; 352 default: 353 return Error() << "translate: Unsupported transport protocol " 354 << static_cast<int>(ipHeader->protocol); 355 } 356 357 return {}; 358 } 359 360 Result<void> TunForwarder::translateIpv6Packet(Slice ipv6Packet) const { 361 ip6_hdr* ipv6Header = reinterpret_cast<ip6_hdr*>(ipv6Packet.base()); 362 const size_t ipHeaderLen = IP6_HDRLEN; 363 const size_t transport_len = ipv6Packet.size() - ipHeaderLen; 364 365 uint32_t oldPseudoSum = 366 ipv6_pseudo_header_checksum(ipv6Header, transport_len, ipv6Header->ip6_nxt); 367 for (const auto& [from, to] : mRulesIpv6) { 368 if (ipv6Header->ip6_src == from.src && ipv6Header->ip6_dst == from.dst) { 369 ipv6Header->ip6_src = to.src; 370 ipv6Header->ip6_dst = to.dst; 371 break; 372 } 373 } 374 uint32_t newPseudoSum = 375 ipv6_pseudo_header_checksum(ipv6Header, transport_len, ipv6Header->ip6_nxt); 376 377 switch (ipv6Header->ip6_nxt) { 378 case IPPROTO_UDP: 379 translateUdpPacket(drop(ipv6Packet, ipHeaderLen), oldPseudoSum, newPseudoSum); 380 break; 381 case IPPROTO_TCP: 382 translateTcpPacket(drop(ipv6Packet, ipHeaderLen), oldPseudoSum, newPseudoSum); 383 break; 384 default: 385 return Error() << "transliate: Expect TCP/UDP in ipv6 next header: " 386 << static_cast<int>(ipv6Header->ip6_nxt); 387 } 388 389 return {}; 390 } 391 392 void TunForwarder::translateUdpPacket(Slice udpPacket, uint32_t oldPseudoSum, 393 uint32_t newPseudoSum) const { 394 udphdr* udpHeader = reinterpret_cast<udphdr*>(udpPacket.base()); 395 if (udpHeader->check) { 396 udpHeader->check = ip_checksum_adjust(udpHeader->check, oldPseudoSum, newPseudoSum); 397 } else { 398 uint32_t tmp = ip_checksum_add(newPseudoSum, udpPacket.base(), udpPacket.size()); 399 udpHeader->check = ip_checksum_finish(tmp); 400 } 401 402 // RFC 768: "If the computed checksum is zero, it is transmitted as all ones (the equivalent 403 // in one's complement arithmetic)." 404 if (!udpHeader->check) { 405 udpHeader->check = 0xffff; 406 } 407 } 408 409 void TunForwarder::translateTcpPacket(Slice tcpPacket, uint32_t oldPseudoSum, 410 uint32_t newPseudoSum) const { 411 tcphdr* tcpHeader = reinterpret_cast<tcphdr*>(tcpPacket.base()); 412 tcpHeader->check = ip_checksum_adjust(tcpHeader->check, oldPseudoSum, newPseudoSum); 413 } 414 415 bool TunForwarder::signalEventFd() { 416 return eventfd_write(mEventFd.get(), 1) == 0; 417 } 418 419 } // namespace android::net 420