1 //
2 // Copyright (C) 2013 The Android Open Source Project
3 //
4 // Licensed under the Apache License, Version 2.0 (the "License");
5 // you may not use this file except in compliance with the License.
6 // You may obtain a copy of the License at
7 //
8 // http://www.apache.org/licenses/LICENSE-2.0
9 //
10 // Unless required by applicable law or agreed to in writing, software
11 // distributed under the License is distributed on an "AS IS" BASIS,
12 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 // See the License for the specific language governing permissions and
14 // limitations under the License.
15 //
16
17 #include "shill/icmp.h"
18
19 #include <netinet/ip_icmp.h>
20
21 #include "shill/logging.h"
22 #include "shill/net/ip_address.h"
23 #include "shill/net/sockets.h"
24
25 namespace shill {
26
27 const int Icmp::kIcmpEchoCode = 0; // value specified in RFC 792.
28
Icmp()29 Icmp::Icmp()
30 : sockets_(new Sockets()),
31 socket_(-1) {}
32
~Icmp()33 Icmp::~Icmp() {}
34
Start()35 bool Icmp::Start() {
36 int socket = sockets_->Socket(AF_INET, SOCK_RAW, IPPROTO_ICMP);
37 if (socket == -1) {
38 PLOG(ERROR) << "Could not create ICMP socket";
39 Stop();
40 return false;
41 }
42 socket_ = socket;
43 socket_closer_.reset(new ScopedSocketCloser(sockets_.get(), socket_));
44
45 if (sockets_->SetNonBlocking(socket_) != 0) {
46 PLOG(ERROR) << "Could not set socket to be non-blocking";
47 Stop();
48 return false;
49 }
50
51 return true;
52 }
53
Stop()54 void Icmp::Stop() {
55 socket_closer_.reset();
56 socket_ = -1;
57 }
58
IsStarted() const59 bool Icmp::IsStarted() const {
60 return socket_closer_.get();
61 }
62
TransmitEchoRequest(const IPAddress & destination,uint16_t id,uint16_t seq_num)63 bool Icmp::TransmitEchoRequest(const IPAddress& destination, uint16_t id,
64 uint16_t seq_num) {
65 if (!IsStarted() && !Start()) {
66 return false;
67 }
68
69 if (!destination.IsValid()) {
70 LOG(ERROR) << "Destination address is not valid.";
71 return false;
72 }
73
74 if (destination.family() != IPAddress::kFamilyIPv4) {
75 NOTIMPLEMENTED() << "Only IPv4 destination addresses are implemented.";
76 return false;
77 }
78
79 struct icmphdr icmp_header;
80 memset(&icmp_header, 0, sizeof(icmp_header));
81 icmp_header.type = ICMP_ECHO;
82 icmp_header.code = kIcmpEchoCode;
83 icmp_header.un.echo.id = id;
84 icmp_header.un.echo.sequence = seq_num;
85 icmp_header.checksum = ComputeIcmpChecksum(icmp_header, sizeof(icmp_header));
86
87 struct sockaddr_in destination_address;
88 destination_address.sin_family = AF_INET;
89 CHECK_EQ(sizeof(destination_address.sin_addr.s_addr),
90 destination.GetLength());
91 memcpy(&destination_address.sin_addr.s_addr,
92 destination.address().GetConstData(),
93 sizeof(destination_address.sin_addr.s_addr));
94
95 int result = sockets_->SendTo(
96 socket_,
97 &icmp_header,
98 sizeof(icmp_header),
99 0,
100 reinterpret_cast<struct sockaddr*>(&destination_address),
101 sizeof(destination_address));
102 int expected_result = sizeof(icmp_header);
103 if (result != expected_result) {
104 if (result < 0) {
105 PLOG(ERROR) << "Socket sendto failed";
106 } else if (result < expected_result) {
107 LOG(ERROR) << "Socket sendto returned "
108 << result
109 << " which is less than the expected result "
110 << expected_result;
111 }
112 return false;
113 }
114
115 return true;
116 }
117
118 // static
ComputeIcmpChecksum(const struct icmphdr & hdr,size_t len)119 uint16_t Icmp::ComputeIcmpChecksum(const struct icmphdr& hdr, size_t len) {
120 // Compute Internet Checksum for "len" bytes beginning at location "hdr".
121 // Adapted directly from the canonical implementation in RFC 1071 Section 4.1.
122 uint32_t sum = 0;
123 const uint16_t* addr = reinterpret_cast<const uint16_t*>(&hdr);
124
125 while (len > 1) {
126 sum += *addr;
127 ++addr;
128 len -= sizeof(*addr);
129 }
130
131 // Add left-over byte, if any.
132 if (len > 0) {
133 sum += *reinterpret_cast<const uint8_t*>(addr);
134 }
135
136 // Fold 32-bit sum to 16 bits.
137 while (sum >> 16) {
138 sum = (sum & 0xffff) + (sum >> 16);
139 }
140
141 return static_cast<uint16_t>(~sum);
142 }
143
144 } // namespace shill
145