1 /*
2 * Copyright (C) 2020 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include <host/libs/websocket/websocket_server.h>
18
19 #include <string>
20 #include <unordered_map>
21
22 #include <android-base/logging.h>
23 #include <libwebsockets.h>
24
25 #include <common/libs/utils/files.h>
26 #include <host/libs/websocket/websocket_handler.h>
27
28 namespace cuttlefish {
29 namespace {
30
GetPath(struct lws * wsi)31 std::string GetPath(struct lws* wsi) {
32 auto len = lws_hdr_total_length(wsi, WSI_TOKEN_GET_URI);
33 std::string path(len + 1, '\0');
34 auto ret = lws_hdr_copy(wsi, path.data(), path.size(), WSI_TOKEN_GET_URI);
35 if (ret <= 0) {
36 len = lws_hdr_total_length(wsi, WSI_TOKEN_HTTP_COLON_PATH);
37 path.resize(len + 1, '\0');
38 ret =
39 lws_hdr_copy(wsi, path.data(), path.size(), WSI_TOKEN_HTTP_COLON_PATH);
40 }
41 if (ret < 0) {
42 LOG(FATAL) << "Something went wrong getting the path";
43 }
44 path.resize(len);
45 return path;
46 }
47
48 const std::vector<std::pair<std::string, std::string>> kCORSHeaders = {
49 {"Access-Control-Allow-Origin:", "*"},
50 {"Access-Control-Allow-Methods:", "POST, GET, OPTIONS"},
51 {"Access-Control-Allow-Headers:",
52 "Content-Type, Access-Control-Allow-Headers, Authorization, "
53 "X-Requested-With, Accept"}};
54
AddCORSHeaders(struct lws * wsi,unsigned char ** buffer_ptr,unsigned char * buffer_end)55 bool AddCORSHeaders(struct lws* wsi, unsigned char** buffer_ptr,
56 unsigned char* buffer_end) {
57 for (const auto& header : kCORSHeaders) {
58 const auto& name = header.first;
59 const auto& value = header.second;
60 if (lws_add_http_header_by_name(
61 wsi, reinterpret_cast<const unsigned char*>(name.c_str()),
62 reinterpret_cast<const unsigned char*>(value.c_str()), value.size(),
63 buffer_ptr, buffer_end)) {
64 return false;
65 }
66 }
67 return true;
68 }
69
WriteCommonHttpHeaders(int status,const char * mime_type,size_t content_len,struct lws * wsi)70 bool WriteCommonHttpHeaders(int status, const char* mime_type,
71 size_t content_len, struct lws* wsi) {
72 constexpr size_t BUFF_SIZE = 2048;
73 uint8_t header_buffer[LWS_PRE + BUFF_SIZE];
74 const auto start = &header_buffer[LWS_PRE];
75 auto p = &header_buffer[LWS_PRE];
76 auto end = start + BUFF_SIZE;
77 if (lws_add_http_common_headers(wsi, status, mime_type, content_len, &p,
78 end)) {
79 LOG(ERROR) << "Failed to write headers for response";
80 return false;
81 }
82 if (!AddCORSHeaders(wsi, &p, end)) {
83 LOG(ERROR) << "Failed to write CORS headers for response";
84 return false;
85 }
86 if (lws_finalize_write_http_header(wsi, start, &p, end)) {
87 LOG(ERROR) << "Failed to finalize headers for response";
88 return false;
89 }
90 return true;
91 }
92
93 } // namespace
WebSocketServer(const char * protocol_name,const std::string & assets_dir,int server_port)94 WebSocketServer::WebSocketServer(const char* protocol_name,
95 const std::string& assets_dir, int server_port)
96 : WebSocketServer(protocol_name, "", assets_dir, server_port) {}
97
WebSocketServer(const char * protocol_name,const std::string & certs_dir,const std::string & assets_dir,int server_port)98 WebSocketServer::WebSocketServer(const char* protocol_name,
99 const std::string& certs_dir,
100 const std::string& assets_dir, int server_port)
101 : protocol_name_(protocol_name),
102 assets_dir_(assets_dir),
103 certs_dir_(certs_dir),
104 server_port_(server_port) {}
105
InitializeLwsObjects()106 void WebSocketServer::InitializeLwsObjects() {
107 std::string cert_file = certs_dir_ + "/server.crt";
108 std::string key_file = certs_dir_ + "/server.key";
109 std::string ca_file = certs_dir_ + "/CA.crt";
110
111 retry_ = {
112 .secs_since_valid_ping = 3,
113 .secs_since_valid_hangup = 10,
114 };
115
116 struct lws_protocols protocols[] = //
117 {{
118 .name = protocol_name_.c_str(),
119 .callback = WebsocketCallback,
120 .per_session_data_size = 0,
121 .rx_buffer_size = 0,
122 .id = 0,
123 .user = this,
124 .tx_packet_size = 0,
125 },
126 {
127 .name = "__http_polling__",
128 .callback = DynHttpCallback,
129 .per_session_data_size = 0,
130 .rx_buffer_size = 0,
131 .id = 0,
132 .user = this,
133 .tx_packet_size = 0,
134 },
135 {
136 .name = nullptr,
137 .callback = nullptr,
138 .per_session_data_size = 0,
139 .rx_buffer_size = 0,
140 .id = 0,
141 .user = nullptr,
142 .tx_packet_size = 0,
143 }};
144
145 dyn_mounts_.reserve(dyn_handler_factories_.size());
146 for (auto& handler_entry : dyn_handler_factories_) {
147 auto& path = handler_entry.first;
148 dyn_mounts_.push_back({
149 .mount_next = nullptr,
150 .mountpoint = path.c_str(),
151 .mountpoint_len = static_cast<uint8_t>(path.size()),
152 .origin = "__http_polling__",
153 .def = nullptr,
154 .protocol = nullptr,
155 .cgienv = nullptr,
156 .extra_mimetypes = nullptr,
157 .interpret = nullptr,
158 .cgi_timeout = 0,
159 .cache_max_age = 0,
160 .auth_mask = 0,
161 .cache_reusable = 0,
162 .cache_revalidate = 0,
163 .cache_intermediaries = 0,
164 .origin_protocol = LWSMPRO_CALLBACK, // dynamic
165 .basic_auth_login_file = nullptr,
166 });
167 }
168 struct lws_http_mount* next_mount = nullptr;
169 // Set up the linked list after all the mounts have been created to ensure
170 // pointers are not invalidated.
171 for (auto& mount : dyn_mounts_) {
172 mount.mount_next = next_mount;
173 next_mount = &mount;
174 }
175
176 static_mount_ = {
177 .mount_next = next_mount,
178 .mountpoint = "/",
179 .mountpoint_len = 1,
180 .origin = assets_dir_.c_str(),
181 .def = "index.html",
182 .protocol = nullptr,
183 .cgienv = nullptr,
184 .extra_mimetypes = nullptr,
185 .interpret = nullptr,
186 .cgi_timeout = 0,
187 .cache_max_age = 0,
188 .auth_mask = 0,
189 .cache_reusable = 0,
190 .cache_revalidate = 0,
191 .cache_intermediaries = 0,
192 .origin_protocol = LWSMPRO_FILE, // files in a dir
193 .basic_auth_login_file = nullptr,
194 };
195
196 struct lws_context_creation_info info;
197 headers_ = {NULL, NULL, "content-security-policy:",
198 "default-src 'self' https://ajax.googleapis.com; "
199 "style-src 'self' https://fonts.googleapis.com/; "
200 "font-src https://fonts.gstatic.com/; "};
201
202 memset(&info, 0, sizeof info);
203 info.port = server_port_;
204 info.mounts = &static_mount_;
205 info.protocols = protocols;
206 info.vhost_name = "localhost";
207 info.headers = &headers_;
208 info.retry_and_idle_policy = &retry_;
209
210 if (!certs_dir_.empty()) {
211 info.options |= LWS_SERVER_OPTION_DO_SSL_GLOBAL_INIT;
212 info.ssl_cert_filepath = cert_file.c_str();
213 info.ssl_private_key_filepath = key_file.c_str();
214 if (FileExists(ca_file)) {
215 info.ssl_ca_filepath = ca_file.c_str();
216 }
217 }
218
219 context_ = lws_create_context(&info);
220 if (!context_) {
221 LOG(FATAL) << "Failed to create websocket context";
222 }
223 }
224
RegisterHandlerFactory(const std::string & path,std::unique_ptr<WebSocketHandlerFactory> handler_factory_p)225 void WebSocketServer::RegisterHandlerFactory(
226 const std::string& path,
227 std::unique_ptr<WebSocketHandlerFactory> handler_factory_p) {
228 handler_factories_[path] = std::move(handler_factory_p);
229 }
230
RegisterDynHandlerFactory(const std::string & path,DynHandlerFactory handler_factory)231 void WebSocketServer::RegisterDynHandlerFactory(
232 const std::string& path,
233 DynHandlerFactory handler_factory) {
234 dyn_handler_factories_[path] = std::move(handler_factory);
235 }
236
Serve()237 void WebSocketServer::Serve() {
238 InitializeLwsObjects();
239 int n = 0;
240 while (n >= 0) {
241 n = lws_service(context_, 0);
242 }
243 lws_context_destroy(context_);
244 }
245
WebsocketCallback(struct lws * wsi,enum lws_callback_reasons reason,void * user,void * in,size_t len)246 int WebSocketServer::WebsocketCallback(struct lws* wsi,
247 enum lws_callback_reasons reason,
248 void* user, void* in, size_t len) {
249 auto protocol = lws_get_protocol(wsi);
250 if (!protocol) {
251 // Some callback reasons are always handled by the first protocol, before a
252 // wsi struct is even created.
253 return lws_callback_http_dummy(wsi, reason, user, in, len);
254 }
255 return reinterpret_cast<WebSocketServer*>(protocol->user)
256 ->ServerCallback(wsi, reason, user, in, len);
257 }
258
DynHttpCallback(struct lws * wsi,enum lws_callback_reasons reason,void * user,void * in,size_t len)259 int WebSocketServer::DynHttpCallback(struct lws* wsi,
260 enum lws_callback_reasons reason,
261 void* user, void* in, size_t len) {
262 auto protocol = lws_get_protocol(wsi);
263 if (!protocol) {
264 LOG(ERROR) << "No protocol associated with connection";
265 return 1;
266 }
267 return reinterpret_cast<WebSocketServer*>(protocol->user)
268 ->DynServerCallback(wsi, reason, user, in, len);
269 }
270
DynServerCallback(struct lws * wsi,enum lws_callback_reasons reason,void * user,void * in,size_t len)271 int WebSocketServer::DynServerCallback(struct lws* wsi,
272 enum lws_callback_reasons reason,
273 void* user, void* in, size_t len) {
274 switch (reason) {
275 case LWS_CALLBACK_HTTP: {
276 char* path_raw;
277 int path_len;
278 auto method = lws_http_get_uri_and_method(wsi, &path_raw, &path_len);
279 if (method < 0) {
280 return 1;
281 }
282 std::string path(path_raw, path_len);
283 auto handler = InstantiateDynHandler(path, wsi);
284 if (!handler) {
285 if (!WriteCommonHttpHeaders(static_cast<int>(HttpStatusCode::NotFound),
286 "application/json", 0, wsi)) {
287 return 1;
288 }
289 return lws_http_transaction_completed(wsi);
290 }
291 dyn_handlers_[wsi] = std::move(handler);
292 switch (method) {
293 case LWSHUMETH_GET: {
294 auto status = dyn_handlers_[wsi]->DoGet();
295 if (!WriteCommonHttpHeaders(static_cast<int>(status),
296 "application/json",
297 dyn_handlers_[wsi]->content_len(), wsi)) {
298 return 1;
299 }
300 // Write the response later, when the server is ready
301 lws_callback_on_writable(wsi);
302 break;
303 }
304 case LWSHUMETH_POST:
305 // Do nothing until the body has been read
306 break;
307 case LWSHUMETH_OPTIONS: {
308 // Response for CORS preflight
309 auto status = HttpStatusCode::NoContent;
310 if (!WriteCommonHttpHeaders(static_cast<int>(status), "", 0, wsi)) {
311 return 1;
312 }
313 lws_callback_on_writable(wsi);
314 break;
315 }
316 default:
317 LOG(ERROR) << "Unsupported HTTP method: " << method;
318 return 1;
319 }
320 break;
321 }
322 case LWS_CALLBACK_HTTP_BODY: {
323 auto handler = dyn_handlers_[wsi].get();
324 if (!handler) {
325 LOG(WARNING) << "Received body for unknown wsi";
326 return 1;
327 }
328 handler->AppendDataIn(in, len);
329 break;
330 }
331 case LWS_CALLBACK_HTTP_BODY_COMPLETION: {
332 auto handler = dyn_handlers_[wsi].get();
333 if (!handler) {
334 LOG(WARNING) << "Unexpected body completion event from unknown wsi";
335 return 1;
336 }
337 auto status = handler->DoPost();
338 if (!WriteCommonHttpHeaders(static_cast<int>(status), "application/json",
339 dyn_handlers_[wsi]->content_len(), wsi)) {
340 return 1;
341 }
342 lws_callback_on_writable(wsi);
343 break;
344 }
345 case LWS_CALLBACK_HTTP_WRITEABLE: {
346 auto handler = dyn_handlers_[wsi].get();
347 if (!handler) {
348 LOG(WARNING) << "Unknown wsi became writable";
349 return 1;
350 }
351 auto ret = handler->OnWritable();
352 dyn_handlers_.erase(wsi);
353 // Make sure the connection (in HTTP 1) or stream (in HTTP 2) is closed
354 // after the response is written
355 return ret;
356 }
357 case LWS_CALLBACK_CLOSED_HTTP:
358 break;
359 default:
360 return lws_callback_http_dummy(wsi, reason, user, in, len);
361 }
362 return 0;
363 }
364
ServerCallback(struct lws * wsi,enum lws_callback_reasons reason,void * user,void * in,size_t len)365 int WebSocketServer::ServerCallback(struct lws* wsi,
366 enum lws_callback_reasons reason,
367 void* user, void* in, size_t len) {
368 switch (reason) {
369 case LWS_CALLBACK_ESTABLISHED: {
370 auto path = GetPath(wsi);
371 auto handler = InstantiateHandler(path, wsi);
372 if (!handler) {
373 // This message came on an unexpected uri, close the connection.
374 lws_close_reason(wsi, LWS_CLOSE_STATUS_NOSTATUS, (uint8_t*)"404", 3);
375 return -1;
376 }
377 handlers_[wsi] = handler;
378 handler->OnConnected();
379 break;
380 }
381 case LWS_CALLBACK_CLOSED: {
382 auto handler = handlers_[wsi];
383 if (handler) {
384 handler->OnClosed();
385 handlers_.erase(wsi);
386 }
387 break;
388 }
389 case LWS_CALLBACK_SERVER_WRITEABLE: {
390 auto handler = handlers_[wsi];
391 if (handler) {
392 auto should_close = handler->OnWritable();
393 if (should_close) {
394 lws_close_reason(wsi, LWS_CLOSE_STATUS_NORMAL, nullptr, 0);
395 return 1;
396 }
397 } else {
398 LOG(WARNING) << "Unknown wsi became writable";
399 return -1;
400 }
401 break;
402 }
403 case LWS_CALLBACK_RECEIVE: {
404 auto handler = handlers_[wsi];
405 if (handler) {
406 bool is_final = (lws_remaining_packet_payload(wsi) == 0) &&
407 lws_is_final_fragment(wsi);
408 handler->OnReceive(reinterpret_cast<const uint8_t*>(in), len,
409 lws_frame_is_binary(wsi), is_final);
410 } else {
411 LOG(WARNING) << "Unknown wsi sent data";
412 }
413 break;
414 }
415 default:
416 return lws_callback_http_dummy(wsi, reason, user, in, len);
417 }
418 return 0;
419 }
420
InstantiateHandler(const std::string & uri_path,struct lws * wsi)421 std::shared_ptr<WebSocketHandler> WebSocketServer::InstantiateHandler(
422 const std::string& uri_path, struct lws* wsi) {
423 auto it = handler_factories_.find(uri_path);
424 if (it == handler_factories_.end()) {
425 LOG(ERROR) << "Wrong path provided in URI: " << uri_path;
426 return nullptr;
427 } else {
428 LOG(VERBOSE) << "Creating handler for " << uri_path;
429 return it->second->Build(wsi);
430 }
431 }
432
InstantiateDynHandler(const std::string & uri_path,struct lws * wsi)433 std::unique_ptr<DynHandler> WebSocketServer::InstantiateDynHandler(
434 const std::string& uri_path, struct lws* wsi) {
435 auto it = dyn_handler_factories_.find(uri_path);
436 if (it == dyn_handler_factories_.end()) {
437 LOG(ERROR) << "Wrong path provided in URI: " << uri_path;
438 return nullptr;
439 } else {
440 LOG(VERBOSE) << "Creating handler for " << uri_path;
441 return it->second(wsi);
442 }
443 }
444
445 } // namespace cuttlefish
446