1 //
2 // Copyright (C) 2012 The Android Open Source Project
3 //
4 // Licensed under the Apache License, Version 2.0 (the "License");
5 // you may not use this file except in compliance with the License.
6 // You may obtain a copy of the License at
7 //
8 // http://www.apache.org/licenses/LICENSE-2.0
9 //
10 // Unless required by applicable law or agreed to in writing, software
11 // distributed under the License is distributed on an "AS IS" BASIS,
12 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 // See the License for the specific language governing permissions and
14 // limitations under the License.
15 //
16
17 #include "shill/http_request.h"
18
19 #include <string>
20
21 #include <base/bind.h>
22 #include <base/strings/string_number_conversions.h>
23 #include <base/strings/stringprintf.h>
24
25 #include "shill/async_connection.h"
26 #include "shill/connection.h"
27 #include "shill/dns_client.h"
28 #include "shill/error.h"
29 #include "shill/event_dispatcher.h"
30 #include "shill/http_url.h"
31 #include "shill/logging.h"
32 #include "shill/net/ip_address.h"
33 #include "shill/net/sockets.h"
34
35 using base::Bind;
36 using base::Callback;
37 using base::StringPrintf;
38 using std::string;
39
40 namespace shill {
41
42 namespace Logging {
43 static auto kModuleLogScope = ScopeLogger::kHTTP;
ObjectID(Connection * c)44 static string ObjectID(Connection* c) { return c->interface_name(); }
45 }
46
47 const int HTTPRequest::kConnectTimeoutSeconds = 10;
48 const int HTTPRequest::kDNSTimeoutSeconds = 5;
49 const int HTTPRequest::kInputTimeoutSeconds = 10;
50
51 const char HTTPRequest::kHTTPRequestTemplate[] =
52 "GET %s HTTP/1.1\r\n"
53 "Host: %s:%d\r\n"
54 "Connection: Close\r\n\r\n";
55
HTTPRequest(ConnectionRefPtr connection,EventDispatcher * dispatcher,Sockets * sockets)56 HTTPRequest::HTTPRequest(ConnectionRefPtr connection,
57 EventDispatcher* dispatcher,
58 Sockets* sockets)
59 : connection_(connection),
60 dispatcher_(dispatcher),
61 sockets_(sockets),
62 weak_ptr_factory_(this),
63 connect_completion_callback_(
64 Bind(&HTTPRequest::OnConnectCompletion,
65 weak_ptr_factory_.GetWeakPtr())),
66 dns_client_callback_(Bind(&HTTPRequest::GetDNSResult,
67 weak_ptr_factory_.GetWeakPtr())),
68 read_server_callback_(Bind(&HTTPRequest::ReadFromServer,
69 weak_ptr_factory_.GetWeakPtr())),
70 write_server_callback_(Bind(&HTTPRequest::WriteToServer,
71 weak_ptr_factory_.GetWeakPtr())),
72 dns_client_(
73 new DNSClient(connection->IsIPv6() ? IPAddress::kFamilyIPv6
74 : IPAddress::kFamilyIPv4,
75 connection->interface_name(),
76 connection->dns_servers(),
77 kDNSTimeoutSeconds * 1000,
78 dispatcher,
79 dns_client_callback_)),
80 server_async_connection_(
81 new AsyncConnection(connection_->interface_name(),
82 dispatcher_, sockets,
83 connect_completion_callback_)),
84 server_port_(-1),
85 server_socket_(-1),
86 timeout_result_(kResultUnknown),
87 is_running_(false) { }
88
~HTTPRequest()89 HTTPRequest::~HTTPRequest() {
90 Stop();
91 }
92
Start(const HTTPURL & url,const Callback<void (const ByteString &)> & read_event_callback,const Callback<void (Result,const ByteString &)> & result_callback)93 HTTPRequest::Result HTTPRequest::Start(
94 const HTTPURL& url,
95 const Callback<void(const ByteString&)>& read_event_callback,
96 const Callback<void(Result, const ByteString&)>& result_callback) {
97 SLOG(connection_.get(), 3) << "In " << __func__;
98
99 DCHECK(!is_running_);
100
101 is_running_ = true;
102 request_data_ = ByteString(StringPrintf(kHTTPRequestTemplate,
103 url.path().c_str(),
104 url.host().c_str(),
105 url.port()), false);
106 server_hostname_ = url.host();
107 server_port_ = url.port();
108 connection_->RequestRouting();
109
110 IPAddress addr(IPAddress::kFamilyIPv4);
111 if (connection_->IsIPv6()) {
112 addr.set_family(IPAddress::kFamilyIPv6);
113 }
114 if (addr.SetAddressFromString(server_hostname_)) {
115 if (!ConnectServer(addr, server_port_)) {
116 LOG(ERROR) << "Connect to "
117 << server_hostname_
118 << " failed synchronously";
119 return kResultConnectionFailure;
120 }
121 } else {
122 SLOG(connection_.get(), 3) << "Looking up host: " << server_hostname_;
123 Error error;
124 if (!dns_client_->Start(server_hostname_, &error)) {
125 LOG(ERROR) << "Failed to start DNS client: " << error.message();
126 Stop();
127 return kResultDNSFailure;
128 }
129 }
130
131 // Only install callbacks after connection succeeds in starting.
132 read_event_callback_ = read_event_callback;
133 result_callback_ = result_callback;
134
135 return kResultInProgress;
136 }
137
Stop()138 void HTTPRequest::Stop() {
139 SLOG(connection_.get(), 3) << "In " << __func__ << "; running is "
140 << is_running_;
141
142 if (!is_running_) {
143 return;
144 }
145
146 // Clear IO handlers first so that closing the socket doesn't cause
147 // events to fire.
148 write_server_handler_.reset();
149 read_server_handler_.reset();
150
151 connection_->ReleaseRouting();
152 dns_client_->Stop();
153 is_running_ = false;
154 result_callback_.Reset();
155 read_event_callback_.Reset();
156 request_data_.Clear();
157 response_data_.Clear();
158 server_async_connection_->Stop();
159 server_hostname_.clear();
160 server_port_ = -1;
161 if (server_socket_ != -1) {
162 sockets_->Close(server_socket_);
163 server_socket_ = -1;
164 }
165 timeout_closure_.Cancel();
166 timeout_result_ = kResultUnknown;
167 }
168
ConnectServer(const IPAddress & address,int port)169 bool HTTPRequest::ConnectServer(const IPAddress& address, int port) {
170 SLOG(connection_.get(), 3) << "In " << __func__;
171 if (!server_async_connection_->Start(address, port)) {
172 LOG(ERROR) << "Could not create socket to connect to server at "
173 << address.ToString();
174 SendStatus(kResultConnectionFailure);
175 return false;
176 }
177 // Start a connection timeout only if we didn't synchronously connect.
178 if (server_socket_ == -1) {
179 StartIdleTimeout(kConnectTimeoutSeconds, kResultConnectionTimeout);
180 }
181 return true;
182 }
183
184 // DNSClient callback that fires when the DNS request completes.
GetDNSResult(const Error & error,const IPAddress & address)185 void HTTPRequest::GetDNSResult(const Error& error, const IPAddress& address) {
186 SLOG(connection_.get(), 3) << "In " << __func__;
187 if (!error.IsSuccess()) {
188 LOG(ERROR) << "Could not resolve hostname "
189 << server_hostname_
190 << ": "
191 << error.message();
192 if (error.message() == DNSClient::kErrorTimedOut) {
193 SendStatus(kResultDNSTimeout);
194 } else {
195 SendStatus(kResultDNSFailure);
196 }
197 return;
198 }
199 ConnectServer(address, server_port_);
200 }
201
202 // AsyncConnection callback routine which fires when the asynchronous Connect()
203 // to the remote server completes (or fails).
OnConnectCompletion(bool success,int fd)204 void HTTPRequest::OnConnectCompletion(bool success, int fd) {
205 SLOG(connection_.get(), 3) << "In " << __func__;
206 if (!success) {
207 LOG(ERROR) << "Socket connection delayed failure to "
208 << server_hostname_
209 << ": "
210 << server_async_connection_->error();
211 // |this| could be freed as a result of calling SendStatus().
212 SendStatus(kResultConnectionFailure);
213 return;
214 }
215 server_socket_ = fd;
216 write_server_handler_.reset(
217 dispatcher_->CreateReadyHandler(server_socket_,
218 IOHandler::kModeOutput,
219 write_server_callback_));
220 StartIdleTimeout(kInputTimeoutSeconds, kResultRequestTimeout);
221 }
222
OnServerReadError(const string &)223 void HTTPRequest::OnServerReadError(const string& /*error_msg*/) {
224 SendStatus(kResultResponseFailure);
225 }
226
227 // IOInputHandler callback which fires when data has been read from the
228 // server.
ReadFromServer(InputData * data)229 void HTTPRequest::ReadFromServer(InputData* data) {
230 SLOG(connection_.get(), 3) << "In " << __func__ << " length " << data->len;
231 if (data->len == 0) {
232 SendStatus(kResultSuccess);
233 return;
234 }
235
236 response_data_.Append(ByteString(data->buf, data->len));
237 StartIdleTimeout(kInputTimeoutSeconds, kResultResponseTimeout);
238 if (!read_event_callback_.is_null()) {
239 read_event_callback_.Run(response_data_);
240 }
241 }
242
SendStatus(Result result)243 void HTTPRequest::SendStatus(Result result) {
244 // Save copies on the stack, since Stop() will remove them.
245 Callback<void(Result, const ByteString&)> result_callback = result_callback_;
246 const ByteString response_data(response_data_);
247 Stop();
248
249 // Call the callback last, since it may delete us and |this| may no longer
250 // be valid.
251 if (!result_callback.is_null()) {
252 result_callback.Run(result, response_data);
253 }
254 }
255
256 // Start a timeout for "the next event".
StartIdleTimeout(int timeout_seconds,Result timeout_result)257 void HTTPRequest::StartIdleTimeout(int timeout_seconds, Result timeout_result) {
258 timeout_result_ = timeout_result;
259 timeout_closure_.Reset(
260 Bind(&HTTPRequest::TimeoutTask, weak_ptr_factory_.GetWeakPtr()));
261 dispatcher_->PostDelayedTask(timeout_closure_.callback(),
262 timeout_seconds * 1000);
263 }
264
TimeoutTask()265 void HTTPRequest::TimeoutTask() {
266 LOG(ERROR) << "Connection with "
267 << server_hostname_
268 << " timed out";
269 SendStatus(timeout_result_);
270 }
271
272 // Output ReadyHandler callback which fires when the server socket is
273 // ready for data to be sent to it.
WriteToServer(int fd)274 void HTTPRequest::WriteToServer(int fd) {
275 CHECK_EQ(server_socket_, fd);
276 int ret = sockets_->Send(fd, request_data_.GetConstData(),
277 request_data_.GetLength(), 0);
278 CHECK(ret < 0 || static_cast<size_t>(ret) <= request_data_.GetLength());
279
280 SLOG(connection_.get(), 3) << "In " << __func__ << " wrote " << ret << " of "
281 << request_data_.GetLength();
282
283 if (ret < 0) {
284 LOG(ERROR) << "Client write failed to "
285 << server_hostname_;
286 SendStatus(kResultRequestFailure);
287 return;
288 }
289
290 request_data_ = ByteString(request_data_.GetConstData() + ret,
291 request_data_.GetLength() - ret);
292
293 if (request_data_.IsEmpty()) {
294 write_server_handler_->Stop();
295 read_server_handler_.reset(dispatcher_->CreateInputHandler(
296 server_socket_,
297 read_server_callback_,
298 Bind(&HTTPRequest::OnServerReadError, weak_ptr_factory_.GetWeakPtr())));
299 StartIdleTimeout(kInputTimeoutSeconds, kResultResponseTimeout);
300 } else {
301 StartIdleTimeout(kInputTimeoutSeconds, kResultRequestTimeout);
302 }
303 }
304
305 } // namespace shill
306