1 //
2 // Copyright (C) 2012 The Android Open Source Project
3 //
4 // Licensed under the Apache License, Version 2.0 (the "License");
5 // you may not use this file except in compliance with the License.
6 // You may obtain a copy of the License at
7 //
8 //      http://www.apache.org/licenses/LICENSE-2.0
9 //
10 // Unless required by applicable law or agreed to in writing, software
11 // distributed under the License is distributed on an "AS IS" BASIS,
12 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 // See the License for the specific language governing permissions and
14 // limitations under the License.
15 //
16 
17 #include "shill/net/rtnl_handler.h"
18 
19 #include <arpa/inet.h>
20 #include <errno.h>
21 #include <fcntl.h>
22 #include <linux/netlink.h>
23 #include <linux/rtnetlink.h>
24 #include <net/if.h>
25 #include <net/if_arp.h>
26 #include <netinet/ether.h>
27 #include <string.h>
28 #include <sys/ioctl.h>
29 #include <sys/socket.h>
30 #include <time.h>
31 #include <unistd.h>
32 
33 #include <base/bind.h>
34 #include <base/logging.h>
35 #include <base/stl_util.h>
36 
37 #include "shill/net/io_handler.h"
38 #include "shill/net/ip_address.h"
39 #include "shill/net/ndisc.h"
40 #include "shill/net/rtnl_listener.h"
41 #include "shill/net/rtnl_message.h"
42 #include "shill/net/sockets.h"
43 
44 using base::Bind;
45 using base::Unretained;
46 using std::string;
47 
48 namespace shill {
49 
50 // Keep this large enough to avoid overflows on IPv6 SNM routing update spikes
51 const int RTNLHandler::kReceiveBufferSize = 512 * 1024;
52 const int RTNLHandler::kInvalidSocket = -1;
53 const int RTNLHandler::kErrorWindowSize = 16;
54 
55 namespace {
56 base::LazyInstance<RTNLHandler> g_rtnl_handler = LAZY_INSTANCE_INITIALIZER;
57 }  // namespace
58 
RTNLHandler()59 RTNLHandler::RTNLHandler()
60     : sockets_(new Sockets()),
61       in_request_(false),
62       rtnl_socket_(kInvalidSocket),
63       request_flags_(0),
64       request_sequence_(0),
65       last_dump_sequence_(0),
66       rtnl_callback_(Bind(&RTNLHandler::ParseRTNL, Unretained(this))),
67       io_handler_factory_(
68           IOHandlerFactoryContainer::GetInstance()->GetIOHandlerFactory()) {
69   error_mask_window_.resize(kErrorWindowSize);
70   VLOG(2) << "RTNLHandler created";
71 }
72 
~RTNLHandler()73 RTNLHandler::~RTNLHandler() {
74   VLOG(2) << "RTNLHandler removed";
75   Stop();
76 }
77 
GetInstance()78 RTNLHandler* RTNLHandler::GetInstance() {
79   return g_rtnl_handler.Pointer();
80 }
81 
Start(uint32_t netlink_groups_mask)82 void RTNLHandler::Start(uint32_t netlink_groups_mask) {
83   struct sockaddr_nl addr;
84 
85   if (rtnl_socket_ != kInvalidSocket) {
86     return;
87   }
88 
89   rtnl_socket_ = sockets_->Socket(PF_NETLINK, SOCK_DGRAM, NETLINK_ROUTE);
90   if (rtnl_socket_ < 0) {
91     LOG(ERROR) << "Failed to open rtnl socket";
92     return;
93   }
94 
95   if (sockets_->SetReceiveBuffer(rtnl_socket_, kReceiveBufferSize)) {
96     LOG(ERROR) << "Failed to increase receive buffer size";
97   }
98 
99   memset(&addr, 0, sizeof(addr));
100   addr.nl_family = AF_NETLINK;
101   addr.nl_groups = netlink_groups_mask;
102 
103   if (sockets_->Bind(rtnl_socket_,
104                     reinterpret_cast<struct sockaddr*>(&addr),
105                     sizeof(addr)) < 0) {
106     sockets_->Close(rtnl_socket_);
107     rtnl_socket_ = kInvalidSocket;
108     LOG(ERROR) << "RTNL socket bind failed";
109     return;
110   }
111 
112   rtnl_handler_.reset(io_handler_factory_->CreateIOInputHandler(
113       rtnl_socket_,
114       rtnl_callback_,
115       Bind(&RTNLHandler::OnReadError, Unretained(this))));
116 
117   NextRequest(last_dump_sequence_);
118   VLOG(2) << "RTNLHandler started";
119 }
120 
Stop()121 void RTNLHandler::Stop() {
122   rtnl_handler_.reset();
123   // Close the socket if it is currently open.
124   if (rtnl_socket_ != kInvalidSocket) {
125     sockets_->Close(rtnl_socket_);
126     rtnl_socket_ = kInvalidSocket;
127   }
128   in_request_ = false;
129   request_flags_ = 0;
130   VLOG(2) << "RTNLHandler stopped";
131 }
132 
AddListener(RTNLListener * to_add)133 void RTNLHandler::AddListener(RTNLListener* to_add) {
134   for (const auto& listener : listeners_) {
135     if (to_add == listener)
136       return;
137   }
138   listeners_.push_back(to_add);
139   VLOG(2) << "RTNLHandler added listener";
140 }
141 
RemoveListener(RTNLListener * to_remove)142 void RTNLHandler::RemoveListener(RTNLListener* to_remove) {
143   for (auto it = listeners_.begin(); it != listeners_.end(); ++it) {
144     if (to_remove == *it) {
145       listeners_.erase(it);
146       return;
147     }
148   }
149   VLOG(2) << "RTNLHandler removed listener";
150 }
151 
SetInterfaceFlags(int interface_index,unsigned int flags,unsigned int change)152 void RTNLHandler::SetInterfaceFlags(int interface_index, unsigned int flags,
153                                     unsigned int change) {
154   if (rtnl_socket_ == kInvalidSocket) {
155     LOG(ERROR) << __func__ << " called while not started.  "
156         "Assuming we are in unit tests.";
157     return;
158   }
159 
160   RTNLMessage msg(
161       RTNLMessage::kTypeLink,
162       RTNLMessage::kModeAdd,
163       NLM_F_REQUEST,
164       0,  // sequence to be filled in by RTNLHandler::SendMessage().
165       0,  // pid.
166       interface_index,
167       IPAddress::kFamilyUnknown);
168 
169   msg.set_link_status(RTNLMessage::LinkStatus(ARPHRD_VOID, flags, change));
170 
171   ErrorMask error_mask;
172   if ((flags & IFF_UP) == 0) {
173     error_mask.insert(ENODEV);
174   }
175 
176   SendMessageWithErrorMask(&msg, error_mask);
177 }
178 
SetInterfaceMTU(int interface_index,unsigned int mtu)179 void RTNLHandler::SetInterfaceMTU(int interface_index, unsigned int mtu) {
180   RTNLMessage msg(
181       RTNLMessage::kTypeLink,
182       RTNLMessage::kModeAdd,
183       NLM_F_REQUEST,
184       0,  // sequence to be filled in by RTNLHandler::SendMessage().
185       0,  // pid.
186       interface_index,
187       IPAddress::kFamilyUnknown);
188 
189   msg.SetAttribute(
190       IFLA_MTU,
191       ByteString(reinterpret_cast<unsigned char*>(&mtu), sizeof(mtu)));
192 
193   CHECK(SendMessage(&msg));
194 }
195 
RequestDump(int request_flags)196 void RTNLHandler::RequestDump(int request_flags) {
197   if (rtnl_socket_ == kInvalidSocket) {
198     LOG(ERROR) << __func__ << " called while not started.  "
199         "Assuming we are in unit tests.";
200     return;
201   }
202 
203   request_flags_ |= request_flags;
204 
205   VLOG(2) << "RTNLHandler got request to dump "
206           << std::showbase << std::hex
207           << request_flags
208           << std::dec << std::noshowbase;
209 
210   if (!in_request_) {
211     NextRequest(last_dump_sequence_);
212   }
213 }
214 
DispatchEvent(int type,const RTNLMessage & msg)215 void RTNLHandler::DispatchEvent(int type, const RTNLMessage& msg) {
216   for (const auto& listener : listeners_) {
217     listener->NotifyEvent(type, msg);
218   }
219 }
220 
NextRequest(uint32_t seq)221 void RTNLHandler::NextRequest(uint32_t seq) {
222   int flag = 0;
223   RTNLMessage::Type type;
224 
225   VLOG(2) << "RTNLHandler nextrequest " << seq << " "
226           << last_dump_sequence_
227           << std::showbase << std::hex
228           << " " << request_flags_
229           << std::dec << std::noshowbase;
230 
231   if (seq != last_dump_sequence_)
232     return;
233 
234   IPAddress::Family family = IPAddress::kFamilyUnknown;
235   if ((request_flags_ & kRequestAddr) != 0) {
236     type = RTNLMessage::kTypeAddress;
237     flag = kRequestAddr;
238   } else if ((request_flags_ & kRequestRoute) != 0) {
239     type = RTNLMessage::kTypeRoute;
240     flag = kRequestRoute;
241   } else if ((request_flags_ & kRequestLink) != 0) {
242     type = RTNLMessage::kTypeLink;
243     flag = kRequestLink;
244   } else if ((request_flags_ & kRequestNeighbor) != 0) {
245     type = RTNLMessage::kTypeNeighbor;
246     flag = kRequestNeighbor;
247   } else if ((request_flags_ & kRequestBridgeNeighbor) != 0) {
248     type = RTNLMessage::kTypeNeighbor;
249     flag = kRequestBridgeNeighbor;
250     family = AF_BRIDGE;
251   } else {
252     VLOG(2) << "Done with requests";
253     in_request_ = false;
254     return;
255   }
256 
257   RTNLMessage msg(
258       type,
259       RTNLMessage::kModeGet,
260       0,
261       0,
262       0,
263       0,
264       family);
265   CHECK(SendMessage(&msg));
266 
267   last_dump_sequence_ = msg.seq();
268   request_flags_ &= ~flag;
269   in_request_ = true;
270 }
271 
ParseRTNL(InputData * data)272 void RTNLHandler::ParseRTNL(InputData* data) {
273   unsigned char* buf = data->buf;
274   unsigned char* end = buf + data->len;
275 
276   while (buf < end) {
277     struct nlmsghdr* hdr = reinterpret_cast<struct nlmsghdr*>(buf);
278     if (!NLMSG_OK(hdr, static_cast<unsigned int>(end - buf)))
279       break;
280 
281     VLOG(5) << __func__ << ": received payload (" << end - buf << ")";
282 
283     RTNLMessage msg;
284     ByteString payload(reinterpret_cast<unsigned char*>(hdr), hdr->nlmsg_len);
285     VLOG(5) << "RTNL received payload length " << payload.GetLength()
286             << ": \"" << payload.HexEncode() << "\"";
287     if (!msg.Decode(payload)) {
288       VLOG(5) << __func__ << ": rtnl packet type "
289               << hdr->nlmsg_type
290               << " length " << hdr->nlmsg_len
291               << " sequence " << hdr->nlmsg_seq;
292 
293       switch (hdr->nlmsg_type) {
294         case NLMSG_NOOP:
295         case NLMSG_OVERRUN:
296           break;
297         case NLMSG_DONE:
298           GetAndClearErrorMask(hdr->nlmsg_seq);  // Clear any queued error mask.
299           NextRequest(hdr->nlmsg_seq);
300           break;
301         case NLMSG_ERROR:
302           {
303             struct nlmsgerr* err =
304                 reinterpret_cast<nlmsgerr*>(NLMSG_DATA(hdr));
305             int error_number = -err->error;
306             std::ostringstream message;
307             message << "sequence " << hdr->nlmsg_seq << " received error "
308                     << error_number << " ("
309                     << strerror(error_number) << ")";
310             if (!ContainsValue(GetAndClearErrorMask(hdr->nlmsg_seq),
311                                error_number)) {
312               LOG(ERROR) << message.str();
313             } else {
314               VLOG(3) << message.str();
315             }
316             break;
317           }
318         default:
319           NOTIMPLEMENTED() << "Unknown NL message type.";
320       }
321     } else {
322       switch (msg.type()) {
323         case RTNLMessage::kTypeLink:
324           DispatchEvent(kRequestLink, msg);
325           break;
326         case RTNLMessage::kTypeAddress:
327           DispatchEvent(kRequestAddr, msg);
328           break;
329         case RTNLMessage::kTypeRoute:
330           DispatchEvent(kRequestRoute, msg);
331           break;
332         case RTNLMessage::kTypeRdnss:
333           DispatchEvent(kRequestRdnss, msg);
334           break;
335         case RTNLMessage::kTypeNeighbor:
336           DispatchEvent(kRequestNeighbor, msg);
337           break;
338         case RTNLMessage::kTypeDnssl:
339           NOTIMPLEMENTED();
340           break;
341         default:
342           NOTIMPLEMENTED() << "Unknown RTNL message type.";
343       }
344     }
345     buf += hdr->nlmsg_len;
346   }
347 }
348 
AddressRequest(int interface_index,RTNLMessage::Mode mode,int flags,const IPAddress & local,const IPAddress & broadcast,const IPAddress & peer)349 bool RTNLHandler::AddressRequest(int interface_index,
350                                  RTNLMessage::Mode mode,
351                                  int flags,
352                                  const IPAddress& local,
353                                  const IPAddress& broadcast,
354                                  const IPAddress& peer) {
355   CHECK(local.family() == broadcast.family());
356   CHECK(local.family() == peer.family());
357 
358   RTNLMessage msg(
359       RTNLMessage::kTypeAddress,
360       mode,
361       NLM_F_REQUEST | flags,
362       0,
363       0,
364       interface_index,
365       local.family());
366 
367   msg.set_address_status(RTNLMessage::AddressStatus(
368       local.prefix(),
369       0,
370       0));
371 
372   msg.SetAttribute(IFA_LOCAL, local.address());
373   if (!broadcast.IsDefault()) {
374     msg.SetAttribute(IFA_BROADCAST, broadcast.address());
375   }
376   if (!peer.IsDefault()) {
377     msg.SetAttribute(IFA_ADDRESS, peer.address());
378   }
379 
380   return SendMessage(&msg);
381 }
382 
AddInterfaceAddress(int interface_index,const IPAddress & local,const IPAddress & broadcast,const IPAddress & peer)383 bool RTNLHandler::AddInterfaceAddress(int interface_index,
384                                       const IPAddress& local,
385                                       const IPAddress& broadcast,
386                                       const IPAddress& peer) {
387     return AddressRequest(interface_index,
388                           RTNLMessage::kModeAdd,
389                           NLM_F_CREATE | NLM_F_EXCL | NLM_F_ECHO,
390                           local,
391                           broadcast,
392                           peer);
393 }
394 
RemoveInterfaceAddress(int interface_index,const IPAddress & local)395 bool RTNLHandler::RemoveInterfaceAddress(int interface_index,
396                                          const IPAddress& local) {
397   return AddressRequest(interface_index,
398                         RTNLMessage::kModeDelete,
399                         NLM_F_ECHO,
400                         local,
401                         IPAddress(local.family()),
402                         IPAddress(local.family()));
403 }
404 
RemoveInterface(int interface_index)405 bool RTNLHandler::RemoveInterface(int interface_index) {
406   RTNLMessage msg(
407       RTNLMessage::kTypeLink,
408       RTNLMessage::kModeDelete,
409       NLM_F_REQUEST,
410       0,
411       0,
412       interface_index,
413       IPAddress::kFamilyUnknown);
414   return SendMessage(&msg);
415 }
416 
GetInterfaceIndex(const string & interface_name)417 int RTNLHandler::GetInterfaceIndex(const string& interface_name) {
418   if (interface_name.empty()) {
419     LOG(ERROR) << "Empty interface name -- unable to obtain index.";
420     return -1;
421   }
422   struct ifreq ifr;
423   if (interface_name.size() >= sizeof(ifr.ifr_name)) {
424     LOG(ERROR) << "Interface name too long: " << interface_name.size() << " >= "
425                << sizeof(ifr.ifr_name);
426     return -1;
427   }
428   int socket = sockets_->Socket(PF_INET, SOCK_DGRAM, 0);
429   if (socket < 0) {
430     PLOG(ERROR) << "Unable to open INET socket";
431     return -1;
432   }
433   ScopedSocketCloser socket_closer(sockets_.get(), socket);
434   memset(&ifr, 0, sizeof(ifr));
435   strncpy(ifr.ifr_name, interface_name.c_str(), sizeof(ifr.ifr_name));
436   if (sockets_->Ioctl(socket, SIOCGIFINDEX, &ifr) < 0) {
437     PLOG(ERROR) << "SIOCGIFINDEX error for " << interface_name;
438     return -1;
439   }
440   return ifr.ifr_ifindex;
441 }
442 
SendMessageWithErrorMask(RTNLMessage * message,const ErrorMask & error_mask)443 bool RTNLHandler::SendMessageWithErrorMask(RTNLMessage* message,
444                                            const ErrorMask& error_mask) {
445   VLOG(5) << __func__ << " sequence " << request_sequence_
446           << " message type " << message->type()
447           << " mode " << message->mode()
448           << " with error mask size " << error_mask.size();
449 
450   SetErrorMask(request_sequence_, error_mask);
451   message->set_seq(request_sequence_);
452   ByteString msgdata = message->Encode();
453 
454   if (msgdata.GetLength() == 0) {
455     return false;
456   }
457 
458   VLOG(5) << "RTNL sending payload with request sequence "
459                 << request_sequence_ << ", length " << msgdata.GetLength()
460                 << ": \"" << msgdata.HexEncode() << "\"";
461 
462   request_sequence_++;
463 
464   if (sockets_->Send(rtnl_socket_,
465                      msgdata.GetConstData(),
466                      msgdata.GetLength(),
467                      0) < 0) {
468     PLOG(ERROR) << "RTNL send failed";
469     return false;
470   }
471 
472   return true;
473 }
474 
SendMessage(RTNLMessage * message)475 bool RTNLHandler::SendMessage(RTNLMessage* message) {
476   ErrorMask error_mask;
477   if (message->mode() == RTNLMessage::kModeAdd) {
478     error_mask = { EEXIST };
479   } else if (message->mode() == RTNLMessage::kModeDelete) {
480     error_mask = { ESRCH, ENODEV };
481     if (message->type() == RTNLMessage::kTypeAddress) {
482       error_mask.insert(EADDRNOTAVAIL);
483     }
484   }
485   return SendMessageWithErrorMask(message, error_mask);
486 }
487 
IsSequenceInErrorMaskWindow(uint32_t sequence)488 bool RTNLHandler::IsSequenceInErrorMaskWindow(uint32_t sequence) {
489   return (request_sequence_ - sequence) < kErrorWindowSize;
490 }
491 
SetErrorMask(uint32_t sequence,const ErrorMask & error_mask)492 void RTNLHandler::SetErrorMask(uint32_t sequence, const ErrorMask& error_mask) {
493   if (IsSequenceInErrorMaskWindow(sequence)) {
494     error_mask_window_[sequence % kErrorWindowSize] = error_mask;
495   }
496 }
497 
GetAndClearErrorMask(uint32_t sequence)498 RTNLHandler::ErrorMask RTNLHandler::GetAndClearErrorMask(uint32_t sequence) {
499   ErrorMask error_mask;
500   if (IsSequenceInErrorMaskWindow(sequence)) {
501     error_mask.swap(error_mask_window_[sequence % kErrorWindowSize]);
502   }
503   return error_mask;
504 }
505 
OnReadError(const string & error_msg)506 void RTNLHandler::OnReadError(const string& error_msg) {
507   LOG(FATAL) << "RTNL Socket read returns error: "
508              << error_msg;
509 }
510 
511 }  // namespace shill
512