1 //
2 // Copyright (C) 2013 The Android Open Source Project
3 //
4 // Licensed under the Apache License, Version 2.0 (the "License");
5 // you may not use this file except in compliance with the License.
6 // You may obtain a copy of the License at
7 //
8 //      http://www.apache.org/licenses/LICENSE-2.0
9 //
10 // Unless required by applicable law or agreed to in writing, software
11 // distributed under the License is distributed on an "AS IS" BASIS,
12 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 // See the License for the specific language governing permissions and
14 // limitations under the License.
15 //
16 
17 #include "shill/icmp.h"
18 
19 #include <netinet/ip_icmp.h>
20 
21 #include "shill/logging.h"
22 #include "shill/net/ip_address.h"
23 #include "shill/net/sockets.h"
24 
25 namespace shill {
26 
27 const int Icmp::kIcmpEchoCode = 0;  // value specified in RFC 792.
28 
Icmp()29 Icmp::Icmp()
30     : sockets_(new Sockets()),
31       socket_(-1) {}
32 
~Icmp()33 Icmp::~Icmp() {}
34 
Start()35 bool Icmp::Start() {
36   int socket = sockets_->Socket(AF_INET, SOCK_RAW, IPPROTO_ICMP);
37   if (socket == -1) {
38     PLOG(ERROR) << "Could not create ICMP socket";
39     Stop();
40     return false;
41   }
42   socket_ = socket;
43   socket_closer_.reset(new ScopedSocketCloser(sockets_.get(), socket_));
44 
45   if (sockets_->SetNonBlocking(socket_) != 0) {
46     PLOG(ERROR) << "Could not set socket to be non-blocking";
47     Stop();
48     return false;
49   }
50 
51   return true;
52 }
53 
Stop()54 void Icmp::Stop() {
55   socket_closer_.reset();
56   socket_ = -1;
57 }
58 
IsStarted() const59 bool Icmp::IsStarted() const {
60   return socket_closer_.get();
61 }
62 
TransmitEchoRequest(const IPAddress & destination,uint16_t id,uint16_t seq_num)63 bool Icmp::TransmitEchoRequest(const IPAddress& destination, uint16_t id,
64                                uint16_t seq_num) {
65   if (!IsStarted() && !Start()) {
66     return false;
67   }
68 
69   if (!destination.IsValid()) {
70     LOG(ERROR) << "Destination address is not valid.";
71     return false;
72   }
73 
74   if (destination.family() != IPAddress::kFamilyIPv4) {
75     NOTIMPLEMENTED() << "Only IPv4 destination addresses are implemented.";
76     return false;
77   }
78 
79   struct icmphdr icmp_header;
80   memset(&icmp_header, 0, sizeof(icmp_header));
81   icmp_header.type = ICMP_ECHO;
82   icmp_header.code = kIcmpEchoCode;
83   icmp_header.un.echo.id = id;
84   icmp_header.un.echo.sequence = seq_num;
85   icmp_header.checksum = ComputeIcmpChecksum(icmp_header, sizeof(icmp_header));
86 
87   struct sockaddr_in destination_address;
88   destination_address.sin_family = AF_INET;
89   CHECK_EQ(sizeof(destination_address.sin_addr.s_addr),
90            destination.GetLength());
91   memcpy(&destination_address.sin_addr.s_addr,
92          destination.address().GetConstData(),
93          sizeof(destination_address.sin_addr.s_addr));
94 
95   int result = sockets_->SendTo(
96       socket_,
97       &icmp_header,
98       sizeof(icmp_header),
99       0,
100       reinterpret_cast<struct sockaddr*>(&destination_address),
101       sizeof(destination_address));
102   int expected_result = sizeof(icmp_header);
103   if (result != expected_result) {
104     if (result < 0) {
105       PLOG(ERROR) << "Socket sendto failed";
106     } else if (result < expected_result) {
107       LOG(ERROR) << "Socket sendto returned "
108                  << result
109                  << " which is less than the expected result "
110                  << expected_result;
111     }
112     return false;
113   }
114 
115   return true;
116 }
117 
118 // static
ComputeIcmpChecksum(const struct icmphdr & hdr,size_t len)119 uint16_t Icmp::ComputeIcmpChecksum(const struct icmphdr& hdr, size_t len) {
120   // Compute Internet Checksum for "len" bytes beginning at location "hdr".
121   // Adapted directly from the canonical implementation in RFC 1071 Section 4.1.
122   uint32_t sum = 0;
123   const uint16_t* addr = reinterpret_cast<const uint16_t*>(&hdr);
124 
125   while (len > 1) {
126     sum += *addr;
127     ++addr;
128     len -= sizeof(*addr);
129   }
130 
131   // Add left-over byte, if any.
132   if (len > 0) {
133     sum += *reinterpret_cast<const uint8_t*>(addr);
134   }
135 
136   // Fold 32-bit sum to 16 bits.
137   while (sum >> 16) {
138     sum = (sum & 0xffff) + (sum >> 16);
139   }
140 
141   return static_cast<uint16_t>(~sum);
142 }
143 
144 }  // namespace shill
145