1 //
2 // Copyright (C) 2012 The Android Open Source Project
3 //
4 // Licensed under the Apache License, Version 2.0 (the "License");
5 // you may not use this file except in compliance with the License.
6 // You may obtain a copy of the License at
7 //
8 // http://www.apache.org/licenses/LICENSE-2.0
9 //
10 // Unless required by applicable law or agreed to in writing, software
11 // distributed under the License is distributed on an "AS IS" BASIS,
12 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 // See the License for the specific language governing permissions and
14 // limitations under the License.
15 //
16
17 #include "shill/connectivity_trial.h"
18
19 #include <string>
20
21 #include <base/bind.h>
22 #include <base/strings/pattern.h>
23 #include <base/strings/string_number_conversions.h>
24 #include <base/strings/string_util.h>
25 #include <base/strings/stringprintf.h>
26 #if defined(__ANDROID__)
27 #include <dbus/service_constants.h>
28 #else
29 #include <chromeos/dbus/service_constants.h>
30 #endif // __ANDROID__
31
32 #include "shill/async_connection.h"
33 #include "shill/connection.h"
34 #include "shill/dns_client.h"
35 #include "shill/event_dispatcher.h"
36 #include "shill/http_request.h"
37 #include "shill/http_url.h"
38 #include "shill/logging.h"
39 #include "shill/net/ip_address.h"
40 #include "shill/net/sockets.h"
41
42 using base::Bind;
43 using base::Callback;
44 using base::StringPrintf;
45 using std::string;
46
47 namespace shill {
48
49 namespace Logging {
50 static auto kModuleLogScope = ScopeLogger::kPortal;
ObjectID(Connection * c)51 static string ObjectID(Connection* c) { return c->interface_name(); }
52 }
53
54 const char ConnectivityTrial::kDefaultURL[] =
55 "http://www.gstatic.com/generate_204";
56 const char ConnectivityTrial::kResponseExpected[] = "HTTP/?.? 204";
57
ConnectivityTrial(ConnectionRefPtr connection,EventDispatcher * dispatcher,int trial_timeout_seconds,const Callback<void (Result)> & callback)58 ConnectivityTrial::ConnectivityTrial(
59 ConnectionRefPtr connection,
60 EventDispatcher* dispatcher,
61 int trial_timeout_seconds,
62 const Callback<void(Result)>& callback)
63 : connection_(connection),
64 dispatcher_(dispatcher),
65 trial_timeout_seconds_(trial_timeout_seconds),
66 trial_callback_(callback),
67 weak_ptr_factory_(this),
68 request_read_callback_(
69 Bind(&ConnectivityTrial::RequestReadCallback,
70 weak_ptr_factory_.GetWeakPtr())),
71 request_result_callback_(
72 Bind(&ConnectivityTrial::RequestResultCallback,
73 weak_ptr_factory_.GetWeakPtr())),
74 is_active_(false) { }
75
~ConnectivityTrial()76 ConnectivityTrial::~ConnectivityTrial() {
77 Stop();
78 }
79
Retry(int start_delay_milliseconds)80 bool ConnectivityTrial::Retry(int start_delay_milliseconds) {
81 SLOG(connection_.get(), 3) << "In " << __func__;
82 if (request_.get())
83 CleanupTrial(false);
84 else
85 return false;
86 StartTrialAfterDelay(start_delay_milliseconds);
87 return true;
88 }
89
Start(const string & url_string,int start_delay_milliseconds)90 bool ConnectivityTrial::Start(const string& url_string,
91 int start_delay_milliseconds) {
92 SLOG(connection_.get(), 3) << "In " << __func__;
93
94 if (!url_.ParseFromString(url_string)) {
95 LOG(ERROR) << "Failed to parse URL string: " << url_string;
96 return false;
97 }
98 if (request_.get()) {
99 CleanupTrial(false);
100 } else {
101 request_.reset(new HTTPRequest(connection_, dispatcher_, &sockets_));
102 }
103 StartTrialAfterDelay(start_delay_milliseconds);
104 return true;
105 }
106
Stop()107 void ConnectivityTrial::Stop() {
108 SLOG(connection_.get(), 3) << "In " << __func__;
109
110 if (!request_.get()) {
111 return;
112 }
113
114 CleanupTrial(true);
115 }
116
StartTrialAfterDelay(int start_delay_milliseconds)117 void ConnectivityTrial::StartTrialAfterDelay(int start_delay_milliseconds) {
118 SLOG(connection_.get(), 4) << "In " << __func__
119 << " delay = " << start_delay_milliseconds
120 << "ms.";
121 trial_.Reset(Bind(&ConnectivityTrial::StartTrialTask,
122 weak_ptr_factory_.GetWeakPtr()));
123 dispatcher_->PostDelayedTask(trial_.callback(), start_delay_milliseconds);
124 }
125
StartTrialTask()126 void ConnectivityTrial::StartTrialTask() {
127 HTTPRequest::Result result =
128 request_->Start(url_, request_read_callback_, request_result_callback_);
129 if (result != HTTPRequest::kResultInProgress) {
130 CompleteTrial(ConnectivityTrial::GetPortalResultForRequestResult(result));
131 return;
132 }
133 is_active_ = true;
134
135 trial_timeout_.Reset(Bind(&ConnectivityTrial::TimeoutTrialTask,
136 weak_ptr_factory_.GetWeakPtr()));
137 dispatcher_->PostDelayedTask(trial_timeout_.callback(),
138 trial_timeout_seconds_ * 1000);
139 }
140
IsActive()141 bool ConnectivityTrial::IsActive() {
142 return is_active_;
143 }
144
RequestReadCallback(const ByteString & response_data)145 void ConnectivityTrial::RequestReadCallback(const ByteString& response_data) {
146 const string response_expected(kResponseExpected);
147 bool expected_length_received = false;
148 int compare_length = 0;
149 if (response_data.GetLength() < response_expected.length()) {
150 // There isn't enough data yet for a final decision, but we can still
151 // test to see if the partial string matches so far.
152 expected_length_received = false;
153 compare_length = response_data.GetLength();
154 } else {
155 expected_length_received = true;
156 compare_length = response_expected.length();
157 }
158
159 if (base::MatchPattern(
160 string(reinterpret_cast<const char*>(response_data.GetConstData()),
161 compare_length),
162 response_expected.substr(0, compare_length))) {
163 if (expected_length_received) {
164 CompleteTrial(Result(kPhaseContent, kStatusSuccess));
165 }
166 // Otherwise, we wait for more data from the server.
167 } else {
168 CompleteTrial(Result(kPhaseContent, kStatusFailure));
169 }
170 }
171
RequestResultCallback(HTTPRequest::Result result,const ByteString &)172 void ConnectivityTrial::RequestResultCallback(
173 HTTPRequest::Result result, const ByteString& /*response_data*/) {
174 CompleteTrial(GetPortalResultForRequestResult(result));
175 }
176
CompleteTrial(Result result)177 void ConnectivityTrial::CompleteTrial(Result result) {
178 SLOG(connection_.get(), 3)
179 << StringPrintf("Connectivity Trial completed with phase==%s, status==%s",
180 PhaseToString(result.phase).c_str(),
181 StatusToString(result.status).c_str());
182 CleanupTrial(false);
183 trial_callback_.Run(result);
184 }
185
CleanupTrial(bool reset_request)186 void ConnectivityTrial::CleanupTrial(bool reset_request) {
187 trial_timeout_.Cancel();
188
189 if (request_.get())
190 request_->Stop();
191
192 is_active_ = false;
193
194 if (!reset_request || !request_.get())
195 return;
196
197 request_.reset();
198 }
199
TimeoutTrialTask()200 void ConnectivityTrial::TimeoutTrialTask() {
201 LOG(ERROR) << "Connectivity Trial - Request timed out";
202 if (request_->response_data().GetLength()) {
203 CompleteTrial(ConnectivityTrial::Result(ConnectivityTrial::kPhaseContent,
204 ConnectivityTrial::kStatusTimeout));
205 } else {
206 CompleteTrial(ConnectivityTrial::Result(ConnectivityTrial::kPhaseUnknown,
207 ConnectivityTrial::kStatusTimeout));
208 }
209 }
210
211 // statiic
PhaseToString(Phase phase)212 const string ConnectivityTrial::PhaseToString(Phase phase) {
213 switch (phase) {
214 case kPhaseConnection:
215 return kPortalDetectionPhaseConnection;
216 case kPhaseDNS:
217 return kPortalDetectionPhaseDns;
218 case kPhaseHTTP:
219 return kPortalDetectionPhaseHttp;
220 case kPhaseContent:
221 return kPortalDetectionPhaseContent;
222 case kPhaseUnknown:
223 default:
224 return kPortalDetectionPhaseUnknown;
225 }
226 }
227
228 // static
StatusToString(Status status)229 const string ConnectivityTrial::StatusToString(Status status) {
230 switch (status) {
231 case kStatusSuccess:
232 return kPortalDetectionStatusSuccess;
233 case kStatusTimeout:
234 return kPortalDetectionStatusTimeout;
235 case kStatusFailure:
236 default:
237 return kPortalDetectionStatusFailure;
238 }
239 }
240
GetPortalResultForRequestResult(HTTPRequest::Result result)241 ConnectivityTrial::Result ConnectivityTrial::GetPortalResultForRequestResult(
242 HTTPRequest::Result result) {
243 switch (result) {
244 case HTTPRequest::kResultSuccess:
245 // The request completed without receiving the expected payload.
246 return Result(kPhaseContent, kStatusFailure);
247 case HTTPRequest::kResultDNSFailure:
248 return Result(kPhaseDNS, kStatusFailure);
249 case HTTPRequest::kResultDNSTimeout:
250 return Result(kPhaseDNS, kStatusTimeout);
251 case HTTPRequest::kResultConnectionFailure:
252 return Result(kPhaseConnection, kStatusFailure);
253 case HTTPRequest::kResultConnectionTimeout:
254 return Result(kPhaseConnection, kStatusTimeout);
255 case HTTPRequest::kResultRequestFailure:
256 case HTTPRequest::kResultResponseFailure:
257 return Result(kPhaseHTTP, kStatusFailure);
258 case HTTPRequest::kResultRequestTimeout:
259 case HTTPRequest::kResultResponseTimeout:
260 return Result(kPhaseHTTP, kStatusTimeout);
261 case HTTPRequest::kResultUnknown:
262 default:
263 return Result(kPhaseUnknown, kStatusFailure);
264 }
265 }
266
267 } // namespace shill
268