1 //
2 // Copyright (C) 2012 The Android Open Source Project
3 //
4 // Licensed under the Apache License, Version 2.0 (the "License");
5 // you may not use this file except in compliance with the License.
6 // You may obtain a copy of the License at
7 //
8 //      http://www.apache.org/licenses/LICENSE-2.0
9 //
10 // Unless required by applicable law or agreed to in writing, software
11 // distributed under the License is distributed on an "AS IS" BASIS,
12 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 // See the License for the specific language governing permissions and
14 // limitations under the License.
15 //
16 
17 #include "shill/net/rtnl_handler.h"
18 
19 #include <string>
20 
21 #include <gtest/gtest.h>
22 #include <net/if.h>
23 #include <sys/socket.h>
24 #include <linux/netlink.h>  // Needs typedefs from sys/socket.h.
25 #include <linux/rtnetlink.h>
26 #include <sys/ioctl.h>
27 
28 #include <base/bind.h>
29 
30 #include "shill/mock_log.h"
31 #include "shill/net/mock_io_handler_factory.h"
32 #include "shill/net/mock_sockets.h"
33 #include "shill/net/rtnl_message.h"
34 
35 using base::Bind;
36 using base::Callback;
37 using base::Unretained;
38 using std::string;
39 using testing::_;
40 using testing::A;
41 using testing::DoAll;
42 using testing::ElementsAre;
43 using testing::HasSubstr;
44 using testing::Return;
45 using testing::ReturnArg;
46 using testing::StrictMock;
47 using testing::Test;
48 
49 namespace shill {
50 
51 namespace {
52 
53 const int kTestInterfaceIndex = 4;
54 
ACTION(SetInterfaceIndex)55 ACTION(SetInterfaceIndex) {
56   if (arg2) {
57     reinterpret_cast<struct ifreq*>(arg2)->ifr_ifindex = kTestInterfaceIndex;
58   }
59 }
60 
61 MATCHER_P(MessageType, message_type, "") {
62   return std::get<0>(arg).type() == message_type;
63 }
64 
65 }  // namespace
66 
67 class RTNLHandlerTest : public Test {
68  public:
RTNLHandlerTest()69   RTNLHandlerTest()
70       : sockets_(new StrictMock<MockSockets>()),
71         callback_(Bind(&RTNLHandlerTest::HandlerCallback, Unretained(this))),
72         dummy_message_(RTNLMessage::kTypeLink,
73                        RTNLMessage::kModeGet,
74                        0,
75                        0,
76                        0,
77                        0,
78                        IPAddress::kFamilyUnknown) {
79   }
80 
SetUp()81   virtual void SetUp() {
82     RTNLHandler::GetInstance()->io_handler_factory_ = &io_handler_factory_;
83     RTNLHandler::GetInstance()->sockets_.reset(sockets_);
84   }
85 
TearDown()86   virtual void TearDown() {
87     RTNLHandler::GetInstance()->Stop();
88   }
89 
GetRequestSequence()90   uint32_t GetRequestSequence() {
91     return RTNLHandler::GetInstance()->request_sequence_;
92   }
93 
SetRequestSequence(uint32_t sequence)94   void SetRequestSequence(uint32_t sequence) {
95     RTNLHandler::GetInstance()->request_sequence_ = sequence;
96   }
97 
IsSequenceInErrorMaskWindow(uint32_t sequence)98   bool IsSequenceInErrorMaskWindow(uint32_t sequence) {
99     return RTNLHandler::GetInstance()->IsSequenceInErrorMaskWindow(sequence);
100   }
101 
SetErrorMask(uint32_t sequence,const RTNLHandler::ErrorMask & error_mask)102   void SetErrorMask(uint32_t sequence,
103                     const RTNLHandler::ErrorMask& error_mask) {
104     return RTNLHandler::GetInstance()->SetErrorMask(sequence, error_mask);
105   }
106 
GetAndClearErrorMask(uint32_t sequence)107   RTNLHandler::ErrorMask GetAndClearErrorMask(uint32_t sequence) {
108     return RTNLHandler::GetInstance()->GetAndClearErrorMask(sequence);
109   }
110 
GetErrorWindowSize()111   int GetErrorWindowSize() {
112     return  RTNLHandler::kErrorWindowSize;
113   }
114 
115   MOCK_METHOD1(HandlerCallback, void(const RTNLMessage&));
116 
117  protected:
118   static const int kTestSocket;
119   static const int kTestDeviceIndex;
120   static const char kTestDeviceName[];
121 
122   void AddLink();
123   void AddNeighbor();
124   void StartRTNLHandler();
125   void StopRTNLHandler();
126   void ReturnError(uint32_t sequence, int error_number);
127 
128   MockSockets* sockets_;
129   StrictMock<MockIOHandlerFactory> io_handler_factory_;
130   Callback<void(const RTNLMessage&)> callback_;
131   RTNLMessage dummy_message_;
132 };
133 
134 const int RTNLHandlerTest::kTestSocket = 123;
135 const int RTNLHandlerTest::kTestDeviceIndex = 123456;
136 const char RTNLHandlerTest::kTestDeviceName[] = "test-device";
137 
StartRTNLHandler()138 void RTNLHandlerTest::StartRTNLHandler() {
139   EXPECT_CALL(*sockets_, Socket(PF_NETLINK, SOCK_DGRAM, NETLINK_ROUTE))
140       .WillOnce(Return(kTestSocket));
141   EXPECT_CALL(*sockets_, Bind(kTestSocket, _, sizeof(sockaddr_nl)))
142       .WillOnce(Return(0));
143   EXPECT_CALL(*sockets_, SetReceiveBuffer(kTestSocket, _)).WillOnce(Return(0));
144   EXPECT_CALL(io_handler_factory_, CreateIOInputHandler(kTestSocket, _, _));
145   RTNLHandler::GetInstance()->Start(0);
146 }
147 
StopRTNLHandler()148 void RTNLHandlerTest::StopRTNLHandler() {
149   EXPECT_CALL(*sockets_, Close(kTestSocket)).WillOnce(Return(0));
150   RTNLHandler::GetInstance()->Stop();
151 }
152 
AddLink()153 void RTNLHandlerTest::AddLink() {
154   RTNLMessage message(RTNLMessage::kTypeLink,
155                       RTNLMessage::kModeAdd,
156                       0,
157                       0,
158                       0,
159                       kTestDeviceIndex,
160                       IPAddress::kFamilyIPv4);
161   message.SetAttribute(static_cast<uint16_t>(IFLA_IFNAME),
162                        ByteString(string(kTestDeviceName), true));
163   ByteString b(message.Encode());
164   InputData data(b.GetData(), b.GetLength());
165   RTNLHandler::GetInstance()->ParseRTNL(&data);
166 }
167 
AddNeighbor()168 void RTNLHandlerTest::AddNeighbor() {
169   RTNLMessage message(RTNLMessage::kTypeNeighbor,
170                       RTNLMessage::kModeAdd,
171                       0,
172                       0,
173                       0,
174                       kTestDeviceIndex,
175                       IPAddress::kFamilyIPv4);
176   ByteString encoded(message.Encode());
177   InputData data(encoded.GetData(), encoded.GetLength());
178   RTNLHandler::GetInstance()->ParseRTNL(&data);
179 }
180 
ReturnError(uint32_t sequence,int error_number)181 void RTNLHandlerTest::ReturnError(uint32_t sequence, int error_number) {
182   struct {
183     struct nlmsghdr hdr;
184     struct nlmsgerr err;
185   } errmsg;
186 
187   memset(&errmsg, 0, sizeof(errmsg));
188   errmsg.hdr.nlmsg_type = NLMSG_ERROR;
189   errmsg.hdr.nlmsg_len = NLMSG_LENGTH(sizeof(errmsg.err));
190   errmsg.hdr.nlmsg_seq = sequence;
191   errmsg.err.error = -error_number;
192 
193   InputData data(reinterpret_cast<unsigned char*>(&errmsg), sizeof(errmsg));
194   RTNLHandler::GetInstance()->ParseRTNL(&data);
195 }
196 
TEST_F(RTNLHandlerTest,ListenersInvoked)197 TEST_F(RTNLHandlerTest, ListenersInvoked) {
198   StartRTNLHandler();
199 
200   std::unique_ptr<RTNLListener> link_listener(
201       new RTNLListener(RTNLHandler::kRequestLink, callback_));
202   std::unique_ptr<RTNLListener> neighbor_listener(
203       new RTNLListener(RTNLHandler::kRequestNeighbor, callback_));
204 
205   EXPECT_CALL(*this, HandlerCallback(A<const RTNLMessage&>()))
206       .With(MessageType(RTNLMessage::kTypeLink));
207   EXPECT_CALL(*this, HandlerCallback(A<const RTNLMessage&>()))
208       .With(MessageType(RTNLMessage::kTypeNeighbor));
209 
210   AddLink();
211   AddNeighbor();
212 
213   StopRTNLHandler();
214 }
215 
TEST_F(RTNLHandlerTest,GetInterfaceName)216 TEST_F(RTNLHandlerTest, GetInterfaceName) {
217   EXPECT_EQ(-1, RTNLHandler::GetInstance()->GetInterfaceIndex(""));
218   {
219     struct ifreq ifr;
220     string name(sizeof(ifr.ifr_name), 'x');
221     EXPECT_EQ(-1, RTNLHandler::GetInstance()->GetInterfaceIndex(name));
222   }
223 
224   const int kTestSocket = 123;
225   EXPECT_CALL(*sockets_, Socket(PF_INET, SOCK_DGRAM, 0))
226       .Times(3)
227       .WillOnce(Return(-1))
228       .WillRepeatedly(Return(kTestSocket));
229   EXPECT_CALL(*sockets_, Ioctl(kTestSocket, SIOCGIFINDEX, _))
230       .WillOnce(Return(-1))
231       .WillOnce(DoAll(SetInterfaceIndex(), Return(0)));
232   EXPECT_CALL(*sockets_, Close(kTestSocket))
233       .Times(2)
234       .WillRepeatedly(Return(0));
235   EXPECT_EQ(-1, RTNLHandler::GetInstance()->GetInterfaceIndex("eth0"));
236   EXPECT_EQ(-1, RTNLHandler::GetInstance()->GetInterfaceIndex("wlan0"));
237   EXPECT_EQ(kTestInterfaceIndex,
238             RTNLHandler::GetInstance()->GetInterfaceIndex("usb0"));
239 }
240 
TEST_F(RTNLHandlerTest,IsSequenceInErrorMaskWindow)241 TEST_F(RTNLHandlerTest, IsSequenceInErrorMaskWindow) {
242   const uint32_t kRequestSequence = 1234;
243   SetRequestSequence(kRequestSequence);
244   EXPECT_FALSE(IsSequenceInErrorMaskWindow(kRequestSequence + 1));
245   EXPECT_TRUE(IsSequenceInErrorMaskWindow(kRequestSequence));
246   EXPECT_TRUE(IsSequenceInErrorMaskWindow(kRequestSequence - 1));
247   EXPECT_TRUE(IsSequenceInErrorMaskWindow(kRequestSequence -
248                                           GetErrorWindowSize() + 1));
249   EXPECT_FALSE(IsSequenceInErrorMaskWindow(kRequestSequence -
250                                            GetErrorWindowSize()));
251   EXPECT_FALSE(IsSequenceInErrorMaskWindow(kRequestSequence -
252                                            GetErrorWindowSize() - 1));
253 }
254 
TEST_F(RTNLHandlerTest,SendMessageReturnsErrorAndAdvancesSequenceNumber)255 TEST_F(RTNLHandlerTest, SendMessageReturnsErrorAndAdvancesSequenceNumber) {
256   StartRTNLHandler();
257   const uint32_t kSequenceNumber = 123;
258   SetRequestSequence(kSequenceNumber);
259   EXPECT_CALL(*sockets_, Send(kTestSocket, _, _, 0)).WillOnce(Return(-1));
260   EXPECT_FALSE(RTNLHandler::GetInstance()->SendMessage(&dummy_message_));
261 
262   // Sequence number should still increment even if there was a failure.
263   EXPECT_EQ(kSequenceNumber + 1, GetRequestSequence());
264   StopRTNLHandler();
265 }
266 
TEST_F(RTNLHandlerTest,SendMessageWithEmptyMask)267 TEST_F(RTNLHandlerTest, SendMessageWithEmptyMask) {
268   StartRTNLHandler();
269   const uint32_t kSequenceNumber = 123;
270   SetRequestSequence(kSequenceNumber);
271   SetErrorMask(kSequenceNumber, {1, 2, 3});
272   EXPECT_CALL(*sockets_, Send(kTestSocket, _, _, 0)).WillOnce(ReturnArg<2>());
273   EXPECT_TRUE(RTNLHandler::GetInstance()->SendMessageWithErrorMask(
274       &dummy_message_, {}));
275   EXPECT_EQ(kSequenceNumber + 1, GetRequestSequence());
276   EXPECT_TRUE(GetAndClearErrorMask(kSequenceNumber).empty());
277   StopRTNLHandler();
278 }
279 
TEST_F(RTNLHandlerTest,SendMessageWithErrorMask)280 TEST_F(RTNLHandlerTest, SendMessageWithErrorMask) {
281   StartRTNLHandler();
282   const uint32_t kSequenceNumber = 123;
283   SetRequestSequence(kSequenceNumber);
284   EXPECT_CALL(*sockets_, Send(kTestSocket, _, _, 0)).WillOnce(ReturnArg<2>());
285   EXPECT_TRUE(RTNLHandler::GetInstance()->SendMessageWithErrorMask(
286       &dummy_message_, {1, 2, 3}));
287   EXPECT_EQ(kSequenceNumber + 1, GetRequestSequence());
288   EXPECT_TRUE(GetAndClearErrorMask(kSequenceNumber + 1).empty());
289   EXPECT_THAT(GetAndClearErrorMask(kSequenceNumber), ElementsAre(1, 2, 3));
290 
291   // A second call to GetAndClearErrorMask() returns an empty vector.
292   EXPECT_TRUE(GetAndClearErrorMask(kSequenceNumber).empty());
293   StopRTNLHandler();
294 }
295 
TEST_F(RTNLHandlerTest,SendMessageInferredErrorMasks)296 TEST_F(RTNLHandlerTest, SendMessageInferredErrorMasks) {
297   struct {
298     RTNLMessage::Type type;
299     RTNLMessage::Mode mode;
300     RTNLHandler::ErrorMask mask;
301   } expectations[] = {
302     { RTNLMessage::kTypeLink, RTNLMessage::kModeGet, {} },
303     { RTNLMessage::kTypeLink, RTNLMessage::kModeAdd, {EEXIST} },
304     { RTNLMessage::kTypeLink, RTNLMessage::kModeDelete, {ESRCH, ENODEV} },
305     { RTNLMessage::kTypeAddress, RTNLMessage::kModeDelete,
306          {ESRCH, ENODEV, EADDRNOTAVAIL} }
307   };
308   const uint32_t kSequenceNumber = 123;
309   EXPECT_CALL(*sockets_, Send(_, _, _, 0)).WillRepeatedly(ReturnArg<2>());
310   for (const auto& expectation : expectations) {
311     SetRequestSequence(kSequenceNumber);
312     RTNLMessage message(expectation.type,
313                         expectation.mode,
314                         0,
315                         0,
316                         0,
317                         0,
318                         IPAddress::kFamilyUnknown);
319     EXPECT_TRUE(RTNLHandler::GetInstance()->SendMessage(&message));
320     EXPECT_EQ(expectation.mask, GetAndClearErrorMask(kSequenceNumber));
321   }
322 }
323 
TEST_F(RTNLHandlerTest,MaskedError)324 TEST_F(RTNLHandlerTest, MaskedError) {
325   StartRTNLHandler();
326   const uint32_t kSequenceNumber = 123;
327   SetRequestSequence(kSequenceNumber);
328   EXPECT_CALL(*sockets_, Send(kTestSocket, _, _, 0)).WillOnce(ReturnArg<2>());
329   EXPECT_TRUE(RTNLHandler::GetInstance()->SendMessageWithErrorMask(
330       &dummy_message_, {1, 2, 3}));
331   ScopedMockLog log;
332 
333   // This error will be not be masked since this sequence number has no mask.
334   EXPECT_CALL(log, Log(logging::LOG_ERROR, _, HasSubstr("error 1"))).Times(1);
335   ReturnError(kSequenceNumber - 1, 1);
336 
337   // This error will be masked.
338   EXPECT_CALL(log, Log(logging::LOG_ERROR, _, HasSubstr("error 2"))).Times(0);
339   ReturnError(kSequenceNumber, 2);
340 
341   // This second error will be not be masked since the error mask was removed.
342   EXPECT_CALL(log, Log(logging::LOG_ERROR, _, HasSubstr("error 3"))).Times(1);
343   ReturnError(kSequenceNumber, 3);
344 
345   StopRTNLHandler();
346 }
347 
348 }  // namespace shill
349