1 /*
2  *
3  * Copyright (C) 2017 The Android Open Source Project
4  *
5  * Licensed under the Apache License, Version 2.0 (the "License");
6  * you may not use this file except in compliance with the License.
7  * You may obtain a copy of the License at
8  *
9  *      http://www.apache.org/licenses/LICENSE-2.0
10  *
11  * Unless required by applicable law or agreed to in writing, software
12  * distributed under the License is distributed on an "AS IS" BASIS,
13  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14  * See the License for the specific language governing permissions and
15  * limitations under the License.
16  */
17 
18 #include <random>
19 #include <string>
20 #include <vector>
21 
22 #include <ctype.h>
23 #include <errno.h>
24 #include <fcntl.h>
25 #include <getopt.h>
26 #include <stdio.h>
27 #include <stdlib.h>
28 #include <string.h>
29 
30 #define __STDC_FORMAT_MACROS
31 #include <inttypes.h>
32 
33 #include <arpa/inet.h>
34 #include <netinet/in.h>
35 
36 #include <sys/socket.h>
37 #include <sys/stat.h>
38 #include <sys/types.h>
39 #include <sys/wait.h>
40 
41 #include <linux/in.h>
42 #include <linux/ipsec.h>
43 #include <linux/netlink.h>
44 #include <linux/xfrm.h>
45 
46 #define LOG_TAG "XfrmController"
47 #include "InterfaceController.h"
48 #include "NetdConstants.h"
49 #include "NetlinkCommands.h"
50 #include "ResponseCode.h"
51 #include "XfrmController.h"
52 #include "netdutils/Fd.h"
53 #include "netdutils/Slice.h"
54 #include "netdutils/Syscalls.h"
55 #include <android-base/properties.h>
56 #include <android-base/stringprintf.h>
57 #include <android-base/strings.h>
58 #include <android-base/unique_fd.h>
59 #include <android/net/INetd.h>
60 #include <cutils/log.h>
61 #include <cutils/properties.h>
62 #include <logwrap/logwrap.h>
63 
64 using android::net::INetd;
65 using android::netdutils::Fd;
66 using android::netdutils::Slice;
67 using android::netdutils::Status;
68 using android::netdutils::StatusOr;
69 using android::netdutils::Syscalls;
70 
71 namespace android {
72 namespace net {
73 
74 // Exposed for testing
75 constexpr uint32_t ALGO_MASK_AUTH_ALL = ~0;
76 // Exposed for testing
77 constexpr uint32_t ALGO_MASK_CRYPT_ALL = ~0;
78 // Exposed for testing
79 constexpr uint32_t ALGO_MASK_AEAD_ALL = ~0;
80 // Exposed for testing
81 constexpr uint8_t REPLAY_WINDOW_SIZE = 4;
82 
83 namespace {
84 
85 constexpr uint32_t RAND_SPI_MIN = 256;
86 constexpr uint32_t RAND_SPI_MAX = 0xFFFFFFFE;
87 
88 constexpr uint32_t INVALID_SPI = 0;
89 
isEngBuild()90 static inline bool isEngBuild() {
91     static const std::string sBuildType = android::base::GetProperty("ro.build.type", "user");
92     return sBuildType == "eng";
93 }
94 
95 #define XFRM_MSG_TRANS(x)                                                                          \
96     case x:                                                                                        \
97         return #x;
98 
xfrmMsgTypeToString(uint16_t msg)99 const char* xfrmMsgTypeToString(uint16_t msg) {
100     switch (msg) {
101         XFRM_MSG_TRANS(XFRM_MSG_NEWSA)
102         XFRM_MSG_TRANS(XFRM_MSG_DELSA)
103         XFRM_MSG_TRANS(XFRM_MSG_GETSA)
104         XFRM_MSG_TRANS(XFRM_MSG_NEWPOLICY)
105         XFRM_MSG_TRANS(XFRM_MSG_DELPOLICY)
106         XFRM_MSG_TRANS(XFRM_MSG_GETPOLICY)
107         XFRM_MSG_TRANS(XFRM_MSG_ALLOCSPI)
108         XFRM_MSG_TRANS(XFRM_MSG_ACQUIRE)
109         XFRM_MSG_TRANS(XFRM_MSG_EXPIRE)
110         XFRM_MSG_TRANS(XFRM_MSG_UPDPOLICY)
111         XFRM_MSG_TRANS(XFRM_MSG_UPDSA)
112         XFRM_MSG_TRANS(XFRM_MSG_POLEXPIRE)
113         XFRM_MSG_TRANS(XFRM_MSG_FLUSHSA)
114         XFRM_MSG_TRANS(XFRM_MSG_FLUSHPOLICY)
115         XFRM_MSG_TRANS(XFRM_MSG_NEWAE)
116         XFRM_MSG_TRANS(XFRM_MSG_GETAE)
117         XFRM_MSG_TRANS(XFRM_MSG_REPORT)
118         XFRM_MSG_TRANS(XFRM_MSG_MIGRATE)
119         XFRM_MSG_TRANS(XFRM_MSG_NEWSADINFO)
120         XFRM_MSG_TRANS(XFRM_MSG_GETSADINFO)
121         XFRM_MSG_TRANS(XFRM_MSG_GETSPDINFO)
122         XFRM_MSG_TRANS(XFRM_MSG_NEWSPDINFO)
123         XFRM_MSG_TRANS(XFRM_MSG_MAPPING)
124         default:
125             return "XFRM_MSG UNKNOWN";
126     }
127 }
128 
129 // actually const but cannot be declared as such for reasons
130 uint8_t kPadBytesArray[] = {0, 0, 0};
131 void* kPadBytes = static_cast<void*>(kPadBytesArray);
132 
133 #define LOG_HEX(__desc16__, __buf__, __len__)                                                      \
134     do {                                                                                           \
135         if (isEngBuild()) {                                                                        \
136             logHex(__desc16__, __buf__, __len__);                                                  \
137         }                                                                                          \
138     } while (0)
139 
140 #define LOG_IOV(__iov__)                                                                           \
141     do {                                                                                           \
142         if (isEngBuild()) {                                                                        \
143             logIov(__iov__);                                                                       \
144         }                                                                                          \
145     } while (0)
146 
logHex(const char * desc16,const char * buf,size_t len)147 void logHex(const char* desc16, const char* buf, size_t len) {
148     char* printBuf = new char[len * 2 + 1 + 26]; // len->ascii, +newline, +prefix strlen
149     int offset = 0;
150     if (desc16) {
151         sprintf(printBuf, "{%-16s}", desc16);
152         offset += 18; // prefix string length
153     }
154     sprintf(printBuf + offset, "[%4.4u]: ", (len > 9999) ? 9999 : (unsigned)len);
155     offset += 8;
156 
157     for (uint32_t j = 0; j < (uint32_t)len; j++) {
158         sprintf(&printBuf[j * 2 + offset], "%0.2x", (unsigned char)buf[j]);
159     }
160     ALOGD("%s", printBuf);
161     delete[] printBuf;
162 }
163 
logIov(const std::vector<iovec> & iov)164 void logIov(const std::vector<iovec>& iov) {
165     for (const iovec& row : iov) {
166         logHex(0, reinterpret_cast<char*>(row.iov_base), row.iov_len);
167     }
168 }
169 
fillNlAttr(__u16 nlaType,size_t valueSize,nlattr * nlAttr)170 size_t fillNlAttr(__u16 nlaType, size_t valueSize, nlattr* nlAttr) {
171     size_t dataLen = valueSize;
172     int padLength = NLMSG_ALIGN(dataLen) - dataLen;
173     nlAttr->nla_len = (__u16)(dataLen + sizeof(nlattr));
174     nlAttr->nla_type = nlaType;
175     return padLength;
176 }
177 
fillNlAttrIpAddress(__u16 nlaType,int family,const std::string & value,nlattr * nlAttr,Slice ipAddress)178 size_t fillNlAttrIpAddress(__u16 nlaType, int family, const std::string& value, nlattr* nlAttr,
179                            Slice ipAddress) {
180     inet_pton(family, value.c_str(), ipAddress.base());
181     return fillNlAttr(nlaType, (family == AF_INET) ? sizeof(in_addr) : sizeof(in6_addr), nlAttr);
182 }
183 
fillNlAttrU32(__u16 nlaType,int32_t value,nlattr * nlAttr,uint32_t * u32Value)184 size_t fillNlAttrU32(__u16 nlaType, int32_t value, nlattr* nlAttr, uint32_t* u32Value) {
185     *u32Value = htonl(value);
186     return fillNlAttr(nlaType, sizeof((*u32Value)), nlAttr);
187 }
188 
189 // returns the address family, placing the string in the provided buffer
convertStringAddress(std::string addr,uint8_t * buffer)190 StatusOr<uint16_t> convertStringAddress(std::string addr, uint8_t* buffer) {
191     if (inet_pton(AF_INET, addr.c_str(), buffer) == 1) {
192         return AF_INET;
193     } else if (inet_pton(AF_INET6, addr.c_str(), buffer) == 1) {
194         return AF_INET6;
195     } else {
196         return Status(EAFNOSUPPORT);
197     }
198 }
199 
200 // TODO: Need to consider a way to refer to the sSycalls instance
getSyscallInstance()201 inline Syscalls& getSyscallInstance() { return netdutils::sSyscalls.get(); }
202 
203 class XfrmSocketImpl : public XfrmSocket {
204 private:
205     static constexpr int NLMSG_DEFAULTSIZE = 8192;
206 
207     union NetlinkResponse {
208         nlmsghdr hdr;
209         struct _err_ {
210             nlmsghdr hdr;
211             nlmsgerr err;
212         } err;
213 
214         struct _buf_ {
215             nlmsghdr hdr;
216             char buf[NLMSG_DEFAULTSIZE];
217         } buf;
218     };
219 
220 public:
open()221     netdutils::Status open() override {
222         mSock = openNetlinkSocket(NETLINK_XFRM);
223         if (mSock < 0) {
224             ALOGW("Could not get a new socket, line=%d", __LINE__);
225             return netdutils::statusFromErrno(-mSock, "Could not open netlink socket");
226         }
227 
228         return netdutils::status::ok;
229     }
230 
validateResponse(NetlinkResponse response,size_t len)231     static netdutils::Status validateResponse(NetlinkResponse response, size_t len) {
232         if (len < sizeof(nlmsghdr)) {
233             ALOGW("Invalid response message received over netlink");
234             return netdutils::statusFromErrno(EBADMSG, "Invalid message");
235         }
236 
237         switch (response.hdr.nlmsg_type) {
238             case NLMSG_NOOP:
239             case NLMSG_DONE:
240                 return netdutils::status::ok;
241             case NLMSG_OVERRUN:
242                 ALOGD("Netlink request overran kernel buffer");
243                 return netdutils::statusFromErrno(EBADMSG, "Kernel buffer overrun");
244             case NLMSG_ERROR:
245                 if (len < sizeof(NetlinkResponse::_err_)) {
246                     ALOGD("Netlink message received malformed error response");
247                     return netdutils::statusFromErrno(EBADMSG, "Malformed error response");
248                 }
249                 return netdutils::statusFromErrno(
250                     -response.err.err.error,
251                     "Error netlink message"); // Netlink errors are negative errno.
252             case XFRM_MSG_NEWSA:
253                 break;
254         }
255 
256         if (response.hdr.nlmsg_type < XFRM_MSG_BASE /*== NLMSG_MIN_TYPE*/ ||
257             response.hdr.nlmsg_type > XFRM_MSG_MAX) {
258             ALOGD("Netlink message responded with an out-of-range message ID");
259             return netdutils::statusFromErrno(EBADMSG, "Invalid message ID");
260         }
261 
262         // TODO Add more message validation here
263         return netdutils::status::ok;
264     }
265 
sendMessage(uint16_t nlMsgType,uint16_t nlMsgFlags,uint16_t nlMsgSeqNum,std::vector<iovec> * iovecs) const266     netdutils::Status sendMessage(uint16_t nlMsgType, uint16_t nlMsgFlags, uint16_t nlMsgSeqNum,
267                                   std::vector<iovec>* iovecs) const override {
268         nlmsghdr nlMsg = {
269             .nlmsg_type = nlMsgType,
270             .nlmsg_flags = nlMsgFlags,
271             .nlmsg_seq = nlMsgSeqNum,
272         };
273 
274         (*iovecs)[0].iov_base = &nlMsg;
275         (*iovecs)[0].iov_len = NLMSG_HDRLEN;
276         for (const iovec& iov : *iovecs) {
277             nlMsg.nlmsg_len += iov.iov_len;
278         }
279 
280         ALOGD("Sending Netlink XFRM Message: %s", xfrmMsgTypeToString(nlMsgType));
281         LOG_IOV(*iovecs);
282 
283         StatusOr<size_t> writeResult = getSyscallInstance().writev(mSock, *iovecs);
284         if (!isOk(writeResult)) {
285             ALOGE("netlink socket writev failed (%s)", toString(writeResult).c_str());
286             return writeResult;
287         }
288 
289         if (nlMsg.nlmsg_len != writeResult.value()) {
290             ALOGE("Invalid netlink message length sent %d", static_cast<int>(writeResult.value()));
291             return netdutils::statusFromErrno(EBADMSG, "Invalid message length");
292         }
293 
294         NetlinkResponse response = {};
295 
296         StatusOr<Slice> readResult =
297             getSyscallInstance().read(Fd(mSock), netdutils::makeSlice(response));
298         if (!isOk(readResult)) {
299             ALOGE("netlink response error (%s)", toString(readResult).c_str());
300             return readResult;
301         }
302 
303         LOG_HEX("netlink msg resp", reinterpret_cast<char*>(readResult.value().base()),
304                 readResult.value().size());
305 
306         Status validateStatus = validateResponse(response, readResult.value().size());
307         if (!isOk(validateStatus)) {
308             ALOGE("netlink response contains error (%s)", toString(validateStatus).c_str());
309         }
310 
311         return validateStatus;
312     }
313 };
314 
convertToXfrmAddr(const std::string & strAddr,xfrm_address_t * xfrmAddr)315 StatusOr<int> convertToXfrmAddr(const std::string& strAddr, xfrm_address_t* xfrmAddr) {
316     if (strAddr.length() == 0) {
317         memset(xfrmAddr, 0, sizeof(*xfrmAddr));
318         return AF_UNSPEC;
319     }
320 
321     if (inet_pton(AF_INET6, strAddr.c_str(), reinterpret_cast<void*>(xfrmAddr))) {
322         return AF_INET6;
323     } else if (inet_pton(AF_INET, strAddr.c_str(), reinterpret_cast<void*>(xfrmAddr))) {
324         return AF_INET;
325     } else {
326         return netdutils::statusFromErrno(EAFNOSUPPORT, "Invalid address family");
327     }
328 }
329 
fillXfrmNlaHdr(nlattr * hdr,uint16_t type,uint16_t len)330 void fillXfrmNlaHdr(nlattr* hdr, uint16_t type, uint16_t len) {
331     hdr->nla_type = type;
332     hdr->nla_len = len;
333 }
334 
fillXfrmCurLifetimeDefaults(xfrm_lifetime_cur * cur)335 void fillXfrmCurLifetimeDefaults(xfrm_lifetime_cur* cur) {
336     memset(reinterpret_cast<char*>(cur), 0, sizeof(*cur));
337 }
fillXfrmLifetimeDefaults(xfrm_lifetime_cfg * cfg)338 void fillXfrmLifetimeDefaults(xfrm_lifetime_cfg* cfg) {
339     cfg->soft_byte_limit = XFRM_INF;
340     cfg->hard_byte_limit = XFRM_INF;
341     cfg->soft_packet_limit = XFRM_INF;
342     cfg->hard_packet_limit = XFRM_INF;
343 }
344 
345 /*
346  * Allocate SPIs within an (inclusive) range of min-max.
347  * returns 0 (INVALID_SPI) once the entire range has been parsed.
348  */
349 class RandomSpi {
350 public:
RandomSpi(int min,int max)351     RandomSpi(int min, int max) : mMin(min) {
352         // Re-seeding should be safe because the seed itself is
353         // sufficiently random and we don't need secure random
354         std::mt19937 rnd = std::mt19937(std::random_device()());
355         mNext = std::uniform_int_distribution<>(1, INT_MAX)(rnd);
356         mSize = max - min + 1;
357         mCount = mSize;
358     }
359 
next()360     uint32_t next() {
361         if (!mCount)
362             return 0;
363         mCount--;
364         return (mNext++ % mSize) + mMin;
365     }
366 
367 private:
368     uint32_t mNext;
369     uint32_t mSize;
370     uint32_t mMin;
371     uint32_t mCount;
372 };
373 
374 } // namespace
375 
376 //
377 // Begin XfrmController Impl
378 //
379 //
XfrmController(void)380 XfrmController::XfrmController(void) {}
381 
Init()382 netdutils::Status XfrmController::Init() {
383     RETURN_IF_NOT_OK(flushInterfaces());
384     XfrmSocketImpl sock;
385     RETURN_IF_NOT_OK(sock.open());
386     RETURN_IF_NOT_OK(flushSaDb(sock));
387     return flushPolicyDb(sock);
388 }
389 
flushInterfaces()390 netdutils::Status XfrmController::flushInterfaces() {
391     const auto& ifaces = InterfaceController::getIfaceNames();
392     RETURN_IF_NOT_OK(ifaces);
393     const String8 ifPrefix8 = String8(INetd::IPSEC_INTERFACE_PREFIX().string());
394 
395     for (const std::string& iface : ifaces.value()) {
396         int status = 0;
397         // Look for the reserved interface prefix, which must be in the name at position 0
398         if (!iface.compare(0, ifPrefix8.length(), ifPrefix8.c_str()) &&
399             (status = removeVirtualTunnelInterface(iface)) < 0) {
400             ALOGE("Failed to delete ipsec tunnel %s.", iface.c_str());
401             return netdutils::statusFromErrno(status, "Failed to remove ipsec tunnel.");
402         }
403     }
404     return netdutils::status::ok;
405 }
406 
flushSaDb(const XfrmSocket & s)407 netdutils::Status XfrmController::flushSaDb(const XfrmSocket& s) {
408     struct xfrm_usersa_flush flushUserSa = {.proto = IPSEC_PROTO_ANY};
409 
410     std::vector<iovec> iov = {{NULL, 0}, // reserved for the eventual addition of a NLMSG_HDR
411                               {&flushUserSa, sizeof(flushUserSa)}, // xfrm_usersa_flush structure
412                               {kPadBytes, NLMSG_ALIGN(sizeof(flushUserSa)) - sizeof(flushUserSa)}};
413 
414     return s.sendMessage(XFRM_MSG_FLUSHSA, NETLINK_REQUEST_FLAGS, 0, &iov);
415 }
416 
flushPolicyDb(const XfrmSocket & s)417 netdutils::Status XfrmController::flushPolicyDb(const XfrmSocket& s) {
418     std::vector<iovec> iov = {{NULL, 0}}; // reserved for the eventual addition of a NLMSG_HDR
419     return s.sendMessage(XFRM_MSG_FLUSHPOLICY, NETLINK_REQUEST_FLAGS, 0, &iov);
420 }
421 
ipSecSetEncapSocketOwner(const android::base::unique_fd & socket,int newUid,uid_t callerUid)422 netdutils::Status XfrmController::ipSecSetEncapSocketOwner(const android::base::unique_fd& socket,
423                                                            int newUid, uid_t callerUid) {
424     ALOGD("XfrmController:%s, line=%d", __FUNCTION__, __LINE__);
425 
426     const int fd = socket.get();
427     struct stat info;
428     if (fstat(fd, &info)) {
429         return netdutils::statusFromErrno(errno, "Failed to stat socket file descriptor");
430     }
431     if (info.st_uid != callerUid) {
432         return netdutils::statusFromErrno(EPERM, "fchown disabled for non-owner calls");
433     }
434     if (S_ISSOCK(info.st_mode) == 0) {
435         return netdutils::statusFromErrno(EINVAL, "File descriptor was not a socket");
436     }
437 
438     int optval;
439     socklen_t optlen;
440     netdutils::Status status =
441         getSyscallInstance().getsockopt(Fd(socket), IPPROTO_UDP, UDP_ENCAP, &optval, &optlen);
442     if (status != netdutils::status::ok) {
443         return status;
444     }
445     if (optval != UDP_ENCAP_ESPINUDP && optval != UDP_ENCAP_ESPINUDP_NON_IKE) {
446         return netdutils::statusFromErrno(EINVAL, "Socket did not have UDP-encap sockopt set");
447     }
448     if (fchown(fd, newUid, -1)) {
449         return netdutils::statusFromErrno(errno, "Failed to fchown socket file descriptor");
450     }
451 
452     return netdutils::status::ok;
453 }
454 
ipSecAllocateSpi(int32_t transformId,const std::string & sourceAddress,const std::string & destinationAddress,int32_t inSpi,int32_t * outSpi)455 netdutils::Status XfrmController::ipSecAllocateSpi(int32_t transformId,
456                                                    const std::string& sourceAddress,
457                                                    const std::string& destinationAddress,
458                                                    int32_t inSpi, int32_t* outSpi) {
459     ALOGD("XfrmController:%s, line=%d", __FUNCTION__, __LINE__);
460     ALOGD("transformId=%d", transformId);
461     ALOGD("sourceAddress=%s", sourceAddress.c_str());
462     ALOGD("destinationAddress=%s", destinationAddress.c_str());
463     ALOGD("inSpi=%0.8x", inSpi);
464 
465     XfrmSaInfo saInfo{};
466     netdutils::Status ret =
467         fillXfrmId(sourceAddress, destinationAddress, INVALID_SPI, 0, 0, transformId, &saInfo);
468     if (!isOk(ret)) {
469         return ret;
470     }
471 
472     XfrmSocketImpl sock;
473     netdutils::Status socketStatus = sock.open();
474     if (!isOk(socketStatus)) {
475         ALOGD("Sock open failed for XFRM, line=%d", __LINE__);
476         return socketStatus;
477     }
478 
479     int minSpi = RAND_SPI_MIN, maxSpi = RAND_SPI_MAX;
480 
481     if (inSpi)
482         minSpi = maxSpi = inSpi;
483 
484     ret = allocateSpi(saInfo, minSpi, maxSpi, reinterpret_cast<uint32_t*>(outSpi), sock);
485     if (!isOk(ret)) {
486         // TODO: May want to return a new Status with a modified status string
487         ALOGD("Failed to Allocate an SPI, line=%d", __LINE__);
488         *outSpi = INVALID_SPI;
489     }
490 
491     return ret;
492 }
493 
ipSecAddSecurityAssociation(int32_t transformId,int32_t mode,const std::string & sourceAddress,const std::string & destinationAddress,int32_t underlyingNetId,int32_t spi,int32_t markValue,int32_t markMask,const std::string & authAlgo,const std::vector<uint8_t> & authKey,int32_t authTruncBits,const std::string & cryptAlgo,const std::vector<uint8_t> & cryptKey,int32_t cryptTruncBits,const std::string & aeadAlgo,const std::vector<uint8_t> & aeadKey,int32_t aeadIcvBits,int32_t encapType,int32_t encapLocalPort,int32_t encapRemotePort)494 netdutils::Status XfrmController::ipSecAddSecurityAssociation(
495     int32_t transformId, int32_t mode, const std::string& sourceAddress,
496     const std::string& destinationAddress, int32_t underlyingNetId, int32_t spi, int32_t markValue,
497     int32_t markMask, const std::string& authAlgo, const std::vector<uint8_t>& authKey,
498     int32_t authTruncBits, const std::string& cryptAlgo, const std::vector<uint8_t>& cryptKey,
499     int32_t cryptTruncBits, const std::string& aeadAlgo, const std::vector<uint8_t>& aeadKey,
500     int32_t aeadIcvBits, int32_t encapType, int32_t encapLocalPort, int32_t encapRemotePort) {
501     ALOGD("XfrmController::%s, line=%d", __FUNCTION__, __LINE__);
502     ALOGD("transformId=%d", transformId);
503     ALOGD("mode=%d", mode);
504     ALOGD("sourceAddress=%s", sourceAddress.c_str());
505     ALOGD("destinationAddress=%s", destinationAddress.c_str());
506     ALOGD("underlyingNetworkId=%d", underlyingNetId);
507     ALOGD("spi=%0.8x", spi);
508     ALOGD("markValue=%x", markValue);
509     ALOGD("markMask=%x", markMask);
510     ALOGD("authAlgo=%s", authAlgo.c_str());
511     ALOGD("authTruncBits=%d", authTruncBits);
512     ALOGD("cryptAlgo=%s", cryptAlgo.c_str());
513     ALOGD("cryptTruncBits=%d,", cryptTruncBits);
514     ALOGD("aeadAlgo=%s", aeadAlgo.c_str());
515     ALOGD("aeadIcvBits=%d,", aeadIcvBits);
516     ALOGD("encapType=%d", encapType);
517     ALOGD("encapLocalPort=%d", encapLocalPort);
518     ALOGD("encapRemotePort=%d", encapRemotePort);
519 
520     XfrmSaInfo saInfo{};
521     netdutils::Status ret = fillXfrmId(sourceAddress, destinationAddress, spi, markValue, markMask,
522                                        transformId, &saInfo);
523     if (!isOk(ret)) {
524         return ret;
525     }
526 
527     saInfo.auth = XfrmAlgo{
528         .name = authAlgo, .key = authKey, .truncLenBits = static_cast<uint16_t>(authTruncBits)};
529 
530     saInfo.crypt = XfrmAlgo{
531         .name = cryptAlgo, .key = cryptKey, .truncLenBits = static_cast<uint16_t>(cryptTruncBits)};
532 
533     saInfo.aead = XfrmAlgo{
534         .name = aeadAlgo, .key = aeadKey, .truncLenBits = static_cast<uint16_t>(aeadIcvBits)};
535 
536     switch (static_cast<XfrmMode>(mode)) {
537         case XfrmMode::TRANSPORT:
538         case XfrmMode::TUNNEL:
539             saInfo.mode = static_cast<XfrmMode>(mode);
540             break;
541         default:
542             return netdutils::statusFromErrno(EINVAL, "Invalid xfrm mode");
543     }
544 
545     XfrmSocketImpl sock;
546     netdutils::Status socketStatus = sock.open();
547     if (!isOk(socketStatus)) {
548         ALOGD("Sock open failed for XFRM, line=%d", __LINE__);
549         return socketStatus;
550     }
551 
552     switch (static_cast<XfrmEncapType>(encapType)) {
553         case XfrmEncapType::ESPINUDP:
554         case XfrmEncapType::ESPINUDP_NON_IKE:
555             if (saInfo.addrFamily != AF_INET) {
556                 return netdutils::statusFromErrno(EAFNOSUPPORT, "IPv6 encap not supported");
557             }
558             // The ports are not used on input SAs, so this is OK to be wrong when
559             // direction is ultimately input.
560             saInfo.encap.srcPort = encapLocalPort;
561             saInfo.encap.dstPort = encapRemotePort;
562         // fall through
563         case XfrmEncapType::NONE:
564             saInfo.encap.type = static_cast<XfrmEncapType>(encapType);
565             break;
566         default:
567             return netdutils::statusFromErrno(EINVAL, "Invalid encap type");
568     }
569 
570     saInfo.netId = underlyingNetId;
571 
572     ret = updateSecurityAssociation(saInfo, sock);
573     if (!isOk(ret)) {
574         ALOGD("Failed updating a Security Association, line=%d", __LINE__);
575     }
576 
577     return ret;
578 }
579 
ipSecDeleteSecurityAssociation(int32_t transformId,const std::string & sourceAddress,const std::string & destinationAddress,int32_t spi,int32_t markValue,int32_t markMask)580 netdutils::Status XfrmController::ipSecDeleteSecurityAssociation(
581     int32_t transformId, const std::string& sourceAddress, const std::string& destinationAddress,
582     int32_t spi, int32_t markValue, int32_t markMask) {
583     ALOGD("XfrmController:%s, line=%d", __FUNCTION__, __LINE__);
584     ALOGD("transformId=%d", transformId);
585     ALOGD("sourceAddress=%s", sourceAddress.c_str());
586     ALOGD("destinationAddress=%s", destinationAddress.c_str());
587     ALOGD("spi=%0.8x", spi);
588     ALOGD("markValue=%x", markValue);
589     ALOGD("markMask=%x", markMask);
590 
591     XfrmId saId{};
592     netdutils::Status ret =
593         fillXfrmId(sourceAddress, destinationAddress, spi, markValue, markMask, transformId, &saId);
594     if (!isOk(ret)) {
595         return ret;
596     }
597 
598     XfrmSocketImpl sock;
599     netdutils::Status socketStatus = sock.open();
600     if (!isOk(socketStatus)) {
601         ALOGD("Sock open failed for XFRM, line=%d", __LINE__);
602         return socketStatus;
603     }
604 
605     ret = deleteSecurityAssociation(saId, sock);
606     if (!isOk(ret)) {
607         ALOGD("Failed to delete Security Association, line=%d", __LINE__);
608     }
609 
610     return ret;
611 }
612 
fillXfrmId(const std::string & sourceAddress,const std::string & destinationAddress,int32_t spi,int32_t markValue,int32_t markMask,int32_t transformId,XfrmId * xfrmId)613 netdutils::Status XfrmController::fillXfrmId(const std::string& sourceAddress,
614                                              const std::string& destinationAddress, int32_t spi,
615                                              int32_t markValue, int32_t markMask,
616                                              int32_t transformId, XfrmId* xfrmId) {
617     // Fill the straightforward fields first
618     xfrmId->transformId = transformId;
619     xfrmId->spi = htonl(spi);
620     xfrmId->mark.v = markValue;
621     xfrmId->mark.m = markMask;
622 
623     // Use the addresses to determine the address family and do validation
624     xfrm_address_t sourceXfrmAddr{}, destXfrmAddr{};
625     StatusOr<int> sourceFamily, destFamily;
626     sourceFamily = convertToXfrmAddr(sourceAddress, &sourceXfrmAddr);
627     destFamily = convertToXfrmAddr(destinationAddress, &destXfrmAddr);
628     if (!isOk(sourceFamily) || !isOk(destFamily)) {
629         return netdutils::statusFromErrno(EINVAL, "Invalid address " + sourceAddress + "/" +
630                                                       destinationAddress);
631     }
632 
633     if (destFamily.value() == AF_UNSPEC ||
634         (sourceFamily.value() != AF_UNSPEC && sourceFamily.value() != destFamily.value())) {
635         ALOGD("Invalid or Mismatched Address Families, %d != %d, line=%d", sourceFamily.value(),
636               destFamily.value(), __LINE__);
637         return netdutils::statusFromErrno(EINVAL, "Invalid or mismatched address families");
638     }
639 
640     xfrmId->addrFamily = destFamily.value();
641 
642     xfrmId->dstAddr = destXfrmAddr;
643     xfrmId->srcAddr = sourceXfrmAddr;
644     return netdutils::status::ok;
645 }
646 
ipSecApplyTransportModeTransform(const android::base::unique_fd & socket,int32_t transformId,int32_t direction,const std::string & sourceAddress,const std::string & destinationAddress,int32_t spi)647 netdutils::Status XfrmController::ipSecApplyTransportModeTransform(
648     const android::base::unique_fd& socket, int32_t transformId, int32_t direction,
649     const std::string& sourceAddress, const std::string& destinationAddress, int32_t spi) {
650     ALOGD("XfrmController::%s, line=%d", __FUNCTION__, __LINE__);
651     ALOGD("transformId=%d", transformId);
652     ALOGD("direction=%d", direction);
653     ALOGD("sourceAddress=%s", sourceAddress.c_str());
654     ALOGD("destinationAddress=%s", destinationAddress.c_str());
655     ALOGD("spi=%0.8x", spi);
656 
657     StatusOr<sockaddr_storage> ret = getSyscallInstance().getsockname<sockaddr_storage>(Fd(socket));
658     if (!isOk(ret)) {
659         ALOGE("Failed to get socket info in %s", __FUNCTION__);
660         return ret;
661     }
662     struct sockaddr_storage saddr = ret.value();
663 
664     XfrmSaInfo saInfo{};
665     netdutils::Status status =
666         fillXfrmId(sourceAddress, destinationAddress, spi, 0, 0, transformId, &saInfo);
667     if (!isOk(status)) {
668         ALOGE("Couldn't build SA ID %s", __FUNCTION__);
669         return status;
670     }
671 
672     if (saddr.ss_family == AF_INET && saInfo.addrFamily != AF_INET) {
673         ALOGE("IPV4 socket address family(%d) should match IPV4 Transform "
674               "address family(%d)!",
675               saddr.ss_family, saInfo.addrFamily);
676         return netdutils::statusFromErrno(EINVAL, "Mismatched address family");
677     }
678 
679     struct {
680         xfrm_userpolicy_info info;
681         xfrm_user_tmpl tmpl;
682     } policy{};
683 
684     fillTransportModeUserSpInfo(saInfo, static_cast<XfrmDirection>(direction), &policy.info);
685     fillUserTemplate(saInfo, &policy.tmpl);
686 
687     LOG_HEX("XfrmUserPolicy", reinterpret_cast<char*>(&policy), sizeof(policy));
688 
689     int sockOpt, sockLayer;
690     switch (saddr.ss_family) {
691         case AF_INET:
692             sockOpt = IP_XFRM_POLICY;
693             sockLayer = SOL_IP;
694             break;
695         case AF_INET6:
696             sockOpt = IPV6_XFRM_POLICY;
697             sockLayer = SOL_IPV6;
698             break;
699         default:
700             return netdutils::statusFromErrno(EAFNOSUPPORT, "Invalid address family");
701     }
702 
703     status = getSyscallInstance().setsockopt(Fd(socket), sockLayer, sockOpt, policy);
704     if (!isOk(status)) {
705         ALOGE("Error setting socket option for XFRM! (%s)", toString(status).c_str());
706     }
707 
708     return status;
709 }
710 
711 netdutils::Status
ipSecRemoveTransportModeTransform(const android::base::unique_fd & socket)712 XfrmController::ipSecRemoveTransportModeTransform(const android::base::unique_fd& socket) {
713     ALOGD("XfrmController::%s, line=%d", __FUNCTION__, __LINE__);
714 
715     StatusOr<sockaddr_storage> ret = getSyscallInstance().getsockname<sockaddr_storage>(Fd(socket));
716     if (!isOk(ret)) {
717         ALOGE("Failed to get socket info in %s! (%s)", __FUNCTION__, toString(ret).c_str());
718         return ret;
719     }
720 
721     int sockOpt, sockLayer;
722     switch (ret.value().ss_family) {
723         case AF_INET:
724             sockOpt = IP_XFRM_POLICY;
725             sockLayer = SOL_IP;
726             break;
727         case AF_INET6:
728             sockOpt = IPV6_XFRM_POLICY;
729             sockLayer = SOL_IPV6;
730             break;
731         default:
732             return netdutils::statusFromErrno(EAFNOSUPPORT, "Invalid address family");
733     }
734 
735     // Kernel will delete the security policy on this socket for both direction
736     // if optval is set to NULL and optlen is set to 0.
737     netdutils::Status status =
738         getSyscallInstance().setsockopt(Fd(socket), sockLayer, sockOpt, NULL, 0);
739     if (!isOk(status)) {
740         ALOGE("Error removing socket option for XFRM! (%s)", toString(status).c_str());
741     }
742 
743     return status;
744 }
745 
ipSecAddSecurityPolicy(int32_t transformId,int32_t direction,const std::string & localAddress,const std::string & remoteAddress,int32_t spi,int32_t markValue,int32_t markMask)746 netdutils::Status XfrmController::ipSecAddSecurityPolicy(int32_t transformId, int32_t direction,
747                                                          const std::string& localAddress,
748                                                          const std::string& remoteAddress,
749                                                          int32_t spi, int32_t markValue,
750                                                          int32_t markMask) {
751     return processSecurityPolicy(transformId, direction, localAddress, remoteAddress, spi,
752                                  markValue, markMask, XFRM_MSG_NEWPOLICY);
753 }
754 
ipSecUpdateSecurityPolicy(int32_t transformId,int32_t direction,const std::string & localAddress,const std::string & remoteAddress,int32_t spi,int32_t markValue,int32_t markMask)755 netdutils::Status XfrmController::ipSecUpdateSecurityPolicy(int32_t transformId, int32_t direction,
756                                                             const std::string& localAddress,
757                                                             const std::string& remoteAddress,
758                                                             int32_t spi, int32_t markValue,
759                                                             int32_t markMask) {
760     return processSecurityPolicy(transformId, direction, localAddress, remoteAddress, spi,
761                                  markValue, markMask, XFRM_MSG_UPDPOLICY);
762 }
763 
ipSecDeleteSecurityPolicy(int32_t transformId,int32_t direction,const std::string & localAddress,const std::string & remoteAddress,int32_t markValue,int32_t markMask)764 netdutils::Status XfrmController::ipSecDeleteSecurityPolicy(int32_t transformId, int32_t direction,
765                                                             const std::string& localAddress,
766                                                             const std::string& remoteAddress,
767                                                             int32_t markValue, int32_t markMask) {
768     return processSecurityPolicy(transformId, direction, localAddress, remoteAddress, 0, markValue,
769                                  markMask, XFRM_MSG_DELPOLICY);
770 }
771 
processSecurityPolicy(int32_t transformId,int32_t direction,const std::string & localAddress,const std::string & remoteAddress,int32_t spi,int32_t markValue,int32_t markMask,int32_t msgType)772 netdutils::Status XfrmController::processSecurityPolicy(int32_t transformId, int32_t direction,
773                                                         const std::string& localAddress,
774                                                         const std::string& remoteAddress,
775                                                         int32_t spi, int32_t markValue,
776                                                         int32_t markMask, int32_t msgType) {
777     ALOGD("XfrmController::%s, line=%d", __FUNCTION__, __LINE__);
778     ALOGD("transformId=%d", transformId);
779     ALOGD("direction=%d", direction);
780     ALOGD("localAddress=%s", localAddress.c_str());
781     ALOGD("remoteAddress=%s", remoteAddress.c_str());
782     ALOGD("spi=%0.8x", spi);
783     ALOGD("markValue=%d", markValue);
784     ALOGD("markMask=%d", markMask);
785     ALOGD("msgType=%d", msgType);
786 
787     XfrmSaInfo saInfo{};
788     saInfo.mode = XfrmMode::TUNNEL;
789 
790     XfrmSocketImpl sock;
791     RETURN_IF_NOT_OK(sock.open());
792 
793     RETURN_IF_NOT_OK(
794         fillXfrmId(localAddress, remoteAddress, spi, markValue, markMask, transformId, &saInfo));
795 
796     if (msgType == XFRM_MSG_DELPOLICY) {
797         return deleteTunnelModeSecurityPolicy(saInfo, sock, static_cast<XfrmDirection>(direction));
798     } else {
799         return updateTunnelModeSecurityPolicy(saInfo, sock, static_cast<XfrmDirection>(direction),
800                                               msgType);
801     }
802 }
803 
fillXfrmSelector(const XfrmSaInfo & record,xfrm_selector * selector)804 void XfrmController::fillXfrmSelector(const XfrmSaInfo& record, xfrm_selector* selector) {
805     selector->family = record.addrFamily;
806     selector->proto = AF_UNSPEC; // TODO: do we need to match the protocol? it's
807                                  // possible via the socket
808 }
809 
updateSecurityAssociation(const XfrmSaInfo & record,const XfrmSocket & sock)810 netdutils::Status XfrmController::updateSecurityAssociation(const XfrmSaInfo& record,
811                                                             const XfrmSocket& sock) {
812     xfrm_usersa_info usersa{};
813     nlattr_algo_crypt crypt{};
814     nlattr_algo_auth auth{};
815     nlattr_algo_aead aead{};
816     nlattr_xfrm_mark xfrmmark{};
817     nlattr_xfrm_output_mark xfrmoutputmark{};
818     nlattr_encap_tmpl encap{};
819 
820     enum {
821         NLMSG_HDR,
822         USERSA,
823         USERSA_PAD,
824         CRYPT,
825         CRYPT_PAD,
826         AUTH,
827         AUTH_PAD,
828         AEAD,
829         AEAD_PAD,
830         MARK,
831         MARK_PAD,
832         OUTPUT_MARK,
833         OUTPUT_MARK_PAD,
834         ENCAP,
835         ENCAP_PAD,
836     };
837 
838     std::vector<iovec> iov = {
839         {NULL, 0},            // reserved for the eventual addition of a NLMSG_HDR
840         {&usersa, 0},         // main usersa_info struct
841         {kPadBytes, 0},       // up to NLMSG_ALIGNTO pad bytes of padding
842         {&crypt, 0},          // adjust size if crypt algo is present
843         {kPadBytes, 0},       // up to NLATTR_ALIGNTO pad bytes
844         {&auth, 0},           // adjust size if auth algo is present
845         {kPadBytes, 0},       // up to NLATTR_ALIGNTO pad bytes
846         {&aead, 0},           // adjust size if aead algo is present
847         {kPadBytes, 0},       // up to NLATTR_ALIGNTO pad bytes
848         {&xfrmmark, 0},       // adjust size if xfrm mark is present
849         {kPadBytes, 0},       // up to NLATTR_ALIGNTO pad bytes
850         {&xfrmoutputmark, 0}, // adjust size if xfrm output mark is present
851         {kPadBytes, 0},       // up to NLATTR_ALIGNTO pad bytes
852         {&encap, 0},          // adjust size if encapsulating
853         {kPadBytes, 0},       // up to NLATTR_ALIGNTO pad bytes
854     };
855 
856     if (!record.aead.name.empty() && (!record.auth.name.empty() || !record.crypt.name.empty())) {
857         return netdutils::statusFromErrno(EINVAL, "Invalid xfrm algo selection; AEAD is mutually "
858                                                   "exclusive with both Authentication and "
859                                                   "Encryption");
860     }
861 
862     if (record.aead.key.size() > MAX_KEY_LENGTH || record.auth.key.size() > MAX_KEY_LENGTH ||
863         record.crypt.key.size() > MAX_KEY_LENGTH) {
864         return netdutils::statusFromErrno(EINVAL, "Key length invalid; exceeds MAX_KEY_LENGTH");
865     }
866 
867     int len;
868     len = iov[USERSA].iov_len = fillUserSaInfo(record, &usersa);
869     iov[USERSA_PAD].iov_len = NLMSG_ALIGN(len) - len;
870 
871     len = iov[CRYPT].iov_len = fillNlAttrXfrmAlgoEnc(record.crypt, &crypt);
872     iov[CRYPT_PAD].iov_len = NLA_ALIGN(len) - len;
873 
874     len = iov[AUTH].iov_len = fillNlAttrXfrmAlgoAuth(record.auth, &auth);
875     iov[AUTH_PAD].iov_len = NLA_ALIGN(len) - len;
876 
877     len = iov[AEAD].iov_len = fillNlAttrXfrmAlgoAead(record.aead, &aead);
878     iov[AEAD_PAD].iov_len = NLA_ALIGN(len) - len;
879 
880     len = iov[MARK].iov_len = fillNlAttrXfrmMark(record, &xfrmmark);
881     iov[MARK_PAD].iov_len = NLA_ALIGN(len) - len;
882 
883     len = iov[OUTPUT_MARK].iov_len = fillNlAttrXfrmOutputMark(record.netId, &xfrmoutputmark);
884     iov[OUTPUT_MARK_PAD].iov_len = NLA_ALIGN(len) - len;
885 
886     len = iov[ENCAP].iov_len = fillNlAttrXfrmEncapTmpl(record, &encap);
887     iov[ENCAP_PAD].iov_len = NLA_ALIGN(len) - len;
888 
889     return sock.sendMessage(XFRM_MSG_UPDSA, NETLINK_REQUEST_FLAGS, 0, &iov);
890 }
891 
fillNlAttrXfrmAlgoEnc(const XfrmAlgo & inAlgo,nlattr_algo_crypt * algo)892 int XfrmController::fillNlAttrXfrmAlgoEnc(const XfrmAlgo& inAlgo, nlattr_algo_crypt* algo) {
893     if (inAlgo.name.empty()) { // Do not fill anything if algorithm not provided
894         return 0;
895     }
896 
897     int len = NLA_HDRLEN + sizeof(xfrm_algo);
898     // Kernel always changes last char to null terminator; no safety checks needed.
899     strncpy(algo->crypt.alg_name, inAlgo.name.c_str(), sizeof(algo->crypt.alg_name));
900     algo->crypt.alg_key_len = inAlgo.key.size() * 8; // bits
901     memcpy(algo->key, &inAlgo.key[0], inAlgo.key.size());
902     len += inAlgo.key.size();
903     fillXfrmNlaHdr(&algo->hdr, XFRMA_ALG_CRYPT, len);
904     return len;
905 }
906 
fillNlAttrXfrmAlgoAuth(const XfrmAlgo & inAlgo,nlattr_algo_auth * algo)907 int XfrmController::fillNlAttrXfrmAlgoAuth(const XfrmAlgo& inAlgo, nlattr_algo_auth* algo) {
908     if (inAlgo.name.empty()) { // Do not fill anything if algorithm not provided
909         return 0;
910     }
911 
912     int len = NLA_HDRLEN + sizeof(xfrm_algo_auth);
913     // Kernel always changes last char to null terminator; no safety checks needed.
914     strncpy(algo->auth.alg_name, inAlgo.name.c_str(), sizeof(algo->auth.alg_name));
915     algo->auth.alg_key_len = inAlgo.key.size() * 8; // bits
916 
917     // This is the extra field for ALG_AUTH_TRUNC
918     algo->auth.alg_trunc_len = inAlgo.truncLenBits;
919 
920     memcpy(algo->key, &inAlgo.key[0], inAlgo.key.size());
921     len += inAlgo.key.size();
922 
923     fillXfrmNlaHdr(&algo->hdr, XFRMA_ALG_AUTH_TRUNC, len);
924     return len;
925 }
926 
fillNlAttrXfrmAlgoAead(const XfrmAlgo & inAlgo,nlattr_algo_aead * algo)927 int XfrmController::fillNlAttrXfrmAlgoAead(const XfrmAlgo& inAlgo, nlattr_algo_aead* algo) {
928     if (inAlgo.name.empty()) { // Do not fill anything if algorithm not provided
929         return 0;
930     }
931 
932     int len = NLA_HDRLEN + sizeof(xfrm_algo_aead);
933     // Kernel always changes last char to null terminator; no safety checks needed.
934     strncpy(algo->aead.alg_name, inAlgo.name.c_str(), sizeof(algo->aead.alg_name));
935     algo->aead.alg_key_len = inAlgo.key.size() * 8; // bits
936 
937     // This is the extra field for ALG_AEAD. ICV length is the same as truncation length
938     // for any AEAD algorithm.
939     algo->aead.alg_icv_len = inAlgo.truncLenBits;
940 
941     memcpy(algo->key, &inAlgo.key[0], inAlgo.key.size());
942     len += inAlgo.key.size();
943 
944     fillXfrmNlaHdr(&algo->hdr, XFRMA_ALG_AEAD, len);
945     return len;
946 }
947 
fillNlAttrXfrmEncapTmpl(const XfrmSaInfo & record,nlattr_encap_tmpl * tmpl)948 int XfrmController::fillNlAttrXfrmEncapTmpl(const XfrmSaInfo& record, nlattr_encap_tmpl* tmpl) {
949     if (record.encap.type == XfrmEncapType::NONE) {
950         return 0;
951     }
952 
953     int len = NLA_HDRLEN + sizeof(xfrm_encap_tmpl);
954     tmpl->tmpl.encap_type = static_cast<uint16_t>(record.encap.type);
955     tmpl->tmpl.encap_sport = htons(record.encap.srcPort);
956     tmpl->tmpl.encap_dport = htons(record.encap.dstPort);
957     fillXfrmNlaHdr(&tmpl->hdr, XFRMA_ENCAP, len);
958     return len;
959 }
960 
fillUserSaInfo(const XfrmSaInfo & record,xfrm_usersa_info * usersa)961 int XfrmController::fillUserSaInfo(const XfrmSaInfo& record, xfrm_usersa_info* usersa) {
962     fillXfrmSelector(record, &usersa->sel);
963 
964     usersa->id.proto = IPPROTO_ESP;
965     usersa->id.spi = record.spi;
966     usersa->id.daddr = record.dstAddr;
967 
968     usersa->saddr = record.srcAddr;
969 
970     fillXfrmLifetimeDefaults(&usersa->lft);
971     fillXfrmCurLifetimeDefaults(&usersa->curlft);
972     memset(&usersa->stats, 0, sizeof(usersa->stats)); // leave stats zeroed out
973     usersa->reqid = record.transformId;
974     usersa->family = record.addrFamily;
975     usersa->mode = static_cast<uint8_t>(record.mode);
976     usersa->replay_window = REPLAY_WINDOW_SIZE;
977 
978     if (record.mode == XfrmMode::TRANSPORT) {
979         usersa->flags = 0; // TODO: should we actually set flags, XFRM_SA_XFLAG_DONT_ENCAP_DSCP?
980     } else {
981         usersa->flags = XFRM_STATE_AF_UNSPEC;
982     }
983 
984     return sizeof(*usersa);
985 }
986 
fillUserSaId(const XfrmId & record,xfrm_usersa_id * said)987 int XfrmController::fillUserSaId(const XfrmId& record, xfrm_usersa_id* said) {
988     said->daddr = record.dstAddr;
989     said->spi = record.spi;
990     said->family = record.addrFamily;
991     said->proto = IPPROTO_ESP;
992 
993     return sizeof(*said);
994 }
995 
deleteSecurityAssociation(const XfrmId & record,const XfrmSocket & sock)996 netdutils::Status XfrmController::deleteSecurityAssociation(const XfrmId& record,
997                                                             const XfrmSocket& sock) {
998     xfrm_usersa_id said{};
999     nlattr_xfrm_mark xfrmmark{};
1000 
1001     enum { NLMSG_HDR, USERSAID, USERSAID_PAD, MARK, MARK_PAD };
1002 
1003     std::vector<iovec> iov = {
1004         {NULL, 0},      // reserved for the eventual addition of a NLMSG_HDR
1005         {&said, 0},     // main usersa_info struct
1006         {kPadBytes, 0}, // up to NLMSG_ALIGNTO pad bytes of padding
1007         {&xfrmmark, 0}, // adjust size if xfrm mark is present
1008         {kPadBytes, 0}, // up to NLATTR_ALIGNTO pad bytes
1009     };
1010 
1011     int len;
1012     len = iov[USERSAID].iov_len = fillUserSaId(record, &said);
1013     iov[USERSAID_PAD].iov_len = NLMSG_ALIGN(len) - len;
1014 
1015     len = iov[MARK].iov_len = fillNlAttrXfrmMark(record, &xfrmmark);
1016     iov[MARK_PAD].iov_len = NLA_ALIGN(len) - len;
1017 
1018     return sock.sendMessage(XFRM_MSG_DELSA, NETLINK_REQUEST_FLAGS, 0, &iov);
1019 }
1020 
allocateSpi(const XfrmSaInfo & record,uint32_t minSpi,uint32_t maxSpi,uint32_t * outSpi,const XfrmSocket & sock)1021 netdutils::Status XfrmController::allocateSpi(const XfrmSaInfo& record, uint32_t minSpi,
1022                                               uint32_t maxSpi, uint32_t* outSpi,
1023                                               const XfrmSocket& sock) {
1024     xfrm_userspi_info spiInfo{};
1025 
1026     enum { NLMSG_HDR, USERSAID, USERSAID_PAD };
1027 
1028     std::vector<iovec> iov = {
1029         {NULL, 0},      // reserved for the eventual addition of a NLMSG_HDR
1030         {&spiInfo, 0},  // main userspi_info struct
1031         {kPadBytes, 0}, // up to NLMSG_ALIGNTO pad bytes of padding
1032     };
1033 
1034     int len;
1035     if (fillUserSaInfo(record, &spiInfo.info) == 0) {
1036         ALOGE("Failed to fill transport SA Info");
1037     }
1038 
1039     len = iov[USERSAID].iov_len = sizeof(spiInfo);
1040     iov[USERSAID_PAD].iov_len = NLMSG_ALIGN(len) - len;
1041 
1042     RandomSpi spiGen = RandomSpi(minSpi, maxSpi);
1043     int spi;
1044     netdutils::Status ret;
1045     while ((spi = spiGen.next()) != INVALID_SPI) {
1046         spiInfo.min = spi;
1047         spiInfo.max = spi;
1048         ret = sock.sendMessage(XFRM_MSG_ALLOCSPI, NETLINK_REQUEST_FLAGS, 0, &iov);
1049 
1050         /* If the SPI is in use, we'll get ENOENT */
1051         if (netdutils::equalToErrno(ret, ENOENT))
1052             continue;
1053 
1054         if (isOk(ret)) {
1055             *outSpi = spi;
1056             ALOGD("Allocated an SPI: %x", *outSpi);
1057         } else {
1058             *outSpi = INVALID_SPI;
1059             ALOGE("SPI Allocation Failed with error %d", ret.code());
1060         }
1061 
1062         return ret;
1063     }
1064 
1065     // Should always be -ENOENT if we get here
1066     return ret;
1067 }
1068 
updateTunnelModeSecurityPolicy(const XfrmSaInfo & record,const XfrmSocket & sock,XfrmDirection direction,uint16_t msgType)1069 netdutils::Status XfrmController::updateTunnelModeSecurityPolicy(const XfrmSaInfo& record,
1070                                                                  const XfrmSocket& sock,
1071                                                                  XfrmDirection direction,
1072                                                                  uint16_t msgType) {
1073     xfrm_userpolicy_info userpolicy{};
1074     nlattr_user_tmpl usertmpl{};
1075     nlattr_xfrm_mark xfrmmark{};
1076 
1077     enum {
1078         NLMSG_HDR,
1079         USERPOLICY,
1080         USERPOLICY_PAD,
1081         USERTMPL,
1082         USERTMPL_PAD,
1083         MARK,
1084         MARK_PAD,
1085     };
1086 
1087     std::vector<iovec> iov = {
1088         {NULL, 0},        // reserved for the eventual addition of a NLMSG_HDR
1089         {&userpolicy, 0}, // main xfrm_userpolicy_info struct
1090         {kPadBytes, 0},   // up to NLMSG_ALIGNTO pad bytes of padding
1091         {&usertmpl, 0},   // adjust size if xfrm_user_tmpl struct is present
1092         {kPadBytes, 0},   // up to NLATTR_ALIGNTO pad bytes
1093         {&xfrmmark, 0},   // adjust size if xfrm mark is present
1094         {kPadBytes, 0},   // up to NLATTR_ALIGNTO pad bytes
1095     };
1096 
1097     int len;
1098     len = iov[USERPOLICY].iov_len = fillTransportModeUserSpInfo(record, direction, &userpolicy);
1099     iov[USERPOLICY_PAD].iov_len = NLMSG_ALIGN(len) - len;
1100 
1101     len = iov[USERTMPL].iov_len = fillNlAttrUserTemplate(record, &usertmpl);
1102     iov[USERTMPL_PAD].iov_len = NLA_ALIGN(len) - len;
1103 
1104     len = iov[MARK].iov_len = fillNlAttrXfrmMark(record, &xfrmmark);
1105     iov[MARK_PAD].iov_len = NLA_ALIGN(len) - len;
1106 
1107     return sock.sendMessage(msgType, NETLINK_REQUEST_FLAGS, 0, &iov);
1108 }
1109 
deleteTunnelModeSecurityPolicy(const XfrmSaInfo & record,const XfrmSocket & sock,XfrmDirection direction)1110 netdutils::Status XfrmController::deleteTunnelModeSecurityPolicy(const XfrmSaInfo& record,
1111                                                                  const XfrmSocket& sock,
1112                                                                  XfrmDirection direction) {
1113     xfrm_userpolicy_id policyid{};
1114     nlattr_xfrm_mark xfrmmark{};
1115 
1116     enum {
1117         NLMSG_HDR,
1118         USERPOLICYID,
1119         USERPOLICYID_PAD,
1120         MARK,
1121         MARK_PAD,
1122     };
1123 
1124     std::vector<iovec> iov = {
1125         {NULL, 0},      // reserved for the eventual addition of a NLMSG_HDR
1126         {&policyid, 0}, // main xfrm_userpolicy_id struct
1127         {kPadBytes, 0}, // up to NLMSG_ALIGNTO pad bytes of padding
1128         {&xfrmmark, 0}, // adjust size if xfrm mark is present
1129         {kPadBytes, 0}, // up to NLATTR_ALIGNTO pad bytes
1130     };
1131 
1132     int len = iov[USERPOLICYID].iov_len = fillUserPolicyId(record, direction, &policyid);
1133     iov[USERPOLICYID_PAD].iov_len = NLMSG_ALIGN(len) - len;
1134 
1135     len = iov[MARK].iov_len = fillNlAttrXfrmMark(record, &xfrmmark);
1136     iov[MARK_PAD].iov_len = NLA_ALIGN(len) - len;
1137 
1138     return sock.sendMessage(XFRM_MSG_DELPOLICY, NETLINK_REQUEST_FLAGS, 0, &iov);
1139 }
1140 
fillTransportModeUserSpInfo(const XfrmSaInfo & record,XfrmDirection direction,xfrm_userpolicy_info * usersp)1141 int XfrmController::fillTransportModeUserSpInfo(const XfrmSaInfo& record, XfrmDirection direction,
1142                                                 xfrm_userpolicy_info* usersp) {
1143     fillXfrmSelector(record, &usersp->sel);
1144     fillXfrmLifetimeDefaults(&usersp->lft);
1145     fillXfrmCurLifetimeDefaults(&usersp->curlft);
1146     /* if (index) index & 0x3 == dir -- must be true
1147      * xfrm_user.c:verify_newpolicy_info() */
1148     usersp->index = 0;
1149     usersp->dir = static_cast<uint8_t>(direction);
1150     usersp->action = XFRM_POLICY_ALLOW;
1151     usersp->flags = XFRM_POLICY_LOCALOK;
1152     usersp->share = XFRM_SHARE_UNIQUE;
1153     return sizeof(*usersp);
1154 }
1155 
fillUserTemplate(const XfrmSaInfo & record,xfrm_user_tmpl * tmpl)1156 int XfrmController::fillUserTemplate(const XfrmSaInfo& record, xfrm_user_tmpl* tmpl) {
1157     tmpl->id.daddr = record.dstAddr;
1158     tmpl->id.spi = record.spi;
1159     tmpl->id.proto = IPPROTO_ESP;
1160 
1161     tmpl->family = record.addrFamily;
1162     tmpl->saddr = record.srcAddr;
1163     tmpl->reqid = record.transformId;
1164     tmpl->mode = static_cast<uint8_t>(record.mode);
1165     tmpl->share = XFRM_SHARE_UNIQUE;
1166     tmpl->optional = 0; // if this is true, then a failed state lookup will be considered OK:
1167                         // http://lxr.free-electrons.com/source/net/xfrm/xfrm_policy.c#L1492
1168     tmpl->aalgos = ALGO_MASK_AUTH_ALL;  // TODO: if there's a bitmask somewhere of
1169                                         // algos, we should find it and apply it.
1170                                         // I can't find one.
1171     tmpl->ealgos = ALGO_MASK_CRYPT_ALL; // TODO: if there's a bitmask somewhere...
1172     return sizeof(xfrm_user_tmpl*);
1173 }
1174 
fillNlAttrUserTemplate(const XfrmSaInfo & record,nlattr_user_tmpl * tmpl)1175 int XfrmController::fillNlAttrUserTemplate(const XfrmSaInfo& record, nlattr_user_tmpl* tmpl) {
1176     fillUserTemplate(record, &tmpl->tmpl);
1177 
1178     int len = NLA_HDRLEN + sizeof(xfrm_user_tmpl);
1179     fillXfrmNlaHdr(&tmpl->hdr, XFRMA_TMPL, len);
1180     return len;
1181 }
1182 
fillNlAttrXfrmMark(const XfrmId & record,nlattr_xfrm_mark * mark)1183 int XfrmController::fillNlAttrXfrmMark(const XfrmId& record, nlattr_xfrm_mark* mark) {
1184     mark->mark.v = record.mark.v; // set to 0 if it's not used
1185     mark->mark.m = record.mark.m; // set to 0 if it's not used
1186     int len = NLA_HDRLEN + sizeof(xfrm_mark);
1187     fillXfrmNlaHdr(&mark->hdr, XFRMA_MARK, len);
1188     return len;
1189 }
1190 
fillNlAttrXfrmOutputMark(const __u32 output_mark_value,nlattr_xfrm_output_mark * output_mark)1191 int XfrmController::fillNlAttrXfrmOutputMark(const __u32 output_mark_value,
1192                                              nlattr_xfrm_output_mark* output_mark) {
1193     // Do not set if we were not given an output mark
1194     if (output_mark_value == 0) {
1195         return 0;
1196     }
1197 
1198     output_mark->outputMark = output_mark_value;
1199     int len = NLA_HDRLEN + sizeof(__u32);
1200     fillXfrmNlaHdr(&output_mark->hdr, XFRMA_OUTPUT_MARK, len);
1201     return len;
1202 }
1203 
fillUserPolicyId(const XfrmSaInfo & record,XfrmDirection direction,xfrm_userpolicy_id * usersp)1204 int XfrmController::fillUserPolicyId(const XfrmSaInfo& record, XfrmDirection direction,
1205                                      xfrm_userpolicy_id* usersp) {
1206     // For DELPOLICY, when index is absent, selector is needed to match the policy
1207     fillXfrmSelector(record, &usersp->sel);
1208     usersp->dir = static_cast<uint8_t>(direction);
1209     return sizeof(*usersp);
1210 }
1211 
addVirtualTunnelInterface(const std::string & deviceName,const std::string & localAddress,const std::string & remoteAddress,int32_t ikey,int32_t okey,bool isUpdate)1212 int XfrmController::addVirtualTunnelInterface(const std::string& deviceName,
1213                                               const std::string& localAddress,
1214                                               const std::string& remoteAddress, int32_t ikey,
1215                                               int32_t okey, bool isUpdate) {
1216     ALOGD("XfrmController::%s, line=%d", __FUNCTION__, __LINE__);
1217     ALOGD("deviceName=%s", deviceName.c_str());
1218     ALOGD("localAddress=%s", localAddress.c_str());
1219     ALOGD("remoteAddress=%s", remoteAddress.c_str());
1220     ALOGD("ikey=%0.8x", ikey);
1221     ALOGD("okey=%0.8x", okey);
1222     ALOGD("isUpdate=%d", isUpdate);
1223 
1224     if (deviceName.empty() || localAddress.empty() || remoteAddress.empty()) {
1225         return EINVAL;
1226     }
1227 
1228     const char* INFO_KIND_VTI6 = "vti6";
1229     const char* INFO_KIND_VTI = "vti";
1230     uint8_t PADDING_BUFFER[] = {0, 0, 0, 0};
1231 
1232     // Find address family.
1233     uint8_t remAddr[sizeof(in6_addr)];
1234 
1235     StatusOr<uint16_t> statusOrRemoteFam = convertStringAddress(remoteAddress, remAddr);
1236     if (!isOk(statusOrRemoteFam)) {
1237         return statusOrRemoteFam.status().code();
1238     }
1239 
1240     uint8_t locAddr[sizeof(in6_addr)];
1241     StatusOr<uint16_t> statusOrLocalFam = convertStringAddress(localAddress, locAddr);
1242     if (!isOk(statusOrLocalFam)) {
1243         return statusOrLocalFam.status().code();
1244     }
1245 
1246     if (statusOrLocalFam.value() != statusOrRemoteFam.value()) {
1247         return EINVAL;
1248     }
1249 
1250     uint16_t family = statusOrLocalFam.value();
1251 
1252     ifinfomsg ifInfoMsg{};
1253 
1254     // Construct IFLA_IFNAME
1255     nlattr iflaIfName;
1256     char iflaIfNameStrValue[deviceName.length() + 1];
1257     size_t iflaIfNameLength =
1258         strlcpy(iflaIfNameStrValue, deviceName.c_str(), sizeof(iflaIfNameStrValue));
1259     size_t iflaIfNamePad = fillNlAttr(IFLA_IFNAME, iflaIfNameLength, &iflaIfName);
1260 
1261     // Construct IFLA_INFO_KIND
1262     // Constants "vti6" and "vti" enable the kernel to call different code paths,
1263     // (ip_tunnel.c, ip6_tunnel), based on the family.
1264     const std::string infoKindValue = (family == AF_INET6) ? INFO_KIND_VTI6 : INFO_KIND_VTI;
1265     nlattr iflaIfInfoKind;
1266     char infoKindValueStrValue[infoKindValue.length() + 1];
1267     size_t iflaIfInfoKindLength =
1268         strlcpy(infoKindValueStrValue, infoKindValue.c_str(), sizeof(infoKindValueStrValue));
1269     size_t iflaIfInfoKindPad = fillNlAttr(IFLA_INFO_KIND, iflaIfInfoKindLength, &iflaIfInfoKind);
1270 
1271     // Construct IFLA_VTI_LOCAL
1272     nlattr iflaVtiLocal;
1273     uint8_t binaryLocalAddress[sizeof(in6_addr)];
1274     size_t iflaVtiLocalPad =
1275         fillNlAttrIpAddress(IFLA_VTI_LOCAL, family, localAddress, &iflaVtiLocal,
1276                             netdutils::makeSlice(binaryLocalAddress));
1277 
1278     // Construct IFLA_VTI_REMOTE
1279     nlattr iflaVtiRemote;
1280     uint8_t binaryRemoteAddress[sizeof(in6_addr)];
1281     size_t iflaVtiRemotePad =
1282         fillNlAttrIpAddress(IFLA_VTI_REMOTE, family, remoteAddress, &iflaVtiRemote,
1283                             netdutils::makeSlice(binaryRemoteAddress));
1284 
1285     // Construct IFLA_VTI_OKEY
1286     nlattr iflaVtiIKey;
1287     uint32_t iKeyValue;
1288     size_t iflaVtiIKeyPad = fillNlAttrU32(IFLA_VTI_IKEY, ikey, &iflaVtiIKey, &iKeyValue);
1289 
1290     // Construct IFLA_VTI_IKEY
1291     nlattr iflaVtiOKey;
1292     uint32_t oKeyValue;
1293     size_t iflaVtiOKeyPad = fillNlAttrU32(IFLA_VTI_OKEY, okey, &iflaVtiOKey, &oKeyValue);
1294 
1295     int iflaInfoDataPayloadLength = iflaVtiLocal.nla_len + iflaVtiLocalPad + iflaVtiRemote.nla_len +
1296                                     iflaVtiRemotePad + iflaVtiIKey.nla_len + iflaVtiIKeyPad +
1297                                     iflaVtiOKey.nla_len + iflaVtiOKeyPad;
1298 
1299     // Construct IFLA_INFO_DATA
1300     nlattr iflaInfoData;
1301     size_t iflaInfoDataPad = fillNlAttr(IFLA_INFO_DATA, iflaInfoDataPayloadLength, &iflaInfoData);
1302 
1303     // Construct IFLA_LINKINFO
1304     nlattr iflaLinkInfo;
1305     size_t iflaLinkInfoPad = fillNlAttr(IFLA_LINKINFO,
1306                                         iflaInfoData.nla_len + iflaInfoDataPad +
1307                                             iflaIfInfoKind.nla_len + iflaIfInfoKindPad,
1308                                         &iflaLinkInfo);
1309 
1310     iovec iov[] = {
1311         {NULL, 0},
1312         {&ifInfoMsg, sizeof(ifInfoMsg)},
1313 
1314         {&iflaIfName, sizeof(iflaIfName)},
1315         {iflaIfNameStrValue, iflaIfNameLength},
1316         {&PADDING_BUFFER, iflaIfNamePad},
1317 
1318         {&iflaLinkInfo, sizeof(iflaLinkInfo)},
1319 
1320         {&iflaIfInfoKind, sizeof(iflaIfInfoKind)},
1321         {infoKindValueStrValue, iflaIfInfoKindLength},
1322         {&PADDING_BUFFER, iflaIfInfoKindPad},
1323 
1324         {&iflaInfoData, sizeof(iflaInfoData)},
1325 
1326         {&iflaVtiLocal, sizeof(iflaVtiLocal)},
1327         {&binaryLocalAddress, (family == AF_INET) ? sizeof(in_addr) : sizeof(in6_addr)},
1328         {&PADDING_BUFFER, iflaVtiLocalPad},
1329 
1330         {&iflaVtiRemote, sizeof(iflaVtiRemote)},
1331         {&binaryRemoteAddress, (family == AF_INET) ? sizeof(in_addr) : sizeof(in6_addr)},
1332         {&PADDING_BUFFER, iflaVtiRemotePad},
1333 
1334         {&iflaVtiIKey, sizeof(iflaVtiIKey)},
1335         {&iKeyValue, sizeof(iKeyValue)},
1336         {&PADDING_BUFFER, iflaVtiIKeyPad},
1337 
1338         {&iflaVtiOKey, sizeof(iflaVtiOKey)},
1339         {&oKeyValue, sizeof(oKeyValue)},
1340         {&PADDING_BUFFER, iflaVtiOKeyPad},
1341 
1342         {&PADDING_BUFFER, iflaInfoDataPad},
1343 
1344         {&PADDING_BUFFER, iflaLinkInfoPad},
1345     };
1346 
1347     uint16_t action = RTM_NEWLINK;
1348     uint16_t flags = NLM_F_REQUEST | NLM_F_ACK;
1349 
1350     if (!isUpdate) {
1351         flags |= NLM_F_EXCL | NLM_F_CREATE;
1352     }
1353 
1354     // sendNetlinkRequest returns -errno
1355     int ret = -1 * sendNetlinkRequest(action, flags, iov, ARRAY_SIZE(iov), nullptr);
1356     if (ret) {
1357         ALOGE("Error in %s virtual tunnel interface. Error Code: %d",
1358               isUpdate ? "updating" : "adding", ret);
1359     }
1360     return ret;
1361 }
1362 
removeVirtualTunnelInterface(const std::string & deviceName)1363 int XfrmController::removeVirtualTunnelInterface(const std::string& deviceName) {
1364     ALOGD("XfrmController::%s, line=%d", __FUNCTION__, __LINE__);
1365     ALOGD("deviceName=%s", deviceName.c_str());
1366 
1367     if (deviceName.empty()) {
1368         return EINVAL;
1369     }
1370 
1371     uint8_t PADDING_BUFFER[] = {0, 0, 0, 0};
1372 
1373     ifinfomsg ifInfoMsg{};
1374     nlattr iflaIfName;
1375     char iflaIfNameStrValue[deviceName.length() + 1];
1376     size_t iflaIfNameLength =
1377         strlcpy(iflaIfNameStrValue, deviceName.c_str(), sizeof(iflaIfNameStrValue));
1378     size_t iflaIfNamePad = fillNlAttr(IFLA_IFNAME, iflaIfNameLength, &iflaIfName);
1379 
1380     iovec iov[] = {
1381         {NULL, 0},
1382         {&ifInfoMsg, sizeof(ifInfoMsg)},
1383 
1384         {&iflaIfName, sizeof(iflaIfName)},
1385         {iflaIfNameStrValue, iflaIfNameLength},
1386         {&PADDING_BUFFER, iflaIfNamePad},
1387     };
1388 
1389     uint16_t action = RTM_DELLINK;
1390     uint16_t flags = NLM_F_REQUEST | NLM_F_ACK;
1391 
1392     // sendNetlinkRequest returns -errno
1393     int ret = -1 * sendNetlinkRequest(action, flags, iov, ARRAY_SIZE(iov), nullptr);
1394     if (ret) {
1395         ALOGE("Error in removing virtual tunnel interface %s. Error Code: %d", iflaIfNameStrValue,
1396               ret);
1397     }
1398     return ret;
1399 }
1400 
1401 } // namespace net
1402 } // namespace android
1403