1 /*
2  * Copyright (C) 2020 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include <host/libs/websocket/websocket_server.h>
18 
19 #include <string>
20 #include <unordered_map>
21 
22 #include <android-base/logging.h>
23 #include <libwebsockets.h>
24 
25 #include <common/libs/utils/files.h>
26 #include <host/libs/websocket/websocket_handler.h>
27 
28 namespace cuttlefish {
29 namespace {
30 
GetPath(struct lws * wsi)31 std::string GetPath(struct lws* wsi) {
32   auto len = lws_hdr_total_length(wsi, WSI_TOKEN_GET_URI);
33   std::string path(len + 1, '\0');
34   auto ret = lws_hdr_copy(wsi, path.data(), path.size(), WSI_TOKEN_GET_URI);
35   if (ret <= 0) {
36     len = lws_hdr_total_length(wsi, WSI_TOKEN_HTTP_COLON_PATH);
37     path.resize(len + 1, '\0');
38     ret =
39         lws_hdr_copy(wsi, path.data(), path.size(), WSI_TOKEN_HTTP_COLON_PATH);
40   }
41   if (ret < 0) {
42     LOG(FATAL) << "Something went wrong getting the path";
43   }
44   path.resize(len);
45   return path;
46 }
47 
48 const std::vector<std::pair<std::string, std::string>> kCORSHeaders = {
49     {"Access-Control-Allow-Origin:", "*"},
50     {"Access-Control-Allow-Methods:", "POST, GET, OPTIONS"},
51     {"Access-Control-Allow-Headers:",
52      "Content-Type, Access-Control-Allow-Headers, Authorization, "
53      "X-Requested-With, Accept"}};
54 
AddCORSHeaders(struct lws * wsi,unsigned char ** buffer_ptr,unsigned char * buffer_end)55 bool AddCORSHeaders(struct lws* wsi, unsigned char** buffer_ptr,
56                     unsigned char* buffer_end) {
57   for (const auto& header : kCORSHeaders) {
58     const auto& name = header.first;
59     const auto& value = header.second;
60     if (lws_add_http_header_by_name(
61             wsi, reinterpret_cast<const unsigned char*>(name.c_str()),
62             reinterpret_cast<const unsigned char*>(value.c_str()), value.size(),
63             buffer_ptr, buffer_end)) {
64       return false;
65     }
66   }
67   return true;
68 }
69 
WriteCommonHttpHeaders(int status,const char * mime_type,size_t content_len,struct lws * wsi)70 bool WriteCommonHttpHeaders(int status, const char* mime_type,
71                             size_t content_len, struct lws* wsi) {
72   constexpr size_t BUFF_SIZE = 2048;
73   uint8_t header_buffer[LWS_PRE + BUFF_SIZE];
74   const auto start = &header_buffer[LWS_PRE];
75   auto p = &header_buffer[LWS_PRE];
76   auto end = start + BUFF_SIZE;
77   if (lws_add_http_common_headers(wsi, status, mime_type, content_len, &p,
78                                   end)) {
79     LOG(ERROR) << "Failed to write headers for response";
80     return false;
81   }
82   if (!AddCORSHeaders(wsi, &p, end)) {
83     LOG(ERROR) << "Failed to write CORS headers for response";
84     return false;
85   }
86   if (lws_finalize_write_http_header(wsi, start, &p, end)) {
87     LOG(ERROR) << "Failed to finalize headers for response";
88     return false;
89   }
90   return true;
91 }
92 
93 }  // namespace
WebSocketServer(const char * protocol_name,const std::string & assets_dir,int server_port)94 WebSocketServer::WebSocketServer(const char* protocol_name,
95                                  const std::string& assets_dir, int server_port)
96     : WebSocketServer(protocol_name, "", assets_dir, server_port) {}
97 
WebSocketServer(const char * protocol_name,const std::string & certs_dir,const std::string & assets_dir,int server_port)98 WebSocketServer::WebSocketServer(const char* protocol_name,
99                                  const std::string& certs_dir,
100                                  const std::string& assets_dir, int server_port)
101     : protocol_name_(protocol_name),
102       assets_dir_(assets_dir),
103       certs_dir_(certs_dir),
104       server_port_(server_port) {}
105 
InitializeLwsObjects()106 void WebSocketServer::InitializeLwsObjects() {
107   std::string cert_file = certs_dir_ + "/server.crt";
108   std::string key_file = certs_dir_ + "/server.key";
109   std::string ca_file = certs_dir_ + "/CA.crt";
110 
111   retry_ = {
112       .secs_since_valid_ping = 3,
113       .secs_since_valid_hangup = 10,
114   };
115 
116   struct lws_protocols protocols[] =  //
117       {{
118            .name = protocol_name_.c_str(),
119            .callback = WebsocketCallback,
120            .per_session_data_size = 0,
121            .rx_buffer_size = 0,
122            .id = 0,
123            .user = this,
124            .tx_packet_size = 0,
125        },
126        {
127            .name = "__http_polling__",
128            .callback = DynHttpCallback,
129            .per_session_data_size = 0,
130            .rx_buffer_size = 0,
131            .id = 0,
132            .user = this,
133            .tx_packet_size = 0,
134        },
135        {
136            .name = nullptr,
137            .callback = nullptr,
138            .per_session_data_size = 0,
139            .rx_buffer_size = 0,
140            .id = 0,
141            .user = nullptr,
142            .tx_packet_size = 0,
143        }};
144 
145   dyn_mounts_.reserve(dyn_handler_factories_.size());
146   for (auto& handler_entry : dyn_handler_factories_) {
147     auto& path = handler_entry.first;
148     dyn_mounts_.push_back({
149         .mount_next = nullptr,
150         .mountpoint = path.c_str(),
151         .mountpoint_len = static_cast<uint8_t>(path.size()),
152         .origin = "__http_polling__",
153         .def = nullptr,
154         .protocol = nullptr,
155         .cgienv = nullptr,
156         .extra_mimetypes = nullptr,
157         .interpret = nullptr,
158         .cgi_timeout = 0,
159         .cache_max_age = 0,
160         .auth_mask = 0,
161         .cache_reusable = 0,
162         .cache_revalidate = 0,
163         .cache_intermediaries = 0,
164         .origin_protocol = LWSMPRO_CALLBACK,  // dynamic
165         .basic_auth_login_file = nullptr,
166     });
167   }
168   struct lws_http_mount* next_mount = nullptr;
169   // Set up the linked list after all the mounts have been created to ensure
170   // pointers are not invalidated.
171   for (auto& mount : dyn_mounts_) {
172     mount.mount_next = next_mount;
173     next_mount = &mount;
174   }
175 
176   static_mount_ = {
177       .mount_next = next_mount,
178       .mountpoint = "/",
179       .mountpoint_len = 1,
180       .origin = assets_dir_.c_str(),
181       .def = "index.html",
182       .protocol = nullptr,
183       .cgienv = nullptr,
184       .extra_mimetypes = nullptr,
185       .interpret = nullptr,
186       .cgi_timeout = 0,
187       .cache_max_age = 0,
188       .auth_mask = 0,
189       .cache_reusable = 0,
190       .cache_revalidate = 0,
191       .cache_intermediaries = 0,
192       .origin_protocol = LWSMPRO_FILE,  // files in a dir
193       .basic_auth_login_file = nullptr,
194   };
195 
196   struct lws_context_creation_info info;
197   headers_ = {NULL, NULL, "content-security-policy:",
198               "default-src 'self' https://ajax.googleapis.com; "
199               "style-src 'self' https://fonts.googleapis.com/; "
200               "font-src  https://fonts.gstatic.com/; "};
201 
202   memset(&info, 0, sizeof info);
203   info.port = server_port_;
204   info.mounts = &static_mount_;
205   info.protocols = protocols;
206   info.vhost_name = "localhost";
207   info.headers = &headers_;
208   info.retry_and_idle_policy = &retry_;
209 
210   if (!certs_dir_.empty()) {
211     info.options |= LWS_SERVER_OPTION_DO_SSL_GLOBAL_INIT;
212     info.ssl_cert_filepath = cert_file.c_str();
213     info.ssl_private_key_filepath = key_file.c_str();
214     if (FileExists(ca_file)) {
215       info.ssl_ca_filepath = ca_file.c_str();
216     }
217   }
218 
219   context_ = lws_create_context(&info);
220   if (!context_) {
221     LOG(FATAL) << "Failed to create websocket context";
222   }
223 }
224 
RegisterHandlerFactory(const std::string & path,std::unique_ptr<WebSocketHandlerFactory> handler_factory_p)225 void WebSocketServer::RegisterHandlerFactory(
226     const std::string& path,
227     std::unique_ptr<WebSocketHandlerFactory> handler_factory_p) {
228   handler_factories_[path] = std::move(handler_factory_p);
229 }
230 
RegisterDynHandlerFactory(const std::string & path,DynHandlerFactory handler_factory)231 void WebSocketServer::RegisterDynHandlerFactory(
232     const std::string& path,
233     DynHandlerFactory handler_factory) {
234   dyn_handler_factories_[path] = std::move(handler_factory);
235 }
236 
Serve()237 void WebSocketServer::Serve() {
238   InitializeLwsObjects();
239   int n = 0;
240   while (n >= 0) {
241     n = lws_service(context_, 0);
242   }
243   lws_context_destroy(context_);
244 }
245 
WebsocketCallback(struct lws * wsi,enum lws_callback_reasons reason,void * user,void * in,size_t len)246 int WebSocketServer::WebsocketCallback(struct lws* wsi,
247                                        enum lws_callback_reasons reason,
248                                        void* user, void* in, size_t len) {
249   auto protocol = lws_get_protocol(wsi);
250   if (!protocol) {
251     // Some callback reasons are always handled by the first protocol, before a
252     // wsi struct is even created.
253     return lws_callback_http_dummy(wsi, reason, user, in, len);
254   }
255   return reinterpret_cast<WebSocketServer*>(protocol->user)
256       ->ServerCallback(wsi, reason, user, in, len);
257 }
258 
DynHttpCallback(struct lws * wsi,enum lws_callback_reasons reason,void * user,void * in,size_t len)259 int WebSocketServer::DynHttpCallback(struct lws* wsi,
260                                      enum lws_callback_reasons reason,
261                                      void* user, void* in, size_t len) {
262   auto protocol = lws_get_protocol(wsi);
263   if (!protocol) {
264     LOG(ERROR) << "No protocol associated with connection";
265     return 1;
266   }
267   return reinterpret_cast<WebSocketServer*>(protocol->user)
268       ->DynServerCallback(wsi, reason, user, in, len);
269 }
270 
DynServerCallback(struct lws * wsi,enum lws_callback_reasons reason,void * user,void * in,size_t len)271 int WebSocketServer::DynServerCallback(struct lws* wsi,
272                                        enum lws_callback_reasons reason,
273                                        void* user, void* in, size_t len) {
274   switch (reason) {
275     case LWS_CALLBACK_HTTP: {
276       char* path_raw;
277       int path_len;
278       auto method = lws_http_get_uri_and_method(wsi, &path_raw, &path_len);
279       if (method < 0) {
280         return 1;
281       }
282       std::string path(path_raw, path_len);
283       auto handler = InstantiateDynHandler(path, wsi);
284       if (!handler) {
285         if (!WriteCommonHttpHeaders(static_cast<int>(HttpStatusCode::NotFound),
286                                     "application/json", 0, wsi)) {
287           return 1;
288         }
289         return lws_http_transaction_completed(wsi);
290       }
291       dyn_handlers_[wsi] = std::move(handler);
292       switch (method) {
293         case LWSHUMETH_GET: {
294           auto status = dyn_handlers_[wsi]->DoGet();
295           if (!WriteCommonHttpHeaders(static_cast<int>(status),
296                                       "application/json",
297                                       dyn_handlers_[wsi]->content_len(), wsi)) {
298             return 1;
299           }
300           // Write the response later, when the server is ready
301           lws_callback_on_writable(wsi);
302           break;
303         }
304         case LWSHUMETH_POST:
305           // Do nothing until the body has been read
306           break;
307         case LWSHUMETH_OPTIONS: {
308           // Response for CORS preflight
309           auto status = HttpStatusCode::NoContent;
310           if (!WriteCommonHttpHeaders(static_cast<int>(status), "", 0, wsi)) {
311             return 1;
312           }
313           lws_callback_on_writable(wsi);
314           break;
315         }
316         default:
317           LOG(ERROR) << "Unsupported HTTP method: " << method;
318           return 1;
319       }
320       break;
321     }
322     case LWS_CALLBACK_HTTP_BODY: {
323       auto handler = dyn_handlers_[wsi].get();
324       if (!handler) {
325         LOG(WARNING) << "Received body for unknown wsi";
326         return 1;
327       }
328       handler->AppendDataIn(in, len);
329       break;
330     }
331     case LWS_CALLBACK_HTTP_BODY_COMPLETION: {
332       auto handler = dyn_handlers_[wsi].get();
333       if (!handler) {
334         LOG(WARNING) << "Unexpected body completion event from unknown wsi";
335         return 1;
336       }
337       auto status = handler->DoPost();
338       if (!WriteCommonHttpHeaders(static_cast<int>(status), "application/json",
339                                   dyn_handlers_[wsi]->content_len(), wsi)) {
340         return 1;
341       }
342       lws_callback_on_writable(wsi);
343       break;
344     }
345     case LWS_CALLBACK_HTTP_WRITEABLE: {
346       auto handler = dyn_handlers_[wsi].get();
347       if (!handler) {
348         LOG(WARNING) << "Unknown wsi became writable";
349         return 1;
350       }
351       auto ret = handler->OnWritable();
352       dyn_handlers_.erase(wsi);
353       // Make sure the connection (in HTTP 1) or stream (in HTTP 2) is closed
354       // after the response is written
355       return ret;
356     }
357     case LWS_CALLBACK_CLOSED_HTTP:
358       break;
359     default:
360       return lws_callback_http_dummy(wsi, reason, user, in, len);
361   }
362   return 0;
363 }
364 
ServerCallback(struct lws * wsi,enum lws_callback_reasons reason,void * user,void * in,size_t len)365 int WebSocketServer::ServerCallback(struct lws* wsi,
366                                     enum lws_callback_reasons reason,
367                                     void* user, void* in, size_t len) {
368   switch (reason) {
369     case LWS_CALLBACK_ESTABLISHED: {
370       auto path = GetPath(wsi);
371       auto handler = InstantiateHandler(path, wsi);
372       if (!handler) {
373         // This message came on an unexpected uri, close the connection.
374         lws_close_reason(wsi, LWS_CLOSE_STATUS_NOSTATUS, (uint8_t*)"404", 3);
375         return -1;
376       }
377       handlers_[wsi] = handler;
378       handler->OnConnected();
379       break;
380     }
381     case LWS_CALLBACK_CLOSED: {
382       auto handler = handlers_[wsi];
383       if (handler) {
384         handler->OnClosed();
385         handlers_.erase(wsi);
386       }
387       break;
388     }
389     case LWS_CALLBACK_SERVER_WRITEABLE: {
390       auto handler = handlers_[wsi];
391       if (handler) {
392         auto should_close = handler->OnWritable();
393         if (should_close) {
394           lws_close_reason(wsi, LWS_CLOSE_STATUS_NORMAL, nullptr, 0);
395           return 1;
396         }
397       } else {
398         LOG(WARNING) << "Unknown wsi became writable";
399         return -1;
400       }
401       break;
402     }
403     case LWS_CALLBACK_RECEIVE: {
404       auto handler = handlers_[wsi];
405       if (handler) {
406         bool is_final = (lws_remaining_packet_payload(wsi) == 0) &&
407                         lws_is_final_fragment(wsi);
408         handler->OnReceive(reinterpret_cast<const uint8_t*>(in), len,
409                            lws_frame_is_binary(wsi), is_final);
410       } else {
411         LOG(WARNING) << "Unknown wsi sent data";
412       }
413       break;
414     }
415     default:
416       return lws_callback_http_dummy(wsi, reason, user, in, len);
417   }
418   return 0;
419 }
420 
InstantiateHandler(const std::string & uri_path,struct lws * wsi)421 std::shared_ptr<WebSocketHandler> WebSocketServer::InstantiateHandler(
422     const std::string& uri_path, struct lws* wsi) {
423   auto it = handler_factories_.find(uri_path);
424   if (it == handler_factories_.end()) {
425     LOG(ERROR) << "Wrong path provided in URI: " << uri_path;
426     return nullptr;
427   } else {
428     LOG(VERBOSE) << "Creating handler for " << uri_path;
429     return it->second->Build(wsi);
430   }
431 }
432 
InstantiateDynHandler(const std::string & uri_path,struct lws * wsi)433 std::unique_ptr<DynHandler> WebSocketServer::InstantiateDynHandler(
434     const std::string& uri_path, struct lws* wsi) {
435   auto it = dyn_handler_factories_.find(uri_path);
436   if (it == dyn_handler_factories_.end()) {
437     LOG(ERROR) << "Wrong path provided in URI: " << uri_path;
438     return nullptr;
439   } else {
440     LOG(VERBOSE) << "Creating handler for " << uri_path;
441     return it->second(wsi);
442   }
443 }
444 
445 }  // namespace cuttlefish
446