1 //
2 // Copyright (C) 2013 The Android Open Source Project
3 //
4 // Licensed under the Apache License, Version 2.0 (the "License");
5 // you may not use this file except in compliance with the License.
6 // You may obtain a copy of the License at
7 //
8 // http://www.apache.org/licenses/LICENSE-2.0
9 //
10 // Unless required by applicable law or agreed to in writing, software
11 // distributed under the License is distributed on an "AS IS" BASIS,
12 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 // See the License for the specific language governing permissions and
14 // limitations under the License.
15 //
16
17 #include "shill/net/netlink_socket.h"
18
19 #include <linux/netlink.h>
20
21 #include <algorithm>
22 #include <string>
23
24 #include <gmock/gmock.h>
25 #include <gtest/gtest.h>
26
27 #include "shill/net/byte_string.h"
28 #include "shill/net/mock_sockets.h"
29 #include "shill/net/netlink_message.h"
30
31 using std::min;
32 using std::string;
33 using testing::_;
34 using testing::Invoke;
35 using testing::Return;
36 using testing::Test;
37
38 namespace shill {
39
40 class NetlinkSocketTest;
41
42 const int kFakeFd = 99;
43
44 class NetlinkSocketTest : public Test {
45 public:
NetlinkSocketTest()46 NetlinkSocketTest() {}
~NetlinkSocketTest()47 virtual ~NetlinkSocketTest() {}
48
SetUp()49 virtual void SetUp() {
50 mock_sockets_ = new MockSockets();
51 netlink_socket_.sockets_.reset(mock_sockets_);
52 }
53
InitializeSocket(int fd)54 virtual void InitializeSocket(int fd) {
55 EXPECT_CALL(*mock_sockets_, Socket(PF_NETLINK, SOCK_DGRAM, NETLINK_GENERIC))
56 .WillOnce(Return(fd));
57 EXPECT_CALL(*mock_sockets_, SetReceiveBuffer(
58 fd, NetlinkSocket::kReceiveBufferSize)).WillOnce(Return(0));
59 EXPECT_CALL(*mock_sockets_, Bind(fd, _, sizeof(struct sockaddr_nl)))
60 .WillOnce(Return(0));
61 EXPECT_TRUE(netlink_socket_.Init());
62 }
63
64 protected:
65 MockSockets* mock_sockets_; // Owned by netlink_socket_.
66 NetlinkSocket netlink_socket_;
67 };
68
69 class FakeSocketRead {
70 public:
FakeSocketRead(const ByteString & next_read_string)71 explicit FakeSocketRead(const ByteString& next_read_string) {
72 next_read_string_ = next_read_string;
73 }
74 // Copies |len| bytes of |next_read_string_| into |buf| and clears
75 // |next_read_string_|.
FakeSuccessfulRead(int sockfd,void * buf,size_t len,int flags,struct sockaddr * src_addr,socklen_t * addrlen)76 ssize_t FakeSuccessfulRead(int sockfd, void* buf, size_t len, int flags,
77 struct sockaddr* src_addr, socklen_t* addrlen) {
78 if (!buf) {
79 return -1;
80 }
81 int read_bytes = min(len, next_read_string_.GetLength());
82 memcpy(buf, next_read_string_.GetConstData(), read_bytes);
83 next_read_string_.Clear();
84 return read_bytes;
85 }
86
87 private:
88 ByteString next_read_string_;
89 };
90
TEST_F(NetlinkSocketTest,InitWorkingTest)91 TEST_F(NetlinkSocketTest, InitWorkingTest) {
92 SetUp();
93 InitializeSocket(kFakeFd);
94 EXPECT_CALL(*mock_sockets_, Close(kFakeFd));
95 }
96
TEST_F(NetlinkSocketTest,InitBrokenSocketTest)97 TEST_F(NetlinkSocketTest, InitBrokenSocketTest) {
98 SetUp();
99
100 const int kBadFd = -1;
101 EXPECT_CALL(*mock_sockets_, Socket(PF_NETLINK, SOCK_DGRAM, NETLINK_GENERIC))
102 .WillOnce(Return(kBadFd));
103 EXPECT_CALL(*mock_sockets_, SetReceiveBuffer(_, _)).Times(0);
104 EXPECT_CALL(*mock_sockets_, Bind(_, _, _)).Times(0);
105 EXPECT_FALSE(netlink_socket_.Init());
106 }
107
TEST_F(NetlinkSocketTest,InitBrokenBufferTest)108 TEST_F(NetlinkSocketTest, InitBrokenBufferTest) {
109 SetUp();
110
111 EXPECT_CALL(*mock_sockets_, Socket(PF_NETLINK, SOCK_DGRAM, NETLINK_GENERIC))
112 .WillOnce(Return(kFakeFd));
113 EXPECT_CALL(*mock_sockets_, SetReceiveBuffer(
114 kFakeFd, NetlinkSocket::kReceiveBufferSize)).WillOnce(Return(-1));
115 EXPECT_CALL(*mock_sockets_, Bind(kFakeFd, _, sizeof(struct sockaddr_nl)))
116 .WillOnce(Return(0));
117 EXPECT_TRUE(netlink_socket_.Init());
118
119 // Destructor.
120 EXPECT_CALL(*mock_sockets_, Close(kFakeFd));
121 }
122
TEST_F(NetlinkSocketTest,InitBrokenBindTest)123 TEST_F(NetlinkSocketTest, InitBrokenBindTest) {
124 SetUp();
125
126 EXPECT_CALL(*mock_sockets_, Socket(PF_NETLINK, SOCK_DGRAM, NETLINK_GENERIC))
127 .WillOnce(Return(kFakeFd));
128 EXPECT_CALL(*mock_sockets_, SetReceiveBuffer(
129 kFakeFd, NetlinkSocket::kReceiveBufferSize)).WillOnce(Return(0));
130 EXPECT_CALL(*mock_sockets_, Bind(kFakeFd, _, sizeof(struct sockaddr_nl)))
131 .WillOnce(Return(-1));
132 EXPECT_CALL(*mock_sockets_, Close(kFakeFd)).WillOnce(Return(0));
133 EXPECT_FALSE(netlink_socket_.Init());
134 }
135
TEST_F(NetlinkSocketTest,SendMessageTest)136 TEST_F(NetlinkSocketTest, SendMessageTest) {
137 SetUp();
138 InitializeSocket(kFakeFd);
139
140 string message_string("This text is really arbitrary");
141 ByteString message(message_string.c_str(), message_string.size());
142
143 // Good Send.
144 EXPECT_CALL(*mock_sockets_,
145 Send(kFakeFd, message.GetConstData(), message.GetLength(), 0))
146 .WillOnce(Return(message.GetLength()));
147 EXPECT_TRUE(netlink_socket_.SendMessage(message));
148
149 // Short Send.
150 EXPECT_CALL(*mock_sockets_,
151 Send(kFakeFd, message.GetConstData(), message.GetLength(), 0))
152 .WillOnce(Return(message.GetLength() - 3));
153 EXPECT_FALSE(netlink_socket_.SendMessage(message));
154
155 // Bad Send.
156 EXPECT_CALL(*mock_sockets_,
157 Send(kFakeFd, message.GetConstData(), message.GetLength(), 0))
158 .WillOnce(Return(-1));
159 EXPECT_FALSE(netlink_socket_.SendMessage(message));
160
161 // Destructor.
162 EXPECT_CALL(*mock_sockets_, Close(kFakeFd));
163 }
164
TEST_F(NetlinkSocketTest,SequenceNumberTest)165 TEST_F(NetlinkSocketTest, SequenceNumberTest) {
166 SetUp();
167
168 // Just a sequence number.
169 const uint32_t arbitrary_number = 42;
170 netlink_socket_.sequence_number_ = arbitrary_number;
171 EXPECT_EQ(arbitrary_number+1, netlink_socket_.GetSequenceNumber());
172
173 // Make sure we don't go to |NetlinkMessage::kBroadcastSequenceNumber|.
174 netlink_socket_.sequence_number_ = NetlinkMessage::kBroadcastSequenceNumber;
175 EXPECT_NE(NetlinkMessage::kBroadcastSequenceNumber,
176 netlink_socket_.GetSequenceNumber());
177 }
178
TEST_F(NetlinkSocketTest,GoodRecvMessageTest)179 TEST_F(NetlinkSocketTest, GoodRecvMessageTest) {
180 SetUp();
181 InitializeSocket(kFakeFd);
182
183 ByteString message;
184 static const string next_read_string(
185 "Random text may include things like 'freaking fracking foo'.");
186 static const size_t read_size = next_read_string.size();
187 ByteString expected_results(next_read_string.c_str(), read_size);
188 FakeSocketRead fake_socket_read(expected_results);
189
190 // Expect one call to get the size...
191 EXPECT_CALL(*mock_sockets_,
192 RecvFrom(kFakeFd, _, _, MSG_TRUNC | MSG_PEEK, _, _))
193 .WillOnce(Return(read_size));
194
195 // ...and expect a second call to get the data.
196 EXPECT_CALL(*mock_sockets_,
197 RecvFrom(kFakeFd, _, read_size, 0, _, _))
198 .WillOnce(Invoke(&fake_socket_read, &FakeSocketRead::FakeSuccessfulRead));
199
200 EXPECT_TRUE(netlink_socket_.RecvMessage(&message));
201 EXPECT_TRUE(message.Equals(expected_results));
202
203 // Destructor.
204 EXPECT_CALL(*mock_sockets_, Close(kFakeFd));
205 }
206
TEST_F(NetlinkSocketTest,BadRecvMessageTest)207 TEST_F(NetlinkSocketTest, BadRecvMessageTest) {
208 SetUp();
209 InitializeSocket(kFakeFd);
210
211 ByteString message;
212 EXPECT_CALL(*mock_sockets_, RecvFrom(kFakeFd, _, _, _, _, _))
213 .WillOnce(Return(-1));
214 EXPECT_FALSE(netlink_socket_.RecvMessage(&message));
215
216 EXPECT_CALL(*mock_sockets_, Close(kFakeFd));
217 }
218
219 } // namespace shill.
220