1 //
2 // Copyright (C) 2015 The Android Open Source Project
3 //
4 // Licensed under the Apache License, Version 2.0 (the "License");
5 // you may not use this file except in compliance with the License.
6 // You may obtain a copy of the License at
7 //
8 // http://www.apache.org/licenses/LICENSE-2.0
9 //
10 // Unless required by applicable law or agreed to in writing, software
11 // distributed under the License is distributed on an "AS IS" BASIS,
12 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 // See the License for the specific language governing permissions and
14 // limitations under the License.
15 //
16
17 #include "dhcp_client/dhcpv4.h"
18
19 #include <linux/filter.h>
20 #include <linux/if_packet.h>
21 #include <net/ethernet.h>
22 #include <net/if.h>
23 #include <net/if_arp.h>
24 #include <netinet/ip.h>
25 #include <netinet/udp.h>
26
27 #include <random>
28
29 #include <base/bind.h>
30 #include <base/logging.h>
31
32 #include "dhcp_client/dhcp_message.h"
33
34 using base::Bind;
35 using base::Unretained;
36 using shill::ByteString;
37 using shill::IOHandlerFactoryContainer;
38
39 namespace dhcp_client {
40
41 namespace {
42 // UDP port numbers for DHCP.
43 const uint16_t kDHCPServerPort = 67;
44 const uint16_t kDHCPClientPort = 68;
45
46 const int kInvalidSocketDescriptor = -1;
47
48 // RFC 791: the minimum value for a correct header is 20 octets.
49 // The maximum value is 60 octets.
50 const size_t kIPHeaderMinLength = 20;
51 const size_t kIPHeaderMaxLength = 60;
52
53 // Socket filter for dhcp packet.
54 const sock_filter dhcp_bpf_filter[] = {
55 BPF_STMT(BPF_LD + BPF_B + BPF_ABS, 23 - ETH_HLEN),
56 BPF_JUMP(BPF_JMP + BPF_JEQ + BPF_K, IPPROTO_UDP, 0, 6),
57 BPF_STMT(BPF_LD + BPF_H + BPF_ABS, 20 - ETH_HLEN),
58 BPF_JUMP(BPF_JMP + BPF_JSET + BPF_K, 0x1fff, 4, 0),
59 BPF_STMT(BPF_LDX + BPF_B + BPF_MSH, 14 - ETH_HLEN),
60 BPF_STMT(BPF_LD + BPF_H + BPF_IND, 16 - ETH_HLEN),
61 BPF_JUMP(BPF_JMP + BPF_JEQ + BPF_K, kDHCPClientPort, 0, 1),
62 BPF_STMT(BPF_RET + BPF_K, 0x0fffffff),
63 BPF_STMT(BPF_RET + BPF_K, 0),
64 };
65 const int dhcp_bpf_filter_len =
66 sizeof(dhcp_bpf_filter) / sizeof(dhcp_bpf_filter[0]);
67 } // namespace
68
DHCPV4(const std::string & interface_name,const ByteString & hardware_address,unsigned int interface_index,const std::string & network_id,bool request_hostname,bool arp_gateway,bool unicast_arp,EventDispatcherInterface * event_dispatcher)69 DHCPV4::DHCPV4(const std::string& interface_name,
70 const ByteString& hardware_address,
71 unsigned int interface_index,
72 const std::string& network_id,
73 bool request_hostname,
74 bool arp_gateway,
75 bool unicast_arp,
76 EventDispatcherInterface* event_dispatcher)
77 : interface_name_(interface_name),
78 hardware_address_(hardware_address),
79 interface_index_(interface_index),
80 network_id_(network_id),
81 request_hostname_(request_hostname),
82 arp_gateway_(arp_gateway),
83 unicast_arp_(unicast_arp),
84 event_dispatcher_(event_dispatcher),
85 io_handler_factory_(
86 IOHandlerFactoryContainer::GetInstance()->GetIOHandlerFactory()),
87 state_(State::INIT),
88 from_(INADDR_ANY),
89 to_(INADDR_BROADCAST),
90 socket_(kInvalidSocketDescriptor),
91 sockets_(new shill::Sockets()),
92 random_engine_(time(nullptr)) {
93 }
94
~DHCPV4()95 DHCPV4::~DHCPV4() {
96 Stop();
97 }
98
ParseRawPacket(shill::InputData * data)99 void DHCPV4::ParseRawPacket(shill::InputData* data) {
100 if (data->len < sizeof(iphdr)) {
101 LOG(ERROR) << "Invalid packet length from buffer";
102 return;
103 }
104 // The socket filter has finished part the header validation.
105 // This function will perform the remaining part.
106 int header_len = ValidatePacketHeader(data->buf, data->len);
107 if (header_len == -1) {
108 return;
109 }
110 unsigned char* buffer = data->buf + header_len;
111 DHCPMessage msg;
112 if (!DHCPMessage::InitFromBuffer(buffer, data->len - header_len, &msg)) {
113 LOG(ERROR) << "Failed to initialize DHCP message from buffer";
114 return;
115 }
116 // In INIT state the client ignores all messages from server.
117 if (state_ == State::INIT) {
118 return;
119 }
120 // Check transaction id with the existing one.
121 if (msg.transaction_id() != transaction_id_) {
122 LOG(ERROR) << "Transaction id(xid) doesn't match";
123 return;
124 }
125 uint8_t message_type = msg.message_type();
126 switch (message_type) {
127 case kDHCPMessageTypeOffer:
128 HandleOffer(msg);
129 break;
130 case kDHCPMessageTypeAck:
131 HandleAck(msg);
132 break;
133 case kDHCPMessageTypeNak:
134 HandleNak(msg);
135 break;
136 default:
137 LOG(ERROR) << "Invalid message type: "
138 << static_cast<int>(message_type);
139 }
140 }
141
OnReadError(const std::string & error_msg)142 void DHCPV4::OnReadError(const std::string& error_msg) {
143 LOG(INFO) << __func__;
144 }
145
Start()146 bool DHCPV4::Start() {
147 if (!CreateRawSocket()) {
148 return false;
149 }
150
151 input_handler_.reset(io_handler_factory_->CreateIOInputHandler(
152 socket_,
153 Bind(&DHCPV4::ParseRawPacket, Unretained(this)),
154 Bind(&DHCPV4::OnReadError, Unretained(this))));
155 return true;
156 }
157
Stop()158 void DHCPV4::Stop() {
159 input_handler_.reset();
160 if (socket_ != kInvalidSocketDescriptor) {
161 sockets_->Close(socket_);
162 }
163 }
164
CreateRawSocket()165 bool DHCPV4::CreateRawSocket() {
166 int fd = sockets_->Socket(PF_PACKET,
167 SOCK_DGRAM | SOCK_CLOEXEC | SOCK_NONBLOCK,
168 htons(ETHERTYPE_IP));
169 if (fd == kInvalidSocketDescriptor) {
170 PLOG(ERROR) << "Failed to create socket";
171 return false;
172 }
173 shill::ScopedSocketCloser socket_closer(sockets_.get(), fd);
174
175 // Apply the socket filter.
176 sock_fprog pf;
177 memset(&pf, 0, sizeof(pf));
178 pf.filter = const_cast<sock_filter*>(dhcp_bpf_filter);
179 pf.len = dhcp_bpf_filter_len;
180
181 if (sockets_->AttachFilter(fd, &pf) != 0) {
182 PLOG(ERROR) << "Failed to attach filter";
183 return false;
184 }
185
186 if (sockets_->ReuseAddress(fd) == -1) {
187 PLOG(ERROR) << "Failed to reuse socket address";
188 return false;
189 }
190
191 if (sockets_->BindToDevice(fd, interface_name_) < 0) {
192 PLOG(ERROR) << "Failed to bind socket to device";
193 return false;
194 }
195
196 struct sockaddr_ll local;
197 memset(&local, 0, sizeof(local));
198 local.sll_family = PF_PACKET;
199 local.sll_protocol = htons(ETHERTYPE_IP);
200 local.sll_ifindex = static_cast<int>(interface_index_);
201
202 if (sockets_->Bind(fd,
203 reinterpret_cast<struct sockaddr*>(&local),
204 sizeof(local)) < 0) {
205 PLOG(ERROR) << "Failed to bind to address";
206 return false;
207 }
208
209 socket_ = socket_closer.Release();
210 return true;
211 }
212
HandleOffer(const DHCPMessage & msg)213 void DHCPV4::HandleOffer(const DHCPMessage& msg) {
214 return;
215 }
216
HandleAck(const DHCPMessage & msg)217 void DHCPV4::HandleAck(const DHCPMessage& msg) {
218 return;
219 }
220
HandleNak(const DHCPMessage & msg)221 void DHCPV4::HandleNak(const DHCPMessage& msg) {
222 return;
223 }
224
MakeRawPacket(const DHCPMessage & message,ByteString * output)225 bool DHCPV4::MakeRawPacket(const DHCPMessage& message, ByteString* output) {
226 ByteString payload;
227 if (!message.Serialize(&payload)) {
228 LOG(ERROR) << "Failed to serialzie dhcp message";
229 return false;
230 }
231 const size_t header_len = sizeof(struct iphdr) + sizeof(struct udphdr);
232 const size_t payload_len = payload.GetLength();
233
234 char buffer[header_len + payload_len];
235 memset(buffer, 0, header_len + payload_len);
236 struct iphdr* ip = reinterpret_cast<struct iphdr*>(buffer);
237 struct udphdr* udp = reinterpret_cast<struct udphdr*>(buffer + sizeof(*ip));
238
239 if (!payload.CopyData(payload_len, buffer + header_len)) {
240 LOG(ERROR) << "Failed to copy data from payload";
241 return false;
242 }
243 udp->uh_sport = htons(kDHCPClientPort);
244 udp->uh_dport = htons(kDHCPServerPort);
245 udp->uh_ulen =
246 htons(static_cast<uint16_t>(sizeof(*udp) + payload.GetLength()));
247
248 // Fill pseudo header (for UDP checksum computing):
249 // Protocol.
250 ip->protocol = IPPROTO_UDP;
251 // Source IP address.
252 ip->saddr = htonl(from_);
253 // Destination IP address.
254 ip->daddr = htonl(to_);
255 // Total length, use udp packet length for pseudo header.
256 ip->tot_len = udp->uh_ulen;
257 // Calculate udp checksum based on:
258 // IPV4 pseudo header, UDP header, and payload.
259 udp->uh_sum = htons(DHCPMessage::ComputeChecksum(
260 reinterpret_cast<const uint8_t*>(buffer),
261 header_len + payload_len));
262
263 // IP version.
264 ip->version = IPVERSION;
265 // IP header length.
266 ip->ihl = sizeof(*ip) >> 2;
267 // Fragment offset field.
268 // The DHCP packet is always smaller than MTU,
269 // so fragmentation is not needed.
270 ip->frag_off = 0;
271 // Identification.
272 ip->id = static_cast<uint16_t>(
273 std::uniform_int_distribution<unsigned int>()(
274 random_engine_) % UINT16_MAX + 1);
275 // Time to live.
276 ip->ttl = IPDEFTTL;
277 // Total length.
278 ip->tot_len = htons(static_cast<uint16_t>(header_len+ payload.GetLength()));
279 // Calculate IP Checksum only based on IP header.
280 ip->check = htons(DHCPMessage::ComputeChecksum(
281 reinterpret_cast<const uint8_t*>(ip),
282 sizeof(*ip)));
283
284 *output = ByteString(buffer, header_len + payload_len);
285 return true;
286 }
287
SendRawPacket(const ByteString & packet)288 bool DHCPV4::SendRawPacket(const ByteString& packet) {
289 struct sockaddr_ll remote;
290 memset(&remote, 0, sizeof(remote));
291 remote.sll_family = AF_PACKET;
292 remote.sll_protocol = htons(ETHERTYPE_IP);
293 remote.sll_ifindex = interface_index_;
294 remote.sll_hatype = htons(ARPHRD_ETHER);
295 // Use broadcast hardware address.
296 remote.sll_halen = IFHWADDRLEN;
297 memset(remote.sll_addr, 0xff, IFHWADDRLEN);
298
299 size_t result = sockets_->SendTo(socket_,
300 packet.GetConstData(),
301 packet.GetLength(),
302 0,
303 reinterpret_cast<struct sockaddr *>(&remote),
304 sizeof(remote));
305
306 if (result != packet.GetLength()) {
307 PLOG(ERROR) << "Socket sento failed";
308 return false;
309 }
310 return true;
311 }
312
ValidatePacketHeader(const unsigned char * buffer,size_t len)313 int DHCPV4::ValidatePacketHeader(const unsigned char* buffer, size_t len) {
314 const struct iphdr* ip =
315 reinterpret_cast<const struct iphdr*>(buffer);
316 const size_t ip_header_len = static_cast<size_t>(ip->ihl) << 2;
317 if (ip_header_len < kIPHeaderMinLength ||
318 ip_header_len > kIPHeaderMaxLength) {
319 LOG(ERROR) << "Invalid Internet Header Length: "
320 << ip_header_len << " bytes";
321 return -1;
322 }
323 if (ip->tot_len != len) {
324 LOG(ERROR) << "Invalid IP total length";
325 return -1;
326 }
327 // TODO(nywang): Validate other ip header fields.
328
329 const struct udphdr* udp =
330 reinterpret_cast<const struct udphdr*>(buffer + ip_header_len);
331 if (udp->uh_sport != htons(kDHCPServerPort) ||
332 udp->uh_dport != htons(kDHCPClientPort)) {
333 LOG(ERROR) << "Invlaid UDP ports";
334 return -1;
335 }
336 if (udp->uh_ulen != len - ip_header_len) {
337 LOG(ERROR) << "Invalid UDP total length";
338 return -1;
339 }
340 // TODO(nywang): Validate UDP checksum.
341
342 return ip_header_len + sizeof(*udp);
343 }
344
345 } // namespace dhcp_client
346
347