1 /*
2  * Copyright (C) 2019 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #pragma once
18 
19 #include <android-base/macros.h>
20 #include <android-base/unique_fd.h>
21 #include <libnl++/Buffer.h>
22 #include <libnl++/Message.h>
23 #include <libnl++/MessageFactory.h>
24 
25 #include <linux/netlink.h>
26 #include <poll.h>
27 
28 #include <optional>
29 #include <set>
30 #include <vector>
31 
32 namespace android::nl {
33 
34 /**
35  * A wrapper around AF_NETLINK sockets.
36  *
37  * This class is not thread safe to use a single instance between multiple threads, but it's fine to
38  * use multiple instances over multiple threads.
39  */
40 class Socket {
41   public:
42     static constexpr size_t defaultReceiveSize = 8192;
43 
44     /**
45      * Socket constructor.
46      *
47      * \param protocol the Netlink protocol to use.
48      * \param pid port id. Default value of 0 allows the kernel to assign us a unique pid.
49      *        (NOTE: this is NOT the same as process id).
50      * \param groups Netlink multicast groups to listen to. This is a 32-bit bitfield, where each
51      *        bit is a different group. Default value of 0 means no groups are selected.
52      *        See man netlink.7.
53      * for more details.
54      */
55     Socket(int protocol, unsigned pid = 0, uint32_t groups = 0);
56 
57     /**
58      * Attempt to clear POLLERR by recv-ing.
59      * TODO(224850481): determine if this is necessary, or if the socket is locked up anyway.
60      */
61     void clearPollErr();
62 
63     /**
64      * Send Netlink message with incremented sequence number to the Kernel.
65      *
66      * \param msg Message to send. Its sequence number will be updated.
67      * \return true, if succeeded.
68      */
69     template <typename T, unsigned BUFSIZE>
send(MessageFactory<T,BUFSIZE> & req)70     bool send(MessageFactory<T, BUFSIZE>& req) {
71         sockaddr_nl sa = {};
72         sa.nl_family = AF_NETLINK;
73         sa.nl_pid = 0;  // Kernel
74         return send(req, sa);
75     }
76 
77     /**
78      * Send Netlink message with incremented sequence number.
79      *
80      * \param msg Message to send. Its sequence number will be updated.
81      * \param sa Destination address.
82      * \return true, if succeeded.
83      */
84     template <typename T, unsigned BUFSIZE>
send(MessageFactory<T,BUFSIZE> & req,const sockaddr_nl & sa)85     bool send(MessageFactory<T, BUFSIZE>& req, const sockaddr_nl& sa) {
86         req.header.nlmsg_seq = mSeq + 1;
87 
88         const auto msg = req.build();
89         if (!msg.has_value()) return false;
90 
91         return send(*msg, sa);
92     }
93 
94     /**
95      * Send Netlink message.
96      *
97      * \param msg Message to send.
98      * \param sa Destination address.
99      * \return true, if succeeded.
100      */
101     bool send(const Buffer<nlmsghdr>& msg, const sockaddr_nl& sa);
102 
103     /**
104      * Send Netlink message.
105      *
106      * \param msg Message to send.
107      * \param destination Destination PID.
108      * \return true, if succeeded.
109      */
110     bool send(const Buffer<nlmsghdr>& msg, uint32_t destination);
111 
112     /**
113      * Receive one or multiple Netlink messages.
114      *
115      * WARNING: the underlying buffer is owned by Socket class and the data is valid until the next
116      * call to the read function or until deallocation of Socket instance.
117      *
118      * \param maxSize Maximum total size of received messages
119      * \return Buffer view with message data, std::nullopt on error.
120      */
121     std::optional<Buffer<nlmsghdr>> receive(size_t maxSize = defaultReceiveSize);
122 
123     /**
124      * Receive one or multiple Netlink messages and the sender process address.
125      *
126      * WARNING: the underlying buffer is owned by Socket class and the data is valid until the next
127      * call to the read function or until deallocation of Socket instance.
128      *
129      * \param maxSize Maximum total size of received messages.
130      * \return A pair (for use with structured binding) containing:
131      *         - buffer view with message data, std::nullopt on error;
132      *         - sender process address.
133      */
134     std::pair<std::optional<Buffer<nlmsghdr>>, sockaddr_nl> receiveFrom(
135             size_t maxSize = defaultReceiveSize);
136 
137     /**
138      * Receive matching Netlink message of a given payload type.
139      *
140      * This method should be used if the caller expects exactly one incoming message of exactly
141      * given type (such as ACK). If there is a use case to handle multiple types of messages,
142      * please use receive(size_t) directly and iterate through potential multipart messages.
143      *
144      * If this method is used in such an environment, it will only return the first matching message
145      * from multipart packet and will issue warnings on messages that do not match.
146      *
147      * \param msgtypes Expected message types (such as NLMSG_ERROR).
148      * \param maxSize Maximum total size of received messages.
149      * \return Parsed message or std::nullopt in case of error.
150      */
151     template <typename T>
152     std::optional<Message<T>> receive(const std::set<nlmsgtype_t>& msgtypes,
153                                       size_t maxSize = defaultReceiveSize) {
154         const auto msg = receive(msgtypes, maxSize);
155         if (!msg.has_value()) return std::nullopt;
156 
157         const auto parsed = Message<T>::parse(*msg);
158         if (!parsed.has_value()) {
159             LOG(WARNING) << "Received matching Netlink message, but couldn't parse it";
160             return std::nullopt;
161         }
162 
163         return parsed;
164     }
165 
166     /**
167      * Receive Netlink ACK message.
168      *
169      * \param req Message to match sequence number against.
170      * \return true if received ACK message, false in case of error.
171      */
172     template <typename T, unsigned BUFSIZE>
receiveAck(MessageFactory<T,BUFSIZE> & req)173     bool receiveAck(MessageFactory<T, BUFSIZE>& req) {
174         return receiveAck(req.header.nlmsg_seq);
175     }
176 
177     /**
178      * Receive Netlink ACK message.
179      *
180      * \param seq Sequence number of message to ACK.
181      * \return true if received ACK message, false in case of error.
182      */
183     bool receiveAck(uint32_t seq);
184 
185     /**
186      * Fetches the socket PID.
187      *
188      * \return PID that socket is bound to or std::nullopt.
189      */
190     std::optional<unsigned> getPid();
191 
192     /**
193      * Creates a pollfd object for the socket.
194      *
195      * \param events Value for pollfd.events.
196      * \return A populated pollfd object.
197      */
198     pollfd preparePoll(short events = 0);
199 
200     /**
201      * Join a multicast group.
202      *
203      * \param group Group ID (*not* a bitfield)
204      * \return whether the operation succeeded
205      */
206     bool addMembership(unsigned group);
207 
208     /**
209      * Leave a multicast group.
210      *
211      * \param group Group ID (*not* a bitfield)
212      * \return whether the operation succeeded
213      */
214     bool dropMembership(unsigned group);
215 
216     /**
217      * Live iterator continuously receiving messages from Netlink socket.
218      *
219      * Iteration ends when socket fails to receive a buffer.
220      *
221      * Example:
222      * ```
223      *     nl::Socket sock(NETLINK_ROUTE, 0, RTMGRP_LINK);
224      *     for (const auto rawMsg : sock) {
225      *         const auto msg = nl::Message<ifinfomsg>::parse(rawMsg, {RTM_NEWLINK, RTM_DELLINK});
226      *         if (!msg.has_value()) continue;
227      *
228      *         LOG(INFO) << msg->attributes.get<std::string>(IFLA_IFNAME)
229      *                   << " is " << ((msg->data.ifi_flags & IFF_UP) ? "up" : "down");
230      *     }
231      *     LOG(FATAL) << "Failed to read from Netlink socket";
232      * ```
233      */
234     class receive_iterator {
235       public:
236         receive_iterator(Socket& socket, bool end);
237 
238         receive_iterator operator++();
239         bool operator==(const receive_iterator& other) const;
240         const Buffer<nlmsghdr>& operator*() const;
241 
242       private:
243         Socket& mSocket;
244         bool mIsEnd;
245         Buffer<nlmsghdr>::iterator mCurrent;
246 
247         void receive();
248     };
249     receive_iterator begin();
250     receive_iterator end();
251 
252   private:
253     const int mProtocol;
254     base::unique_fd mFd;
255     std::vector<uint8_t> mReceiveBuffer;
256 
257     bool mFailed = false;
258     uint32_t mSeq = 0;
259 
260     bool increaseReceiveBuffer(size_t maxSize);
261     std::optional<Buffer<nlmsghdr>> receive(const std::set<nlmsgtype_t>& msgtypes, size_t maxSize);
262 
263     DISALLOW_COPY_AND_ASSIGN(Socket);
264 };
265 
266 }  // namespace android::nl
267