1 //
2 // Copyright (C) 2013 The Android Open Source Project
3 //
4 // Licensed under the Apache License, Version 2.0 (the "License");
5 // you may not use this file except in compliance with the License.
6 // You may obtain a copy of the License at
7 //
8 //      http://www.apache.org/licenses/LICENSE-2.0
9 //
10 // Unless required by applicable law or agreed to in writing, software
11 // distributed under the License is distributed on an "AS IS" BASIS,
12 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 // See the License for the specific language governing permissions and
14 // limitations under the License.
15 //
16 
17 #include "shill/connection_health_checker.h"
18 
19 #include <arpa/inet.h>
20 #include <netinet/in.h>
21 #include <stdlib.h>
22 #include <sys/socket.h>
23 #include <sys/types.h>
24 #include <time.h>
25 
26 #include <vector>
27 
28 #include <base/bind.h>
29 
30 #include "shill/async_connection.h"
31 #include "shill/connection.h"
32 #include "shill/dns_client.h"
33 #include "shill/dns_client_factory.h"
34 #include "shill/error.h"
35 #include "shill/http_url.h"
36 #include "shill/ip_address_store.h"
37 #include "shill/logging.h"
38 #include "shill/net/ip_address.h"
39 #include "shill/net/sockets.h"
40 #include "shill/socket_info.h"
41 #include "shill/socket_info_reader.h"
42 
43 using base::Bind;
44 using base::Unretained;
45 using std::string;
46 using std::vector;
47 
48 namespace shill {
49 
50 namespace Logging {
51 static auto kModuleLogScope = ScopeLogger::kConnection;
ObjectID(Connection * c)52 static string ObjectID(Connection* c) {
53   return c->interface_name();
54 }
55 }
56 
57 // static
58 const char* ConnectionHealthChecker::kDefaultRemoteIPPool[] = {
59     "74.125.224.47",
60     "74.125.224.79",
61     "74.125.224.111",
62     "74.125.224.143"
63 };
64 // static
65 const int ConnectionHealthChecker::kDNSTimeoutMilliseconds = 5000;
66 // static
67 const int ConnectionHealthChecker::kInvalidSocket = -1;
68 // static
69 const int ConnectionHealthChecker::kMaxFailedConnectionAttempts = 2;
70 // static
71 const int ConnectionHealthChecker::kMaxSentDataPollingAttempts = 2;
72 // static
73 const int ConnectionHealthChecker::kMinCongestedQueueAttempts = 2;
74 // static
75 const int ConnectionHealthChecker::kMinSuccessfulSendAttempts = 1;
76 // static
77 const int ConnectionHealthChecker::kNumDNSQueries = 5;
78 // static
79 const int ConnectionHealthChecker::kTCPStateUpdateWaitMilliseconds = 5000;
80 // static
81 const uint16_t ConnectionHealthChecker::kRemotePort = 80;
82 
ConnectionHealthChecker(ConnectionRefPtr connection,EventDispatcher * dispatcher,IPAddressStore * remote_ips,const base::Callback<void (Result)> & result_callback)83 ConnectionHealthChecker::ConnectionHealthChecker(
84     ConnectionRefPtr connection,
85     EventDispatcher* dispatcher,
86     IPAddressStore* remote_ips,
87     const base::Callback<void(Result)>& result_callback)
88     : connection_(connection),
89       dispatcher_(dispatcher),
90       remote_ips_(remote_ips),
91       result_callback_(result_callback),
92       socket_(new Sockets()),
93       weak_ptr_factory_(this),
94       connection_complete_callback_(
95           Bind(&ConnectionHealthChecker::OnConnectionComplete,
96                weak_ptr_factory_.GetWeakPtr())),
97       tcp_connection_(new AsyncConnection(connection_->interface_name(),
98                                           dispatcher_,
99                                           socket_.get(),
100                                           connection_complete_callback_)),
101       report_result_(
102           Bind(&ConnectionHealthChecker::ReportResult,
103                weak_ptr_factory_.GetWeakPtr())),
104       sock_fd_(kInvalidSocket),
105       socket_info_reader_(new SocketInfoReader()),
106       dns_client_factory_(DNSClientFactory::GetInstance()),
107       dns_client_callback_(Bind(&ConnectionHealthChecker::GetDNSResult,
108                                 weak_ptr_factory_.GetWeakPtr())),
109       health_check_in_progress_(false),
110       num_connection_failures_(0),
111       num_congested_queue_detected_(0),
112       num_successful_sends_(0),
113       tcp_state_update_wait_milliseconds_(kTCPStateUpdateWaitMilliseconds) {
114   for (size_t i = 0; i < arraysize(kDefaultRemoteIPPool); ++i) {
115     const char* ip_string = kDefaultRemoteIPPool[i];
116     IPAddress ip(IPAddress::kFamilyIPv4);
117     ip.SetAddressFromString(ip_string);
118     remote_ips_->AddUnique(ip);
119   }
120 }
121 
~ConnectionHealthChecker()122 ConnectionHealthChecker::~ConnectionHealthChecker() {
123   Stop();
124 }
125 
health_check_in_progress() const126 bool ConnectionHealthChecker::health_check_in_progress() const {
127   return health_check_in_progress_;
128 }
129 
AddRemoteIP(IPAddress ip)130 void ConnectionHealthChecker::AddRemoteIP(IPAddress ip) {
131   remote_ips_->AddUnique(ip);
132 }
133 
AddRemoteURL(const string & url_string)134 void ConnectionHealthChecker::AddRemoteURL(const string& url_string) {
135   GarbageCollectDNSClients();
136 
137   HTTPURL url;
138   if (!url.ParseFromString(url_string)) {
139     SLOG(connection_.get(), 2) << __func__ << ": Malformed url: "
140                                << url_string << ".";
141     return;
142   }
143   if (url.port() != kRemotePort) {
144     SLOG(connection_.get(), 2) << __func__
145                                << ": Remote connections only supported "
146                                << " to port 80, requested " << url.port()
147                                << ".";
148     return;
149   }
150   for (int i = 0; i < kNumDNSQueries; ++i) {
151     Error error;
152     DNSClient* dns_client =
153       dns_client_factory_->CreateDNSClient(IPAddress::kFamilyIPv4,
154                                            connection_->interface_name(),
155                                            connection_->dns_servers(),
156                                            kDNSTimeoutMilliseconds,
157                                            dispatcher_,
158                                            dns_client_callback_);
159     dns_clients_.push_back(dns_client);
160     if (!dns_clients_[i]->Start(url.host(), &error)) {
161       SLOG(connection_.get(), 2) << __func__ << ": Failed to start DNS client "
162                                  << "(query #" << i << "): "
163                                  << error.message();
164     }
165   }
166 }
167 
Start()168 void ConnectionHealthChecker::Start() {
169   if (health_check_in_progress_) {
170     SLOG(connection_.get(), 2) << __func__
171                                << ": Health Check already in progress.";
172     return;
173   }
174   if (!connection_.get()) {
175     SLOG(connection_.get(), 2) << __func__ << ": Connection not ready yet.";
176     result_callback_.Run(kResultUnknown);
177     return;
178   }
179 
180   health_check_in_progress_ = true;
181   num_connection_failures_ = 0;
182   num_congested_queue_detected_ = 0;
183   num_successful_sends_ = 0;
184 
185   if (remote_ips_->Empty()) {
186     // Nothing to try.
187     Stop();
188     SLOG(connection_.get(), 2) << __func__ << ": Not enough IPs.";
189     result_callback_.Run(kResultUnknown);
190     return;
191   }
192 
193   // Initiate the first attempt.
194   NextHealthCheckSample();
195 }
196 
Stop()197 void ConnectionHealthChecker::Stop() {
198   if (tcp_connection_.get() != nullptr)
199     tcp_connection_->Stop();
200   verify_sent_data_callback_.Cancel();
201   ClearSocketDescriptor();
202   health_check_in_progress_ = false;
203   num_connection_failures_ = 0;
204   num_congested_queue_detected_ = 0;
205   num_successful_sends_ = 0;
206   num_tx_queue_polling_attempts_ = 0;
207 }
208 
SetConnection(ConnectionRefPtr connection)209 void ConnectionHealthChecker::SetConnection(ConnectionRefPtr connection) {
210   SLOG(connection_.get(), 3) << __func__;
211   connection_ = connection;
212   tcp_connection_.reset(new AsyncConnection(connection_->interface_name(),
213                                             dispatcher_,
214                                             socket_.get(),
215                                             connection_complete_callback_));
216   dns_clients_.clear();
217   bool restart = health_check_in_progress();
218   Stop();
219   if (restart)
220     Start();
221 }
222 
ResultToString(ConnectionHealthChecker::Result result)223 const char* ConnectionHealthChecker::ResultToString(
224     ConnectionHealthChecker::Result result) {
225   switch (result) {
226     case kResultUnknown:
227       return "Unknown";
228     case kResultConnectionFailure:
229       return "ConnectionFailure";
230     case kResultCongestedTxQueue:
231       return "CongestedTxQueue";
232     case kResultSuccess:
233       return "Success";
234     default:
235       return "Invalid";
236   }
237 }
238 
GetDNSResult(const Error & error,const IPAddress & ip)239 void ConnectionHealthChecker::GetDNSResult(const Error& error,
240                                            const IPAddress& ip) {
241   if (!error.IsSuccess()) {
242     SLOG(connection_.get(), 2) << __func__ << "DNSClient returned failure: "
243                                << error.message();
244     return;
245   }
246   remote_ips_->AddUnique(ip);
247 }
248 
GarbageCollectDNSClients()249 void ConnectionHealthChecker::GarbageCollectDNSClients() {
250   ScopedVector<DNSClient> keep;
251   ScopedVector<DNSClient> discard;
252   for (size_t i = 0; i < dns_clients_.size(); ++i) {
253     if (dns_clients_[i]->IsActive())
254       keep.push_back(dns_clients_[i]);
255     else
256       discard.push_back(dns_clients_[i]);
257   }
258   dns_clients_.weak_clear();
259   dns_clients_ = std::move(keep);
260   discard.clear();
261 }
262 
NextHealthCheckSample()263 void ConnectionHealthChecker::NextHealthCheckSample() {
264   // Finish conditions:
265   if (num_connection_failures_ == kMaxFailedConnectionAttempts) {
266     health_check_result_ = kResultConnectionFailure;
267     dispatcher_->PostTask(report_result_);
268     return;
269   }
270   if (num_congested_queue_detected_ == kMinCongestedQueueAttempts) {
271     health_check_result_ = kResultCongestedTxQueue;
272     dispatcher_->PostTask(report_result_);
273     return;
274   }
275   if (num_successful_sends_ == kMinSuccessfulSendAttempts) {
276     health_check_result_ = kResultSuccess;
277     dispatcher_->PostTask(report_result_);
278     return;
279   }
280 
281   // Pick a random IP from the set of IPs.
282   // This guards against
283   //   (1) Repeated failed attempts for the same IP at start-up everytime.
284   //   (2) All users attempting to connect to the same IP.
285   IPAddress ip = remote_ips_->GetRandomIP();
286   SLOG(connection_.get(), 3) << __func__ << ": Starting connection at "
287                              << ip.ToString();
288   if (!tcp_connection_->Start(ip, kRemotePort)) {
289     SLOG(connection_.get(), 2) << __func__ << ": Connection attempt failed.";
290     ++num_connection_failures_;
291     NextHealthCheckSample();
292   }
293 }
294 
OnConnectionComplete(bool success,int sock_fd)295 void ConnectionHealthChecker::OnConnectionComplete(bool success, int sock_fd) {
296   if (!success) {
297     SLOG(connection_.get(), 2) << __func__
298                                << ": AsyncConnection connection attempt failed "
299                                << "with error: "
300                                << tcp_connection_->error();
301     ++num_connection_failures_;
302     NextHealthCheckSample();
303     return;
304   }
305 
306   SetSocketDescriptor(sock_fd);
307 
308   SocketInfo sock_info;
309   if (!GetSocketInfo(sock_fd_, &sock_info) ||
310       sock_info.connection_state() !=
311           SocketInfo::kConnectionStateEstablished) {
312     SLOG(connection_.get(), 2) << __func__
313                                << ": Connection originally not in established "
314                                   "state.";
315     // Count this as a failed connection attempt.
316     ++num_connection_failures_;
317     ClearSocketDescriptor();
318     NextHealthCheckSample();
319     return;
320   }
321 
322   old_transmit_queue_value_ = sock_info.transmit_queue_value();
323   num_tx_queue_polling_attempts_ = 0;
324 
325   // Send data on the connection and post a delayed task to check successful
326   // transfer.
327   char buf;
328   if (socket_->Send(sock_fd_, &buf, sizeof(buf), 0) == -1) {
329     SLOG(connection_.get(), 2) << __func__ << ": " << socket_->ErrorString();
330     // Count this as a failed connection attempt.
331     ++num_connection_failures_;
332     ClearSocketDescriptor();
333     NextHealthCheckSample();
334     return;
335   }
336 
337   verify_sent_data_callback_.Reset(
338       Bind(&ConnectionHealthChecker::VerifySentData, Unretained(this)));
339   dispatcher_->PostDelayedTask(verify_sent_data_callback_.callback(),
340                                tcp_state_update_wait_milliseconds_);
341 }
342 
VerifySentData()343 void ConnectionHealthChecker::VerifySentData() {
344   SocketInfo sock_info;
345   bool sock_info_found = GetSocketInfo(sock_fd_, &sock_info);
346   // Acceptable TCP connection states after sending the data:
347   // kConnectionStateEstablished: No change in connection state since the send.
348   // kConnectionStateCloseWait: The remote host recieved the sent data and
349   //    requested connection close.
350   if (!sock_info_found ||
351       (sock_info.connection_state() !=
352            SocketInfo::kConnectionStateEstablished &&
353       sock_info.connection_state() !=
354            SocketInfo::kConnectionStateCloseWait)) {
355     SLOG(connection_.get(), 2)
356         << __func__ << ": Connection not in acceptable state after send.";
357     if (sock_info_found)
358       SLOG(connection_.get(), 3) << "Found socket info but in state: "
359                                  << sock_info.connection_state();
360     ++num_connection_failures_;
361   } else if (sock_info.transmit_queue_value() > old_transmit_queue_value_ &&
362       sock_info.timer_state() ==
363           SocketInfo::kTimerStateRetransmitTimerPending) {
364     if (num_tx_queue_polling_attempts_ < kMaxSentDataPollingAttempts) {
365       SLOG(connection_.get(), 2) << __func__
366                                  << ": Polling again.";
367       ++num_tx_queue_polling_attempts_;
368       verify_sent_data_callback_.Reset(
369           Bind(&ConnectionHealthChecker::VerifySentData, Unretained(this)));
370       dispatcher_->PostDelayedTask(verify_sent_data_callback_.callback(),
371                                    tcp_state_update_wait_milliseconds_);
372       return;
373     }
374     SLOG(connection_.get(), 2) << __func__ << ": Sampled congested Tx-Queue";
375     ++num_congested_queue_detected_;
376   } else {
377     SLOG(connection_.get(), 2) << __func__ << ": Sampled successful send.";
378     ++num_successful_sends_;
379   }
380   ClearSocketDescriptor();
381   NextHealthCheckSample();
382 }
383 
384 // TODO(pprabhu): Scrub IP address logging.
GetSocketInfo(int sock_fd,SocketInfo * sock_info)385 bool ConnectionHealthChecker::GetSocketInfo(int sock_fd,
386                                             SocketInfo* sock_info) {
387   struct sockaddr_storage addr;
388   socklen_t addrlen = sizeof(addr);
389   memset(&addr, 0, sizeof(addr));
390   if (socket_->GetSockName(sock_fd,
391                            reinterpret_cast<struct sockaddr*>(&addr),
392                            &addrlen) != 0) {
393     SLOG(connection_.get(), 2) << __func__
394                                << ": Failed to get address of created socket.";
395     return false;
396   }
397   if (addr.ss_family != AF_INET) {
398     SLOG(connection_.get(), 2) << __func__ << ": IPv6 socket address found.";
399     return false;
400   }
401 
402   CHECK_EQ(sizeof(struct sockaddr_in), addrlen);
403   struct sockaddr_in* addr_in = reinterpret_cast<sockaddr_in*>(&addr);
404   uint16_t local_port = ntohs(addr_in->sin_port);
405   char ipstr[INET_ADDRSTRLEN];
406   const char* res = inet_ntop(AF_INET, &addr_in->sin_addr,
407                               ipstr, sizeof(ipstr));
408   if (res == nullptr) {
409     SLOG(connection_.get(), 2) << __func__
410                                << ": Could not convert IP address to string.";
411     return false;
412   }
413 
414   IPAddress local_ip_address(IPAddress::kFamilyIPv4);
415   CHECK(local_ip_address.SetAddressFromString(ipstr));
416   SLOG(connection_.get(), 3) << "Local IP = " << local_ip_address.ToString()
417                              << ":" << local_port;
418 
419   vector<SocketInfo> info_list;
420   if (!socket_info_reader_->LoadTcpSocketInfo(&info_list)) {
421     SLOG(connection_.get(), 2) << __func__
422                                << ": Failed to load TCP socket info.";
423     return false;
424   }
425 
426   for (vector<SocketInfo>::const_iterator info_list_it = info_list.begin();
427        info_list_it != info_list.end();
428        ++info_list_it) {
429     const SocketInfo& cur_sock_info = *info_list_it;
430 
431     SLOG(connection_.get(), 4)
432         << "Testing against IP = "
433         << cur_sock_info.local_ip_address().ToString()
434         << ":" << cur_sock_info.local_port()
435         << " (addresses equal:"
436         << cur_sock_info.local_ip_address().Equals(local_ip_address)
437         << ", ports equal:" << (cur_sock_info.local_port() == local_port)
438         << ")";
439 
440     if (cur_sock_info.local_ip_address().Equals(local_ip_address) &&
441         cur_sock_info.local_port() == local_port) {
442       SLOG(connection_.get(), 3) << __func__
443                                  << ": Found matching TCP socket info.";
444       *sock_info = cur_sock_info;
445       return true;
446     }
447   }
448 
449   SLOG(connection_.get(), 2) << __func__ << ": No matching TCP socket info.";
450   return false;
451 }
452 
ReportResult()453 void ConnectionHealthChecker::ReportResult() {
454   SLOG(connection_.get(), 2) << __func__ << ": Result: "
455                              << ResultToString(health_check_result_);
456   Stop();
457   result_callback_.Run(health_check_result_);
458 }
459 
SetSocketDescriptor(int sock_fd)460 void ConnectionHealthChecker::SetSocketDescriptor(int sock_fd) {
461   if (sock_fd_ != kInvalidSocket) {
462     SLOG(connection_.get(), 4) << "Closing socket";
463     socket_->Close(sock_fd_);
464   }
465   sock_fd_ = sock_fd;
466 }
467 
ClearSocketDescriptor()468 void ConnectionHealthChecker::ClearSocketDescriptor() {
469   SetSocketDescriptor(kInvalidSocket);
470 }
471 
472 }  // namespace shill
473