1 // Copyright 2015 The Android Open Source Project
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 //      http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 
15 #include "webservd/protocol_handler.h"
16 
17 #include <linux/tcp.h>
18 #include <microhttpd.h>
19 #include <netinet/in.h>
20 #include <sys/socket.h>
21 
22 #include <algorithm>
23 #include <limits>
24 #include <vector>
25 
26 #include <base/bind.h>
27 #include <base/guid.h>
28 #include <base/logging.h>
29 #include <base/message_loop/message_loop.h>
30 
31 #include "webservd/request.h"
32 #include "webservd/request_handler_interface.h"
33 #include "webservd/server_interface.h"
34 
35 namespace webservd {
36 
37 // Helper class to provide static callback methods to libmicrohttpd library,
38 // with the ability to access private methods of Server class.
39 class ServerHelper final {
40  public:
ConnectionHandler(void * cls,MHD_Connection * connection,const char * url,const char * method,const char * version,const char * upload_data,size_t * upload_data_size,void ** con_cls)41   static int ConnectionHandler(void *cls,
42                                MHD_Connection* connection,
43                                const char* url,
44                                const char* method,
45                                const char* version,
46                                const char* upload_data,
47                                size_t* upload_data_size,
48                                void** con_cls) {
49     auto handler = reinterpret_cast<ProtocolHandler*>(cls);
50     if (nullptr == *con_cls) {
51       std::string request_handler_id = handler->FindRequestHandler(url, method);
52       std::unique_ptr<Request> request{new Request{
53           request_handler_id, url, method, version, connection, handler
54       }};
55       if (!request->BeginRequestData())
56         return MHD_NO;
57 
58       // Pass the raw pointer here in order to interface with libmicrohttpd's
59       // old-style C API.
60       *con_cls = request.release();
61     } else {
62       auto request = reinterpret_cast<Request*>(*con_cls);
63       if (*upload_data_size) {
64         if (!request->AddRequestData(upload_data, upload_data_size))
65           return MHD_NO;
66       } else {
67         request->EndRequestData();
68       }
69     }
70     return MHD_YES;
71   }
72 
RequestCompleted(void *,MHD_Connection *,void ** con_cls,MHD_RequestTerminationCode toe)73   static void RequestCompleted(void* /* cls */,
74                                MHD_Connection*  /* connection */,
75                                void** con_cls,
76                                MHD_RequestTerminationCode toe) {
77     if (toe != MHD_REQUEST_TERMINATED_COMPLETED_OK) {
78       LOG(ERROR) << "Web request terminated abnormally with error code: "
79                  << toe;
80     }
81     auto request = reinterpret_cast<Request*>(*con_cls);
82     *con_cls = nullptr;
83     delete request;
84   }
85 };
86 
ProtocolHandler(const std::string & name,ServerInterface * server_interface)87 ProtocolHandler::ProtocolHandler(const std::string& name,
88                                  ServerInterface* server_interface)
89     : id_{base::GenerateGUID()},
90       name_{name},
91       server_interface_{server_interface} {}
92 
~ProtocolHandler()93 ProtocolHandler::~ProtocolHandler() {
94   Stop();
95 }
96 
AddRequestHandler(const std::string & url,const std::string & method,std::unique_ptr<RequestHandlerInterface> handler)97 std::string ProtocolHandler::AddRequestHandler(
98     const std::string& url,
99     const std::string& method,
100     std::unique_ptr<RequestHandlerInterface> handler) {
101   std::string handler_id = base::GenerateGUID();
102   request_handlers_.emplace(handler_id,
103                             HandlerMapEntry{url, method, std::move(handler)});
104   return handler_id;
105 }
106 
RemoveRequestHandler(const std::string & handler_id)107 bool ProtocolHandler::RemoveRequestHandler(const std::string& handler_id) {
108   return request_handlers_.erase(handler_id) == 1;
109 }
110 
FindRequestHandler(const base::StringPiece & url,const base::StringPiece & method) const111 std::string ProtocolHandler::FindRequestHandler(
112     const base::StringPiece& url,
113     const base::StringPiece& method) const {
114   size_t score = std::numeric_limits<size_t>::max();
115   std::string handler_id;
116   for (const auto& pair : request_handlers_) {
117     std::string handler_url = pair.second.url;
118     bool url_match = (handler_url == url);
119     bool method_match = (pair.second.method == method);
120 
121     // Try exact match first. If everything matches, we have our handler.
122     if (url_match && method_match)
123       return pair.first;
124 
125     // Calculate the current handler's similarity score. The lower the score
126     // the better the match is...
127     size_t current_score = 0;
128     if (!url_match && !handler_url.empty() && handler_url.back() == '/') {
129       if (url.starts_with(handler_url)) {
130         url_match = true;
131         // Use the difference in URL length as URL match quality proxy.
132         // The longer URL, the more specific (better) match is.
133         // Multiply by 2 to allow for extra score point for matching the method.
134         current_score = (url.size() - handler_url.size()) * 2;
135       }
136     }
137 
138     if (!method_match && pair.second.method.empty()) {
139       // If the handler didn't specify the method it handles, this means
140       // it doesn't care. However this isn't the exact match, so bump
141       // the score up one point.
142       method_match = true;
143       ++current_score;
144     }
145 
146     if (url_match && method_match && current_score < score) {
147       score = current_score;
148       handler_id = pair.first;
149     }
150   }
151 
152   return handler_id;
153 }
154 
Start(Config::ProtocolHandler * config)155 bool ProtocolHandler::Start(Config::ProtocolHandler* config) {
156   if (server_) {
157     LOG(ERROR) << "Protocol handler is already running.";
158     return false;
159   }
160 
161   // If using TLS, the certificate, private key and fingerprint must be
162   // provided.
163   CHECK_EQ(config->use_tls, !config->private_key.empty());
164   CHECK_EQ(config->use_tls, !config->certificate.empty());
165   CHECK_EQ(config->use_tls, !config->certificate_fingerprint.empty());
166 
167   LOG(INFO) << "Starting " << (config->use_tls ? "HTTPS" : "HTTP")
168             << " protocol handler on port: " << config->port;
169 
170   port_ = config->port;
171   protocol_ = (config->use_tls ? "https" : "http");
172   certificate_fingerprint_ = config->certificate_fingerprint;
173 
174   auto callback_addr =
175       reinterpret_cast<intptr_t>(&ServerHelper::RequestCompleted);
176   uint32_t flags = MHD_NO_FLAG;
177   if (server_interface_->GetConfig().use_debug)
178     flags |= MHD_USE_DEBUG;
179 
180   // Enable IPv6 if supported.
181   if (server_interface_->GetConfig().use_ipv6)
182     flags |= MHD_USE_DUAL_STACK;
183   flags |= MHD_USE_TCP_FASTOPEN;  // Use TCP Fast Open (see RFC 7413).
184   flags |= MHD_USE_SUSPEND_RESUME;  // Allow suspending/resuming connections.
185 
186   // MHD uses timeout of 0 to mean there is no timeout.
187   int timeout = server_interface_->GetConfig().default_request_timeout_seconds;
188   if (timeout < 0)
189     timeout = 0;
190 
191   std::vector<MHD_OptionItem> options{
192     {MHD_OPTION_CONNECTION_LIMIT, 10, nullptr},
193     {MHD_OPTION_CONNECTION_TIMEOUT, timeout, nullptr},
194     {MHD_OPTION_NOTIFY_COMPLETED, callback_addr, nullptr},
195   };
196 
197   if (config->socket_fd != -1) {
198     // Take ownership of the socket.
199     int socket_fd = config->socket_fd;
200     config->socket_fd = -1;
201 
202     // Set some more socket options. These options were set in libmicrohttpd.
203     int on = 1;
204     if (setsockopt(socket_fd, SOL_SOCKET, SO_REUSEADDR, &on, sizeof(on)) < 0) {
205       // Treat this as a non-fatal failure. Just continue after logging.
206       PLOG(WARNING) << "Failed to set SO_REUSEADDR option on listening socket.";
207     }
208     on = (MHD_USE_DUAL_STACK != (flags & MHD_USE_DUAL_STACK));
209     if (setsockopt(socket_fd, IPPROTO_IPV6, IPV6_V6ONLY, &on, sizeof(on)) < 0) {
210       PLOG(WARNING) << "Failed to set IPV6_V6ONLY option on listening socket.";
211       close(socket_fd);
212       return false;
213     }
214 
215     // Bind socket to the port.
216     sockaddr_in6 addr = {};
217     addr.sin6_family = AF_INET6;
218     addr.sin6_port = htons(config->port);
219     if (bind(socket_fd, reinterpret_cast<const sockaddr*>(&addr),
220              sizeof(addr)) < 0) {
221       PLOG(ERROR) << "Failed to bind the socket to port " << config->port;
222       close(socket_fd);
223       return false;
224     }
225     if ((flags & MHD_USE_TCP_FASTOPEN) != 0) {
226       // This is the default value from libmicrohttpd.
227       int fastopen_queue_size = 10;
228       if (setsockopt(socket_fd, IPPROTO_TCP, TCP_FASTOPEN,
229                      &fastopen_queue_size, sizeof(fastopen_queue_size)) < 0) {
230         // Treat this as a non-fatal failure. Just continue after logging.
231         PLOG(WARNING) << "Failed to set TCP_FASTOPEN option on socket.";
232       }
233     }
234 
235     // Start listening on the socket.
236     // 32 connections is the value used by libmicrohttpd.
237     if (listen(socket_fd, 32) < 0) {
238       PLOG(ERROR) << "Failed to listen for connections on the socket.";
239       close(socket_fd);
240       return false;
241     }
242 
243     // Finally, pass the socket to libmicrohttpd.
244     options.push_back(
245         MHD_OptionItem{MHD_OPTION_LISTEN_SOCKET, socket_fd, nullptr});
246   }
247 
248   // libmicrohttpd expects both the key and certificate to be zero-terminated
249   // strings. Make sure they are terminated properly.
250   brillo::SecureBlob private_key_copy = config->private_key;
251   brillo::Blob certificate_copy = config->certificate;
252   private_key_copy.push_back(0);
253   certificate_copy.push_back(0);
254 
255   if (config->use_tls) {
256     flags |= MHD_USE_SSL;
257     options.push_back(
258         MHD_OptionItem{MHD_OPTION_HTTPS_MEM_KEY, 0, private_key_copy.data()});
259     options.push_back(
260         MHD_OptionItem{MHD_OPTION_HTTPS_MEM_CERT, 0, certificate_copy.data()});
261   }
262 
263   options.push_back(MHD_OptionItem{MHD_OPTION_END, 0, nullptr});
264 
265   server_ = MHD_start_daemon(flags, config->port, nullptr, nullptr,
266                              &ServerHelper::ConnectionHandler, this,
267                              MHD_OPTION_ARRAY, options.data(), MHD_OPTION_END);
268   if (!server_) {
269     PLOG(ERROR) << "Failed to create protocol handler on port " << config->port;
270     return false;
271   }
272   server_interface_->ProtocolHandlerStarted(this);
273   DoWork();
274   LOG(INFO) << "Protocol handler started";
275   return true;
276 }
277 
Stop()278 bool ProtocolHandler::Stop() {
279   if (server_) {
280     LOG(INFO) << "Shutting down the protocol handler...";
281     MHD_stop_daemon(server_);
282     server_ = nullptr;
283     server_interface_->ProtocolHandlerStopped(this);
284     LOG(INFO) << "Protocol handler shutdown complete";
285   }
286   port_ = 0;
287   protocol_.clear();
288   certificate_fingerprint_.clear();
289   return true;
290 }
291 
AddRequest(Request * request)292 void ProtocolHandler::AddRequest(Request* request) {
293   requests_.emplace(request->GetID(), request);
294 }
295 
RemoveRequest(Request * request)296 void ProtocolHandler::RemoveRequest(Request* request) {
297   requests_.erase(request->GetID());
298 }
299 
GetRequest(const std::string & request_id) const300 Request* ProtocolHandler::GetRequest(const std::string& request_id) const {
301   auto p = requests_.find(request_id);
302   return (p != requests_.end()) ? p->second : nullptr;
303 }
304 
305 // A file descriptor watcher class that oversees I/O operation notification
306 // on particular socket file descriptor.
307 class ProtocolHandler::Watcher final : public base::MessageLoopForIO::Watcher {
308  public:
Watcher(ProtocolHandler * handler,int fd)309   Watcher(ProtocolHandler* handler, int fd) : fd_{fd}, handler_{handler} {}
310 
Watch(bool read,bool write)311   void Watch(bool read, bool write) {
312     if (read == watching_read_ && write == watching_write_ && !triggered_)
313       return;
314 
315     controller_.StopWatchingFileDescriptor();
316     watching_read_ = read;
317     watching_write_ = write;
318     triggered_ = false;
319 
320     auto mode = base::MessageLoopForIO::WATCH_READ_WRITE;
321     if (watching_read_ && watching_write_)
322       mode = base::MessageLoopForIO::WATCH_READ_WRITE;
323     else if (watching_read_)
324       mode = base::MessageLoopForIO::WATCH_READ;
325     else if (watching_write_)
326       mode = base::MessageLoopForIO::WATCH_WRITE;
327     base::MessageLoopForIO::current()->WatchFileDescriptor(fd_, false, mode,
328                                                            &controller_, this);
329   }
330 
331   // Overrides from base::MessageLoopForIO::Watcher.
OnFileCanReadWithoutBlocking(int)332   void OnFileCanReadWithoutBlocking(int /* fd */) override {
333     triggered_ = true;
334     handler_->ScheduleWork();
335   }
336 
OnFileCanWriteWithoutBlocking(int)337   void OnFileCanWriteWithoutBlocking(int /* fd */) override {
338     triggered_ = true;
339     handler_->ScheduleWork();
340   }
341 
GetFileDescriptor() const342   int GetFileDescriptor() const { return fd_; }
343 
344  private:
345   int fd_{-1};
346   ProtocolHandler* handler_{nullptr};
347   bool watching_read_{false};
348   bool watching_write_{false};
349   bool triggered_{false};
350   base::MessageLoopForIO::FileDescriptorWatcher controller_;
351 
352   DISALLOW_COPY_AND_ASSIGN(Watcher);
353 };
354 
ScheduleWork()355 void ProtocolHandler::ScheduleWork() {
356   if (work_scheduled_)
357     return;
358 
359   work_scheduled_ = true;
360   base::MessageLoopForIO::current()->PostTask(
361       FROM_HERE,
362       base::Bind(&ProtocolHandler::DoWork, weak_ptr_factory_.GetWeakPtr()));
363 }
364 
DoWork()365 void ProtocolHandler::DoWork() {
366   work_scheduled_ = false;
367   weak_ptr_factory_.InvalidateWeakPtrs();
368 
369   // Check if there is any pending work to be done in libmicrohttpd.
370   MHD_run(server_);
371 
372   // Get all the file descriptors from libmicrohttpd and watch for I/O
373   // operations on them.
374   fd_set rs;
375   fd_set ws;
376   fd_set es;
377   int max_fd = MHD_INVALID_SOCKET;
378   FD_ZERO(&rs);
379   FD_ZERO(&ws);
380   FD_ZERO(&es);
381   CHECK_EQ(MHD_YES, MHD_get_fdset(server_, &rs, &ws, &es, &max_fd));
382 
383   for (auto& watcher : watchers_) {
384     int fd = watcher->GetFileDescriptor();
385     if (FD_ISSET(fd, &rs) || FD_ISSET(fd, &ws)) {
386       watcher->Watch(FD_ISSET(fd, &rs), FD_ISSET(fd, &ws));
387       FD_CLR(fd, &rs);
388       FD_CLR(fd, &ws);
389     } else {
390       watcher.reset();
391     }
392   }
393 
394   watchers_.erase(std::remove(watchers_.begin(), watchers_.end(), nullptr),
395                   watchers_.end());
396 
397   for (int fd = 0; fd <= max_fd; fd++) {
398     // libmicrohttpd is not using exception FDs, so lets put our expectations
399     // upfront.
400     CHECK(!FD_ISSET(fd, &es));
401     if (FD_ISSET(fd, &rs) || FD_ISSET(fd, &ws)) {
402       // libmicrohttpd should never use any of stdin/stdout/stderr descriptors.
403       CHECK_GT(fd, STDERR_FILENO);
404       std::unique_ptr<Watcher> watcher{new Watcher{this, fd}};
405       watcher->Watch(FD_ISSET(fd, &rs), FD_ISSET(fd, &ws));
406       watchers_.push_back(std::move(watcher));
407     }
408   }
409 
410   // Schedule a time-out timer, if asked by libmicrohttpd.
411   MHD_UNSIGNED_LONG_LONG mhd_timeout = 0;
412   if (MHD_get_timeout(server_, &mhd_timeout) == MHD_YES) {
413     base::MessageLoopForIO::current()->PostDelayedTask(
414         FROM_HERE,
415         base::Bind(&ProtocolHandler::DoWork, weak_ptr_factory_.GetWeakPtr()),
416         base::TimeDelta::FromMilliseconds(mhd_timeout));
417   }
418 }
419 
420 }  // namespace webservd
421