1 /* 2 * Copyright (C) 2019 The Android Open Source Project 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 #pragma once 18 19 #include <android-base/macros.h> 20 #include <android-base/unique_fd.h> 21 #include <libnl++/Buffer.h> 22 #include <libnl++/Message.h> 23 #include <libnl++/MessageFactory.h> 24 25 #include <linux/netlink.h> 26 #include <poll.h> 27 28 #include <optional> 29 #include <set> 30 #include <vector> 31 32 namespace android::nl { 33 34 /** 35 * A wrapper around AF_NETLINK sockets. 36 * 37 * This class is not thread safe to use a single instance between multiple threads, but it's fine to 38 * use multiple instances over multiple threads. 39 */ 40 class Socket { 41 public: 42 static constexpr size_t defaultReceiveSize = 8192; 43 44 /** 45 * Socket constructor. 46 * 47 * \param protocol the Netlink protocol to use. 48 * \param pid port id. Default value of 0 allows the kernel to assign us a unique pid. 49 * (NOTE: this is NOT the same as process id). 50 * \param groups Netlink multicast groups to listen to. This is a 32-bit bitfield, where each 51 * bit is a different group. Default value of 0 means no groups are selected. 52 * See man netlink.7. 53 * for more details. 54 */ 55 Socket(int protocol, unsigned pid = 0, uint32_t groups = 0); 56 57 /** 58 * Attempt to clear POLLERR by recv-ing. 59 * TODO(224850481): determine if this is necessary, or if the socket is locked up anyway. 60 */ 61 void clearPollErr(); 62 63 /** 64 * Send Netlink message with incremented sequence number to the Kernel. 65 * 66 * \param msg Message to send. Its sequence number will be updated. 67 * \return true, if succeeded. 68 */ 69 template <typename T, unsigned BUFSIZE> send(MessageFactory<T,BUFSIZE> & req)70 bool send(MessageFactory<T, BUFSIZE>& req) { 71 sockaddr_nl sa = {}; 72 sa.nl_family = AF_NETLINK; 73 sa.nl_pid = 0; // Kernel 74 return send(req, sa); 75 } 76 77 /** 78 * Send Netlink message with incremented sequence number. 79 * 80 * \param msg Message to send. Its sequence number will be updated. 81 * \param sa Destination address. 82 * \return true, if succeeded. 83 */ 84 template <typename T, unsigned BUFSIZE> send(MessageFactory<T,BUFSIZE> & req,const sockaddr_nl & sa)85 bool send(MessageFactory<T, BUFSIZE>& req, const sockaddr_nl& sa) { 86 req.header.nlmsg_seq = mSeq + 1; 87 88 const auto msg = req.build(); 89 if (!msg.has_value()) return false; 90 91 return send(*msg, sa); 92 } 93 94 /** 95 * Send Netlink message. 96 * 97 * \param msg Message to send. 98 * \param sa Destination address. 99 * \return true, if succeeded. 100 */ 101 bool send(const Buffer<nlmsghdr>& msg, const sockaddr_nl& sa); 102 103 /** 104 * Send Netlink message. 105 * 106 * \param msg Message to send. 107 * \param destination Destination PID. 108 * \return true, if succeeded. 109 */ 110 bool send(const Buffer<nlmsghdr>& msg, uint32_t destination); 111 112 /** 113 * Receive one or multiple Netlink messages. 114 * 115 * WARNING: the underlying buffer is owned by Socket class and the data is valid until the next 116 * call to the read function or until deallocation of Socket instance. 117 * 118 * \param maxSize Maximum total size of received messages 119 * \return Buffer view with message data, std::nullopt on error. 120 */ 121 std::optional<Buffer<nlmsghdr>> receive(size_t maxSize = defaultReceiveSize); 122 123 /** 124 * Receive one or multiple Netlink messages and the sender process address. 125 * 126 * WARNING: the underlying buffer is owned by Socket class and the data is valid until the next 127 * call to the read function or until deallocation of Socket instance. 128 * 129 * \param maxSize Maximum total size of received messages. 130 * \return A pair (for use with structured binding) containing: 131 * - buffer view with message data, std::nullopt on error; 132 * - sender process address. 133 */ 134 std::pair<std::optional<Buffer<nlmsghdr>>, sockaddr_nl> receiveFrom( 135 size_t maxSize = defaultReceiveSize); 136 137 /** 138 * Receive matching Netlink message of a given payload type. 139 * 140 * This method should be used if the caller expects exactly one incoming message of exactly 141 * given type (such as ACK). If there is a use case to handle multiple types of messages, 142 * please use receive(size_t) directly and iterate through potential multipart messages. 143 * 144 * If this method is used in such an environment, it will only return the first matching message 145 * from multipart packet and will issue warnings on messages that do not match. 146 * 147 * \param msgtypes Expected message types (such as NLMSG_ERROR). 148 * \param maxSize Maximum total size of received messages. 149 * \return Parsed message or std::nullopt in case of error. 150 */ 151 template <typename T> 152 std::optional<Message<T>> receive(const std::set<nlmsgtype_t>& msgtypes, 153 size_t maxSize = defaultReceiveSize) { 154 const auto msg = receive(msgtypes, maxSize); 155 if (!msg.has_value()) return std::nullopt; 156 157 const auto parsed = Message<T>::parse(*msg); 158 if (!parsed.has_value()) { 159 LOG(WARNING) << "Received matching Netlink message, but couldn't parse it"; 160 return std::nullopt; 161 } 162 163 return parsed; 164 } 165 166 /** 167 * Receive Netlink ACK message. 168 * 169 * \param req Message to match sequence number against. 170 * \return true if received ACK message, false in case of error. 171 */ 172 template <typename T, unsigned BUFSIZE> receiveAck(MessageFactory<T,BUFSIZE> & req)173 bool receiveAck(MessageFactory<T, BUFSIZE>& req) { 174 return receiveAck(req.header.nlmsg_seq); 175 } 176 177 /** 178 * Receive Netlink ACK message. 179 * 180 * \param seq Sequence number of message to ACK. 181 * \return true if received ACK message, false in case of error. 182 */ 183 bool receiveAck(uint32_t seq); 184 185 /** 186 * Fetches the socket PID. 187 * 188 * \return PID that socket is bound to or std::nullopt. 189 */ 190 std::optional<unsigned> getPid(); 191 192 /** 193 * Creates a pollfd object for the socket. 194 * 195 * \param events Value for pollfd.events. 196 * \return A populated pollfd object. 197 */ 198 pollfd preparePoll(short events = 0); 199 200 /** 201 * Join a multicast group. 202 * 203 * \param group Group ID (*not* a bitfield) 204 * \return whether the operation succeeded 205 */ 206 bool addMembership(unsigned group); 207 208 /** 209 * Leave a multicast group. 210 * 211 * \param group Group ID (*not* a bitfield) 212 * \return whether the operation succeeded 213 */ 214 bool dropMembership(unsigned group); 215 216 /** 217 * Live iterator continuously receiving messages from Netlink socket. 218 * 219 * Iteration ends when socket fails to receive a buffer. 220 * 221 * Example: 222 * ``` 223 * nl::Socket sock(NETLINK_ROUTE, 0, RTMGRP_LINK); 224 * for (const auto rawMsg : sock) { 225 * const auto msg = nl::Message<ifinfomsg>::parse(rawMsg, {RTM_NEWLINK, RTM_DELLINK}); 226 * if (!msg.has_value()) continue; 227 * 228 * LOG(INFO) << msg->attributes.get<std::string>(IFLA_IFNAME) 229 * << " is " << ((msg->data.ifi_flags & IFF_UP) ? "up" : "down"); 230 * } 231 * LOG(FATAL) << "Failed to read from Netlink socket"; 232 * ``` 233 */ 234 class receive_iterator { 235 public: 236 receive_iterator(Socket& socket, bool end); 237 238 receive_iterator operator++(); 239 bool operator==(const receive_iterator& other) const; 240 const Buffer<nlmsghdr>& operator*() const; 241 242 private: 243 Socket& mSocket; 244 bool mIsEnd; 245 Buffer<nlmsghdr>::iterator mCurrent; 246 247 void receive(); 248 }; 249 receive_iterator begin(); 250 receive_iterator end(); 251 252 private: 253 const int mProtocol; 254 base::unique_fd mFd; 255 std::vector<uint8_t> mReceiveBuffer; 256 257 bool mFailed = false; 258 uint32_t mSeq = 0; 259 260 bool increaseReceiveBuffer(size_t maxSize); 261 std::optional<Buffer<nlmsghdr>> receive(const std::set<nlmsgtype_t>& msgtypes, size_t maxSize); 262 263 DISALLOW_COPY_AND_ASSIGN(Socket); 264 }; 265 266 } // namespace android::nl 267