1 /*
2  * Copyright (C) 2017 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "common/libs/net/netlink_request.h"
18 
19 #include <linux/netlink.h>
20 #include <linux/rtnetlink.h>
21 
22 #include <cstring>
23 #include <ios>
24 #include <ostream>
25 
26 #include <gmock/gmock.h>
27 #include <gtest/gtest.h>
28 
29 using ::testing::ElementsAreArray;
30 using ::testing::MatchResultListener;
31 using ::testing::Return;
32 
33 namespace cuttlefish {
34 namespace {
klog_write(int,const char *,...)35 extern "C" void klog_write(int /* level */, const char* /* format */, ...) {}
36 
37 // Dump hex buffer to test log.
Dump(MatchResultListener * result_listener,const char * title,const uint8_t * data,size_t length)38 void Dump(MatchResultListener* result_listener, const char* title,
39           const uint8_t* data, size_t length) {
40   for (size_t item = 0; item < length;) {
41     *result_listener << title;
42     do {
43       result_listener->stream()->width(2);
44       result_listener->stream()->fill('0');
45       *result_listener << std::hex << +data[item] << " ";
46       ++item;
47     } while (item & 0xf);
48     *result_listener << "\n";
49   }
50 }
51 
52 // Compare two memory areas byte by byte, print information about first
53 // difference. Dumps both bufferst to user log.
Compare(MatchResultListener * result_listener,const uint8_t * exp,const uint8_t * act,size_t length)54 bool Compare(MatchResultListener* result_listener,
55              const uint8_t* exp, const uint8_t* act, size_t length) {
56   for (size_t index = 0; index < length; ++index) {
57     if (exp[index] != act[index]) {
58       *result_listener << "\nUnexpected data at offset " << index << "\n";
59       Dump(result_listener, "Data Expected: ", exp, length);
60       Dump(result_listener, "  Data Actual: ", act, length);
61       return false;
62     }
63   }
64 
65   return true;
66 }
67 
68 // Matcher validating Netlink Request data.
69 MATCHER_P2(RequestDataIs, data, length, "Matches expected request data") {
70   size_t offset = sizeof(nlmsghdr);
71   if (offset + length != arg.RequestLength()) {
72     *result_listener << "Unexpected request length: "
73                      << arg.RequestLength() - offset << " vs " << length;
74     return false;
75   }
76 
77   // Note: Request begins with header (nlmsghdr). Header is not covered by this
78   // call.
79   const uint8_t* exp_data = static_cast<const uint8_t*>(
80       static_cast<const void*>(data));
81   const uint8_t* act_data = static_cast<const uint8_t*>(arg.RequestData());
82   return Compare(
83       result_listener, exp_data, &act_data[offset], length);
84 }
85 
86 MATCHER_P4(RequestHeaderIs, length, type, flags, seq,
87            "Matches request header") {
88   nlmsghdr* header = static_cast<nlmsghdr*>(arg.RequestData());
89   if (arg.RequestLength() < sizeof(header)) {
90     *result_listener << "Malformed header: too short.";
91     return false;
92   }
93 
94   if (header->nlmsg_len != length) {
95     *result_listener << "Invalid message length: "
96                      << header->nlmsg_len << " vs " << length;
97     return false;
98   }
99 
100   if (header->nlmsg_type != type) {
101     *result_listener << "Invalid header type: "
102                      << header->nlmsg_type << " vs " << type;
103     return false;
104   }
105 
106   if (header->nlmsg_flags != flags) {
107     *result_listener << "Invalid header flags: "
108                      << header->nlmsg_flags << " vs " << flags;
109     return false;
110   }
111 
112   if (header->nlmsg_seq != seq) {
113     *result_listener << "Invalid header sequence number: "
114                      << header->nlmsg_seq << " vs " << seq;
115     return false;
116   }
117 
118   return true;
119 }
120 }  // namespace
121 
TEST(NetlinkClientTest,BasicStringNode)122 TEST(NetlinkClientTest, BasicStringNode) {
123   constexpr uint16_t kDummyTag = 0xfce2;
124   constexpr char kLongString[] = "long string";
125 
126   struct {
127     // 11 bytes of text + padding 0 + 4 bytes of header.
128     const uint16_t attr_length = 0x10;
129     const uint16_t attr_type = kDummyTag;
130     char text[sizeof(kLongString)];  // sizeof includes padding 0.
131   } expected;
132 
133   memcpy(&expected.text, kLongString, sizeof(kLongString));
134 
135   NetlinkRequest request(RTM_SETLINK, 0);
136   request.AddString(kDummyTag, kLongString);
137   EXPECT_THAT(request, RequestDataIs(&expected, sizeof(expected)));
138 }
139 
TEST(NetlinkClientTest,BasicIntNode)140 TEST(NetlinkClientTest, BasicIntNode) {
141   // Basic { Dummy: Value } test.
142   constexpr uint16_t kDummyTag = 0xfce2;
143   constexpr int32_t kValue = 0x1badd00d;
144 
145   struct {
146     const uint16_t attr_length = 0x8;  // 4 bytes of value + 4 bytes of header.
147     const uint16_t attr_type = kDummyTag;
148     const uint32_t attr_value = kValue;
149   } expected;
150 
151   NetlinkRequest request(RTM_SETLINK, 0);
152   request.AddInt(kDummyTag, kValue);
153   EXPECT_THAT(request, RequestDataIs(&expected, sizeof(expected)));
154 }
155 
TEST(NetlinkClientTest,AllIntegerTypes)156 TEST(NetlinkClientTest, AllIntegerTypes) {
157   // Basic { Dummy: Value } test.
158   constexpr uint16_t kDummyTag = 0xfce2;
159   constexpr uint8_t kValue = 0x1b;
160 
161   // The attribute is necessary for correct binary alignment.
162   constexpr struct __attribute__((__packed__)) {
163     uint16_t attr_length_i64 = 12;
164     uint16_t attr_type_i64 = kDummyTag;
165     int64_t attr_value_i64 = kValue;
166     uint16_t attr_length_i32 = 8;
167     uint16_t attr_type_i32 = kDummyTag + 1;
168     int32_t attr_value_i32 = kValue;
169     uint16_t attr_length_i16 = 6;
170     uint16_t attr_type_i16 = kDummyTag + 2;
171     int16_t attr_value_i16 = kValue;
172     uint8_t attr_padding_i16[2] = {0, 0};
173     uint16_t attr_length_i8 = 5;
174     uint16_t attr_type_i8 = kDummyTag + 3;
175     int8_t attr_value_i8 = kValue;
176     uint8_t attr_padding_i8[3] = {0, 0, 0};
177     uint16_t attr_length_u64 = 12;
178     uint16_t attr_type_u64 = kDummyTag + 4;
179     uint64_t attr_value_u64 = kValue;
180     uint16_t attr_length_u32 = 8;
181     uint16_t attr_type_u32 = kDummyTag + 5;
182     uint32_t attr_value_u32 = kValue;
183     uint16_t attr_length_u16 = 6;
184     uint16_t attr_type_u16 = kDummyTag + 6;
185     uint16_t attr_value_u16 = kValue;
186     uint8_t attr_padding_u16[2] = {0, 0};
187     uint16_t attr_length_u8 = 5;
188     uint16_t attr_type_u8 = kDummyTag + 7;
189     uint8_t attr_value_u8 = kValue;
190     uint8_t attr_padding_u8[3] = {0, 0, 0};
191   } expected = {};
192 
193   NetlinkRequest request(RTM_SETLINK, 0);
194   request.AddInt<int64_t>(kDummyTag, kValue);
195   request.AddInt<int32_t>(kDummyTag + 1, kValue);
196   request.AddInt<int16_t>(kDummyTag + 2, kValue);
197   request.AddInt<int8_t>(kDummyTag + 3, kValue);
198   request.AddInt<uint64_t>(kDummyTag + 4, kValue);
199   request.AddInt<uint32_t>(kDummyTag + 5, kValue);
200   request.AddInt<int16_t>(kDummyTag + 6, kValue);
201   request.AddInt<int8_t>(kDummyTag + 7, kValue);
202 
203   EXPECT_THAT(request, RequestDataIs(&expected, sizeof(expected)));
204 }
205 
TEST(NetlinkClientTest,SingleList)206 TEST(NetlinkClientTest, SingleList) {
207   // List: { Dummy: Value}
208   constexpr uint16_t kDummyTag = 0xfce2;
209   constexpr uint16_t kListTag = 0xcafe;
210   constexpr int32_t kValue = 0x1badd00d;
211 
212   struct {
213     const uint16_t list_length = 0xc;
214     const uint16_t list_type = kListTag;
215     const uint16_t attr_length = 0x8;  // 4 bytes of value + 4 bytes of header.
216     const uint16_t attr_type = kDummyTag;
217     const uint32_t attr_value = kValue;
218   } expected;
219 
220   NetlinkRequest request(RTM_SETLINK, 0);
221   request.PushList(kListTag);
222   request.AddInt(kDummyTag, kValue);
223   request.PopList();
224 
225   EXPECT_THAT(request, RequestDataIs(&expected, sizeof(expected)));
226 }
227 
TEST(NetlinkClientTest,NestedList)228 TEST(NetlinkClientTest, NestedList) {
229   // List1: { List2: { Dummy: Value}}
230   constexpr uint16_t kDummyTag = 0xfce2;
231   constexpr uint16_t kList1Tag = 0xcafe;
232   constexpr uint16_t kList2Tag = 0xfeed;
233   constexpr int32_t kValue = 0x1badd00d;
234 
235   struct {
236     const uint16_t list1_length = 0x10;
237     const uint16_t list1_type = kList1Tag;
238     const uint16_t list2_length = 0xc;
239     const uint16_t list2_type = kList2Tag;
240     const uint16_t attr_length = 0x8;
241     const uint16_t attr_type = kDummyTag;
242     const uint32_t attr_value = kValue;
243   } expected;
244 
245   NetlinkRequest request(RTM_SETLINK, 0);
246   request.PushList(kList1Tag);
247   request.PushList(kList2Tag);
248   request.AddInt(kDummyTag, kValue);
249   request.PopList();
250   request.PopList();
251 
252   EXPECT_THAT(request, RequestDataIs(&expected, sizeof(expected)));
253 }
254 
TEST(NetlinkClientTest,ListSequence)255 TEST(NetlinkClientTest, ListSequence) {
256   // List1: { Dummy1: Value1}, List2: { Dummy2: Value2 }
257   constexpr uint16_t kDummy1Tag = 0xfce2;
258   constexpr uint16_t kDummy2Tag = 0xfd38;
259   constexpr uint16_t kList1Tag = 0xcafe;
260   constexpr uint16_t kList2Tag = 0xfeed;
261   constexpr int32_t kValue1 = 0x1badd00d;
262   constexpr int32_t kValue2 = 0xfee1;
263 
264   struct {
265     const uint16_t list1_length = 0xc;
266     const uint16_t list1_type = kList1Tag;
267     const uint16_t attr1_length = 0x8;
268     const uint16_t attr1_type = kDummy1Tag;
269     const uint32_t attr1_value = kValue1;
270     const uint16_t list2_length = 0xc;
271     const uint16_t list2_type = kList2Tag;
272     const uint16_t attr2_length = 0x8;
273     const uint16_t attr2_type = kDummy2Tag;
274     const uint32_t attr2_value = kValue2;
275   } expected;
276 
277   NetlinkRequest request(RTM_SETLINK, 0);
278   request.PushList(kList1Tag);
279   request.AddInt(kDummy1Tag, kValue1);
280   request.PopList();
281   request.PushList(kList2Tag);
282   request.AddInt(kDummy2Tag, kValue2);
283   request.PopList();
284 
285   EXPECT_THAT(request, RequestDataIs(&expected, sizeof(expected)));
286 }
287 
TEST(NetlinkClientTest,ComplexList)288 TEST(NetlinkClientTest, ComplexList) {
289   // List1: { List2: { Dummy1: Value1 }, Dummy2: Value2 }
290   constexpr uint16_t kDummy1Tag = 0xfce2;
291   constexpr uint16_t kDummy2Tag = 0xfd38;
292   constexpr uint16_t kList1Tag = 0xcafe;
293   constexpr uint16_t kList2Tag = 0xfeed;
294   constexpr int32_t kValue1 = 0x1badd00d;
295   constexpr int32_t kValue2 = 0xfee1;
296 
297   struct {
298     const uint16_t list1_length = 0x18;
299     const uint16_t list1_type = kList1Tag;
300     const uint16_t list2_length = 0xc;  // Note, this only covers until kValue1.
301     const uint16_t list2_type = kList2Tag;
302     const uint16_t attr1_length = 0x8;
303     const uint16_t attr1_type = kDummy1Tag;
304     const uint32_t attr1_value = kValue1;
305     const uint16_t attr2_length = 0x8;
306     const uint16_t attr2_type = kDummy2Tag;
307     const uint32_t attr2_value = kValue2;
308   } expected;
309 
310   NetlinkRequest request(RTM_SETLINK, 0);
311   request.PushList(kList1Tag);
312   request.PushList(kList2Tag);
313   request.AddInt(kDummy1Tag, kValue1);
314   request.PopList();
315   request.AddInt(kDummy2Tag, kValue2);
316   request.PopList();
317 
318   EXPECT_THAT(request, RequestDataIs(&expected, sizeof(expected)));
319 }
320 
TEST(NetlinkClientTest,SimpleNetlinkCreateHeader)321 TEST(NetlinkClientTest, SimpleNetlinkCreateHeader) {
322   NetlinkRequest request(RTM_NEWLINK, NLM_F_CREATE | NLM_F_EXCL);
323   constexpr char kValue[] = "random string";
324   request.AddString(0, kValue);  // Have something to work with.
325 
326   constexpr size_t kMsgLength =
327       sizeof(nlmsghdr) + sizeof(nlattr) + RTA_ALIGN(sizeof(kValue));
328   uint32_t base_seq = request.SeqNo();
329 
330   EXPECT_THAT(request, RequestHeaderIs(
331       kMsgLength,
332       RTM_NEWLINK,
333       NLM_F_ACK | NLM_F_CREATE | NLM_F_EXCL | NLM_F_REQUEST,
334       base_seq));
335 
336   NetlinkRequest request2(RTM_NEWLINK, NLM_F_CREATE | NLM_F_EXCL);
337   request2.AddString(0, kValue);  // Have something to work with.
338   EXPECT_THAT(request2, RequestHeaderIs(
339       kMsgLength,
340       RTM_NEWLINK,
341       NLM_F_ACK | NLM_F_CREATE | NLM_F_EXCL | NLM_F_REQUEST,
342       base_seq + 1));
343 }
344 
TEST(NetlinkClientTest,SimpleNetlinkUpdateHeader)345 TEST(NetlinkClientTest, SimpleNetlinkUpdateHeader) {
346   NetlinkRequest request(RTM_SETLINK, 0);
347   constexpr char kValue[] = "random string";
348   request.AddString(0, kValue);  // Have something to work with.
349 
350   constexpr size_t kMsgLength =
351       sizeof(nlmsghdr) + sizeof(nlattr) + RTA_ALIGN(sizeof(kValue));
352   uint32_t base_seq = request.SeqNo();
353 
354   EXPECT_THAT(request, RequestHeaderIs(
355       kMsgLength, RTM_SETLINK, NLM_F_REQUEST | NLM_F_ACK, base_seq));
356 
357   NetlinkRequest request2(RTM_SETLINK, 0);
358   request2.AddString(0, kValue);  // Have something to work with.
359   EXPECT_THAT(request2, RequestHeaderIs(
360       kMsgLength, RTM_SETLINK, NLM_F_REQUEST | NLM_F_ACK, base_seq + 1));
361 }
362 
363 }  // namespace cuttlefish
364