1 //
2 // Copyright (C) 2013 The Android Open Source Project
3 //
4 // Licensed under the Apache License, Version 2.0 (the "License");
5 // you may not use this file except in compliance with the License.
6 // You may obtain a copy of the License at
7 //
8 //      http://www.apache.org/licenses/LICENSE-2.0
9 //
10 // Unless required by applicable law or agreed to in writing, software
11 // distributed under the License is distributed on an "AS IS" BASIS,
12 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 // See the License for the specific language governing permissions and
14 // limitations under the License.
15 //
16 
17 #include "shill/net/netlink_socket.h"
18 
19 #include <linux/netlink.h>
20 
21 #include <algorithm>
22 #include <string>
23 
24 #include <gmock/gmock.h>
25 #include <gtest/gtest.h>
26 
27 #include "shill/net/byte_string.h"
28 #include "shill/net/mock_sockets.h"
29 #include "shill/net/netlink_message.h"
30 
31 using std::min;
32 using std::string;
33 using testing::_;
34 using testing::Invoke;
35 using testing::Return;
36 using testing::Test;
37 
38 namespace shill {
39 
40 class NetlinkSocketTest;
41 
42 const int kFakeFd = 99;
43 
44 class NetlinkSocketTest : public Test {
45  public:
NetlinkSocketTest()46   NetlinkSocketTest() {}
~NetlinkSocketTest()47   virtual ~NetlinkSocketTest() {}
48 
SetUp()49   virtual void SetUp() {
50     mock_sockets_ = new MockSockets();
51     netlink_socket_.sockets_.reset(mock_sockets_);
52   }
53 
InitializeSocket(int fd)54   virtual void InitializeSocket(int fd) {
55     EXPECT_CALL(*mock_sockets_, Socket(PF_NETLINK, SOCK_DGRAM, NETLINK_GENERIC))
56         .WillOnce(Return(fd));
57     EXPECT_CALL(*mock_sockets_, SetReceiveBuffer(
58         fd, NetlinkSocket::kReceiveBufferSize)).WillOnce(Return(0));
59     EXPECT_CALL(*mock_sockets_, Bind(fd, _, sizeof(struct sockaddr_nl)))
60         .WillOnce(Return(0));
61     EXPECT_TRUE(netlink_socket_.Init());
62   }
63 
64  protected:
65   MockSockets* mock_sockets_;  // Owned by netlink_socket_.
66   NetlinkSocket netlink_socket_;
67 };
68 
69 class FakeSocketRead {
70  public:
FakeSocketRead(const ByteString & next_read_string)71   explicit FakeSocketRead(const ByteString& next_read_string) {
72     next_read_string_ = next_read_string;
73   }
74   // Copies |len| bytes of |next_read_string_| into |buf| and clears
75   // |next_read_string_|.
FakeSuccessfulRead(int sockfd,void * buf,size_t len,int flags,struct sockaddr * src_addr,socklen_t * addrlen)76   ssize_t FakeSuccessfulRead(int sockfd, void* buf, size_t len, int flags,
77                              struct sockaddr* src_addr, socklen_t* addrlen) {
78     if (!buf) {
79       return -1;
80     }
81     int read_bytes = min(len, next_read_string_.GetLength());
82     memcpy(buf, next_read_string_.GetConstData(), read_bytes);
83     next_read_string_.Clear();
84     return read_bytes;
85   }
86 
87  private:
88   ByteString next_read_string_;
89 };
90 
TEST_F(NetlinkSocketTest,InitWorkingTest)91 TEST_F(NetlinkSocketTest, InitWorkingTest) {
92   SetUp();
93   InitializeSocket(kFakeFd);
94   EXPECT_CALL(*mock_sockets_, Close(kFakeFd));
95 }
96 
TEST_F(NetlinkSocketTest,InitBrokenSocketTest)97 TEST_F(NetlinkSocketTest, InitBrokenSocketTest) {
98   SetUp();
99 
100   const int kBadFd = -1;
101   EXPECT_CALL(*mock_sockets_, Socket(PF_NETLINK, SOCK_DGRAM, NETLINK_GENERIC))
102       .WillOnce(Return(kBadFd));
103   EXPECT_CALL(*mock_sockets_, SetReceiveBuffer(_, _)).Times(0);
104   EXPECT_CALL(*mock_sockets_, Bind(_, _, _)).Times(0);
105   EXPECT_FALSE(netlink_socket_.Init());
106 }
107 
TEST_F(NetlinkSocketTest,InitBrokenBufferTest)108 TEST_F(NetlinkSocketTest, InitBrokenBufferTest) {
109   SetUp();
110 
111   EXPECT_CALL(*mock_sockets_, Socket(PF_NETLINK, SOCK_DGRAM, NETLINK_GENERIC))
112       .WillOnce(Return(kFakeFd));
113   EXPECT_CALL(*mock_sockets_, SetReceiveBuffer(
114       kFakeFd, NetlinkSocket::kReceiveBufferSize)).WillOnce(Return(-1));
115   EXPECT_CALL(*mock_sockets_, Bind(kFakeFd, _, sizeof(struct sockaddr_nl)))
116       .WillOnce(Return(0));
117   EXPECT_TRUE(netlink_socket_.Init());
118 
119   // Destructor.
120   EXPECT_CALL(*mock_sockets_, Close(kFakeFd));
121 }
122 
TEST_F(NetlinkSocketTest,InitBrokenBindTest)123 TEST_F(NetlinkSocketTest, InitBrokenBindTest) {
124   SetUp();
125 
126   EXPECT_CALL(*mock_sockets_, Socket(PF_NETLINK, SOCK_DGRAM, NETLINK_GENERIC))
127       .WillOnce(Return(kFakeFd));
128   EXPECT_CALL(*mock_sockets_, SetReceiveBuffer(
129       kFakeFd, NetlinkSocket::kReceiveBufferSize)).WillOnce(Return(0));
130   EXPECT_CALL(*mock_sockets_, Bind(kFakeFd, _, sizeof(struct sockaddr_nl)))
131       .WillOnce(Return(-1));
132   EXPECT_CALL(*mock_sockets_, Close(kFakeFd)).WillOnce(Return(0));
133   EXPECT_FALSE(netlink_socket_.Init());
134 }
135 
TEST_F(NetlinkSocketTest,SendMessageTest)136 TEST_F(NetlinkSocketTest, SendMessageTest) {
137   SetUp();
138   InitializeSocket(kFakeFd);
139 
140   string message_string("This text is really arbitrary");
141   ByteString message(message_string.c_str(), message_string.size());
142 
143   // Good Send.
144   EXPECT_CALL(*mock_sockets_,
145               Send(kFakeFd, message.GetConstData(), message.GetLength(), 0))
146       .WillOnce(Return(message.GetLength()));
147   EXPECT_TRUE(netlink_socket_.SendMessage(message));
148 
149   // Short Send.
150   EXPECT_CALL(*mock_sockets_,
151               Send(kFakeFd, message.GetConstData(), message.GetLength(), 0))
152       .WillOnce(Return(message.GetLength() - 3));
153   EXPECT_FALSE(netlink_socket_.SendMessage(message));
154 
155   // Bad Send.
156   EXPECT_CALL(*mock_sockets_,
157               Send(kFakeFd, message.GetConstData(), message.GetLength(), 0))
158       .WillOnce(Return(-1));
159   EXPECT_FALSE(netlink_socket_.SendMessage(message));
160 
161   // Destructor.
162   EXPECT_CALL(*mock_sockets_, Close(kFakeFd));
163 }
164 
TEST_F(NetlinkSocketTest,SequenceNumberTest)165 TEST_F(NetlinkSocketTest, SequenceNumberTest) {
166   SetUp();
167 
168   // Just a sequence number.
169   const uint32_t arbitrary_number = 42;
170   netlink_socket_.sequence_number_ = arbitrary_number;
171   EXPECT_EQ(arbitrary_number+1, netlink_socket_.GetSequenceNumber());
172 
173   // Make sure we don't go to |NetlinkMessage::kBroadcastSequenceNumber|.
174   netlink_socket_.sequence_number_ = NetlinkMessage::kBroadcastSequenceNumber;
175   EXPECT_NE(NetlinkMessage::kBroadcastSequenceNumber,
176             netlink_socket_.GetSequenceNumber());
177 }
178 
TEST_F(NetlinkSocketTest,GoodRecvMessageTest)179 TEST_F(NetlinkSocketTest, GoodRecvMessageTest) {
180   SetUp();
181   InitializeSocket(kFakeFd);
182 
183   ByteString message;
184   static const string next_read_string(
185       "Random text may include things like 'freaking fracking foo'.");
186   static const size_t read_size = next_read_string.size();
187   ByteString expected_results(next_read_string.c_str(), read_size);
188   FakeSocketRead fake_socket_read(expected_results);
189 
190   // Expect one call to get the size...
191   EXPECT_CALL(*mock_sockets_,
192               RecvFrom(kFakeFd, _, _, MSG_TRUNC | MSG_PEEK, _, _))
193       .WillOnce(Return(read_size));
194 
195   // ...and expect a second call to get the data.
196   EXPECT_CALL(*mock_sockets_,
197               RecvFrom(kFakeFd, _, read_size, 0, _, _))
198       .WillOnce(Invoke(&fake_socket_read, &FakeSocketRead::FakeSuccessfulRead));
199 
200   EXPECT_TRUE(netlink_socket_.RecvMessage(&message));
201   EXPECT_TRUE(message.Equals(expected_results));
202 
203   // Destructor.
204   EXPECT_CALL(*mock_sockets_, Close(kFakeFd));
205 }
206 
TEST_F(NetlinkSocketTest,BadRecvMessageTest)207 TEST_F(NetlinkSocketTest, BadRecvMessageTest) {
208   SetUp();
209   InitializeSocket(kFakeFd);
210 
211   ByteString message;
212   EXPECT_CALL(*mock_sockets_, RecvFrom(kFakeFd, _, _, _, _, _))
213       .WillOnce(Return(-1));
214   EXPECT_FALSE(netlink_socket_.RecvMessage(&message));
215 
216   EXPECT_CALL(*mock_sockets_, Close(kFakeFd));
217 }
218 
219 }  // namespace shill.
220