1 //
2 // Copyright (C) 2012 The Android Open Source Project
3 //
4 // Licensed under the Apache License, Version 2.0 (the "License");
5 // you may not use this file except in compliance with the License.
6 // You may obtain a copy of the License at
7 //
8 //      http://www.apache.org/licenses/LICENSE-2.0
9 //
10 // Unless required by applicable law or agreed to in writing, software
11 // distributed under the License is distributed on an "AS IS" BASIS,
12 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 // See the License for the specific language governing permissions and
14 // limitations under the License.
15 //
16 
17 #include "shill/arp_client.h"
18 
19 #include <linux/if_packet.h>
20 #include <net/ethernet.h>
21 #include <net/if_arp.h>
22 #include <netinet/in.h>
23 #include <string.h>
24 
25 #include "shill/arp_packet.h"
26 #include "shill/logging.h"
27 #include "shill/net/byte_string.h"
28 #include "shill/net/sockets.h"
29 
30 namespace shill {
31 
32 // ARP opcode is the last uint16_t in the ARP header.
33 const size_t ArpClient::kArpOpOffset = sizeof(arphdr) - sizeof(uint16_t);
34 
35 // The largest packet we expect is one with IPv6 addresses in it.
36 const size_t ArpClient::kMaxArpPacketLength =
37     sizeof(arphdr) + sizeof(in6_addr) * 2 + ETH_ALEN * 2;
38 
ArpClient(int interface_index)39 ArpClient::ArpClient(int interface_index)
40     : interface_index_(interface_index),
41       sockets_(new Sockets()),
42       socket_(-1) {}
43 
~ArpClient()44 ArpClient::~ArpClient() {}
45 
StartReplyListener()46 bool ArpClient::StartReplyListener() {
47   return Start(ARPOP_REPLY);
48 }
49 
StartRequestListener()50 bool ArpClient::StartRequestListener() {
51   return Start(ARPOP_REQUEST);
52 }
53 
Start(uint16_t arp_opcode)54 bool ArpClient::Start(uint16_t arp_opcode) {
55   if (!CreateSocket(arp_opcode)) {
56     LOG(ERROR) << "Could not open ARP socket.";
57     Stop();
58     return false;
59   }
60   return true;
61 }
62 
Stop()63 void ArpClient::Stop() {
64   socket_closer_.reset();
65 }
66 
67 
CreateSocket(uint16_t arp_opcode)68 bool ArpClient::CreateSocket(uint16_t arp_opcode) {
69   int socket = sockets_->Socket(PF_PACKET, SOCK_DGRAM, htons(ETHERTYPE_ARP));
70   if (socket == -1) {
71     PLOG(ERROR) << "Could not create ARP socket";
72     return false;
73   }
74   socket_ = socket;
75   socket_closer_.reset(new ScopedSocketCloser(sockets_.get(), socket_));
76 
77   // Create a packet filter incoming ARP packets.
78   const sock_filter arp_filter[] = {
79     // If a packet contains the ARP opcode we are looking for...
80     BPF_STMT(BPF_LD | BPF_H | BPF_ABS, kArpOpOffset),
81     BPF_JUMP(BPF_JMP | BPF_JEQ | BPF_K, arp_opcode, 0, 1),
82     // Return the the packet (up to largest expected packet size).
83     BPF_STMT(BPF_RET | BPF_K, kMaxArpPacketLength),
84     // Otherwise, drop it.
85     BPF_STMT(BPF_RET | BPF_K, 0),
86   };
87 
88   sock_fprog pf;
89   pf.filter = const_cast<sock_filter*>(arp_filter);
90   pf.len = arraysize(arp_filter);
91   if (sockets_->AttachFilter(socket_, &pf) != 0) {
92     PLOG(ERROR) << "Could not attach packet filter";
93     return false;
94   }
95 
96   if (sockets_->SetNonBlocking(socket_) != 0) {
97     PLOG(ERROR) << "Could not set socket to be non-blocking";
98     return false;
99   }
100 
101   sockaddr_ll socket_address;
102   memset(&socket_address, 0, sizeof(socket_address));
103   socket_address.sll_family = AF_PACKET;
104   socket_address.sll_protocol = htons(ETHERTYPE_ARP);
105   socket_address.sll_ifindex = interface_index_;
106 
107   if (sockets_->Bind(socket_,
108                      reinterpret_cast<struct sockaddr*>(&socket_address),
109                      sizeof(socket_address)) != 0) {
110     PLOG(ERROR) << "Could not bind socket to interface";
111     return false;
112   }
113 
114   return true;
115 }
116 
ReceivePacket(ArpPacket * packet,ByteString * sender) const117 bool ArpClient::ReceivePacket(ArpPacket* packet, ByteString* sender) const {
118   ByteString payload(kMaxArpPacketLength);
119   sockaddr_ll socket_address;
120   memset(&socket_address, 0, sizeof(socket_address));
121   socklen_t socklen = sizeof(socket_address);
122   int result = sockets_->RecvFrom(
123       socket_,
124       payload.GetData(),
125       payload.GetLength(),
126       0,
127       reinterpret_cast<struct sockaddr*>(&socket_address),
128       &socklen);
129   if (result < 0) {
130     PLOG(ERROR) << "Socket recvfrom failed";
131     return false;
132   }
133 
134   payload.Resize(result);
135   if (!packet->Parse(payload)) {
136     LOG(ERROR) << "Failed to parse ARP packet.";
137     return false;
138   }
139 
140   // The socket address returned may only be big enough to contain
141   // the hardware address of the sender.
142   CHECK(socklen >=
143         sizeof(socket_address) - sizeof(socket_address.sll_addr) + ETH_ALEN);
144   CHECK(socket_address.sll_halen == ETH_ALEN);
145   *sender = ByteString(
146       reinterpret_cast<const unsigned char*>(&socket_address.sll_addr),
147       socket_address.sll_halen);
148   return true;
149 }
150 
TransmitRequest(const ArpPacket & packet) const151 bool ArpClient::TransmitRequest(const ArpPacket& packet) const {
152   ByteString payload;
153   if (!packet.FormatRequest(&payload)) {
154     return false;
155   }
156 
157   sockaddr_ll socket_address;
158   memset(&socket_address, 0, sizeof(socket_address));
159   socket_address.sll_family = AF_PACKET;
160   socket_address.sll_protocol = htons(ETHERTYPE_ARP);
161   socket_address.sll_hatype = ARPHRD_ETHER;
162   socket_address.sll_halen = ETH_ALEN;
163   socket_address.sll_ifindex = interface_index_;
164 
165   ByteString remote_address = packet.remote_mac_address();
166   CHECK(sizeof(socket_address.sll_addr) >= remote_address.GetLength());
167   if (remote_address.IsZero()) {
168     // If the destination MAC address is unspecified, send the packet
169     // to the broadcast (all-ones) address.
170     remote_address.BitwiseInvert();
171   }
172   memcpy(&socket_address.sll_addr, remote_address.GetConstData(),
173          remote_address.GetLength());
174 
175   int result = sockets_->SendTo(
176       socket_,
177       payload.GetConstData(),
178       payload.GetLength(),
179       0,
180       reinterpret_cast<struct sockaddr*>(&socket_address),
181       sizeof(socket_address));
182   const int expected_result  = static_cast<int>(payload.GetLength());
183   if (result != expected_result) {
184     if (result < 0) {
185       PLOG(ERROR) << "Socket sendto failed";
186     } else if (result < static_cast<int>(payload.GetLength())) {
187       LOG(ERROR) << "Socket sendto returned "
188                  << result
189                  << " which is different from expected result "
190                  << expected_result;
191     }
192     return false;
193   }
194 
195   return true;
196 }
197 
198 }  // namespace shill
199