// // Copyright (C) 2020 The Android Open Source Project // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // #include "host/frontend/webrtc/libdevice/server_connection.h" #include #include #include "common/libs/fs/shared_fd.h" #include "common/libs/fs/shared_select.h" #include "common/libs/utils/files.h" namespace cuttlefish { namespace webrtc_streaming { // ServerConnection over Unix socket class UnixServerConnection : public ServerConnection { public: UnixServerConnection(const std::string& addr, std::weak_ptr observer); ~UnixServerConnection() override; bool Send(const Json::Value& msg) override; private: void Connect() override; void StopThread(); void ReadLoop(); const std::string addr_; SharedFD conn_; std::mutex write_mtx_; std::weak_ptr observer_; // The event fd must be declared before the thread to ensure it's initialized // before the thread starts and is safe to be accessed from it. SharedFD thread_notifier_; std::atomic_bool running_ = false; std::thread thread_; }; // ServerConnection using websockets class WsConnectionContext; class WsConnection : public std::enable_shared_from_this { public: struct CreateConnectionSul { lws_sorted_usec_list_t sul = {}; std::weak_ptr weak_this; }; WsConnection(int port, const std::string& addr, const std::string& path, ServerConfig::Security secure, std::weak_ptr observer, std::shared_ptr context); ~WsConnection(); void Connect(); bool Send(const Json::Value& msg); void ConnectInner(); void OnError(const std::string& error); void OnReceive(const uint8_t* data, size_t len, bool is_binary); void OnOpen(); void OnClose(); void OnWriteable(); private: struct WsBuffer { WsBuffer() = default; WsBuffer(const uint8_t* data, size_t len, bool binary) : buffer_(LWS_PRE + len), is_binary_(binary) { memcpy(&buffer_[LWS_PRE], data, len); } uint8_t* data() { return &buffer_[LWS_PRE]; } bool is_binary() const { return is_binary_; } size_t size() const { return buffer_.size() - LWS_PRE; } private: std::vector buffer_; bool is_binary_; }; bool Send(const uint8_t* data, size_t len, bool binary = false); CreateConnectionSul extended_sul_; struct lws* wsi_; const int port_; const std::string addr_; const std::string path_; const ServerConfig::Security security_; std::weak_ptr observer_; // each element contains the data to be sent and whether it's binary or not std::deque write_queue_; std::mutex write_queue_mutex_; // The connection object should not outlive the context object. This reference // guarantees it. std::shared_ptr context_; }; class WsConnectionContext : public std::enable_shared_from_this { public: static std::shared_ptr Create(); WsConnectionContext(struct lws_context* lws_ctx); ~WsConnectionContext(); std::unique_ptr CreateConnection( int port, const std::string& addr, const std::string& path, ServerConfig::Security secure, std::weak_ptr observer); void RememberConnection(void*, std::weak_ptr); void ForgetConnection(void*); std::shared_ptr GetConnection(void*); struct lws_context* lws_context() { return lws_context_; } private: void Start(); std::map> weak_by_ptr_; std::mutex map_mutex_; struct lws_context* lws_context_; std::thread message_loop_; }; std::unique_ptr ServerConnection::Connect( const ServerConfig& conf, std::weak_ptr observer) { std::unique_ptr ret; // If the provided address points to an existing UNIX socket in the file // system connect to it, otherwise assume it's a network address and connect // using websockets if (FileIsSocket(conf.addr)) { ret.reset(new UnixServerConnection(conf.addr, observer)); } else { // This can be a local variable since the ws connection will keep a // reference to it. auto ws_context = WsConnectionContext::Create(); CHECK(ws_context) << "Failed to create websocket context"; ret = ws_context->CreateConnection(conf.port, conf.addr, conf.path, conf.security, observer); } ret->Connect(); return ret; } void ServerConnection::Reconnect() { Connect(); } // UnixServerConnection implementation UnixServerConnection::UnixServerConnection( const std::string& addr, std::weak_ptr observer) : addr_(addr), observer_(observer) {} UnixServerConnection::~UnixServerConnection() { StopThread(); } bool UnixServerConnection::Send(const Json::Value& msg) { Json::StreamWriterBuilder factory; auto str = Json::writeString(factory, msg); std::lock_guard lock(write_mtx_); auto res = conn_->Send(reinterpret_cast(str.c_str()), str.size(), 0); if (res < 0) { LOG(ERROR) << "Failed to send data to signaling server: " << conn_->StrError(); // Don't call OnError() here, the receiving thread probably did it already // or is about to do it. } // A SOCK_SEQPACKET unix socket will send the entire message or fail, but it // won't send a partial message. return res == str.size(); } void UnixServerConnection::Connect() { // The thread could be running if this is a Reconnect StopThread(); conn_ = SharedFD::SocketLocalClient(addr_, false, SOCK_SEQPACKET); if (!conn_->IsOpen()) { LOG(ERROR) << "Failed to connect to unix socket: " << conn_->StrError(); if (auto o = observer_.lock(); o) { o->OnError("Failed to connect to unix socket"); } return; } thread_notifier_ = SharedFD::Event(); if (!thread_notifier_->IsOpen()) { LOG(ERROR) << "Failed to create eventfd for background thread: " << thread_notifier_->StrError(); if (auto o = observer_.lock(); o) { o->OnError("Failed to create eventfd for background thread"); } return; } if (auto o = observer_.lock(); o) { o->OnOpen(); } // Start the thread running_ = true; thread_ = std::thread([this](){ReadLoop();}); } void UnixServerConnection::StopThread() { running_ = false; if (!thread_notifier_->IsOpen()) { // The thread won't be running if this isn't open return; } if (thread_notifier_->EventfdWrite(1) < 0) { LOG(ERROR) << "Failed to notify background thread, this thread may block"; } if (thread_.joinable()) { thread_.join(); } } void UnixServerConnection::ReadLoop() { if (!thread_notifier_->IsOpen()) { LOG(ERROR) << "The UnixServerConnection's background thread is unable to " "receive notifications so it can't run"; return; } std::vector buffer(4096, 0); while (running_) { SharedFDSet rset; rset.Set(thread_notifier_); rset.Set(conn_); auto res = Select(&rset, nullptr, nullptr, nullptr); if (res < 0) { LOG(ERROR) << "Failed to select from background thread"; break; } if (rset.IsSet(thread_notifier_)) { eventfd_t val; auto res = thread_notifier_->EventfdRead(&val); if (res < 0) { LOG(ERROR) << "Error reading from event fd: " << thread_notifier_->StrError(); break; } } if (rset.IsSet(conn_)) { auto size = conn_->Recv(buffer.data(), 0, MSG_TRUNC | MSG_PEEK); if (size > buffer.size()) { // Enlarge enough to accommodate size bytes and be a multiple of 4096 auto new_size = (size + 4095) & ~4095; buffer.resize(new_size); } auto res = conn_->Recv(buffer.data(), buffer.size(), MSG_TRUNC); if (res < 0) { LOG(ERROR) << "Failed to read from server: " << conn_->StrError(); if (auto observer = observer_.lock(); observer) { observer->OnError(conn_->StrError()); } return; } if (res == 0) { auto observer = observer_.lock(); if (observer) { observer->OnClose(); } break; } auto observer = observer_.lock(); if (observer) { observer->OnReceive(buffer.data(), res, false); } } } } // WsConnection implementation int LwsCallback(struct lws* wsi, enum lws_callback_reasons reason, void* user, void* in, size_t len); void CreateConnectionCallback(lws_sorted_usec_list_t* sul); namespace { constexpr char kProtocolName[] = "cf-webrtc-device"; constexpr int kBufferSize = 65536; const uint32_t backoff_ms[] = {1000, 2000, 3000, 4000, 5000}; const lws_retry_bo_t kRetry = { .retry_ms_table = backoff_ms, .retry_ms_table_count = LWS_ARRAY_SIZE(backoff_ms), .conceal_count = LWS_ARRAY_SIZE(backoff_ms), .secs_since_valid_ping = 3, /* force PINGs after secs idle */ .secs_since_valid_hangup = 10, /* hangup after secs idle */ .jitter_percent = 20, }; const struct lws_protocols kProtocols[2] = { {kProtocolName, LwsCallback, 0, kBufferSize, 0, NULL, 0}, {NULL, NULL, 0, 0, 0, NULL, 0}}; } // namespace std::shared_ptr WsConnectionContext::Create() { struct lws_context_creation_info context_info = {}; context_info.port = CONTEXT_PORT_NO_LISTEN; context_info.options = LWS_SERVER_OPTION_DO_SSL_GLOBAL_INIT; context_info.protocols = kProtocols; struct lws_context* lws_ctx = lws_create_context(&context_info); if (!lws_ctx) { return nullptr; } return std::shared_ptr(new WsConnectionContext(lws_ctx)); } WsConnectionContext::WsConnectionContext(struct lws_context* lws_ctx) : lws_context_(lws_ctx) { Start(); } WsConnectionContext::~WsConnectionContext() { lws_context_destroy(lws_context_); if (message_loop_.joinable()) { message_loop_.join(); } } void WsConnectionContext::Start() { message_loop_ = std::thread([this]() { for (;;) { if (lws_service(lws_context_, 0) < 0) { break; } } }); } // This wrapper is needed because the ServerConnection objects are meant to be // referenced by std::unique_ptr but WsConnection needs to be referenced by // std::shared_ptr because it's also (weakly) referenced by the websocket // thread. class WsConnectionWrapper : public ServerConnection { public: WsConnectionWrapper(std::shared_ptr conn) : conn_(conn) {} bool Send(const Json::Value& msg) override { return conn_->Send(msg); } private: void Connect() override { return conn_->Connect(); } std::shared_ptr conn_; }; std::unique_ptr WsConnectionContext::CreateConnection( int port, const std::string& addr, const std::string& path, ServerConfig::Security security, std::weak_ptr observer) { return std::unique_ptr( new WsConnectionWrapper(std::make_shared( port, addr, path, security, observer, shared_from_this()))); } std::shared_ptr WsConnectionContext::GetConnection(void* raw) { std::shared_ptr connection; { std::lock_guard lock(map_mutex_); if (weak_by_ptr_.count(raw) == 0) { return nullptr; } connection = weak_by_ptr_[raw].lock(); if (!connection) { weak_by_ptr_.erase(raw); } } return connection; } void WsConnectionContext::RememberConnection(void* raw, std::weak_ptr conn) { std::lock_guard lock(map_mutex_); weak_by_ptr_.emplace( std::pair>(raw, conn)); } void WsConnectionContext::ForgetConnection(void* raw) { std::lock_guard lock(map_mutex_); weak_by_ptr_.erase(raw); } WsConnection::WsConnection(int port, const std::string& addr, const std::string& path, ServerConfig::Security security, std::weak_ptr observer, std::shared_ptr context) : port_(port), addr_(addr), path_(path), security_(security), observer_(observer), context_(context) {} WsConnection::~WsConnection() { context_->ForgetConnection(this); // This will cause the callback to be called which will drop the connection // after seeing the context doesn't remember this object lws_callback_on_writable(wsi_); } void WsConnection::Connect() { memset(&extended_sul_.sul, 0, sizeof(extended_sul_.sul)); extended_sul_.weak_this = weak_from_this(); lws_sul_schedule(context_->lws_context(), 0, &extended_sul_.sul, CreateConnectionCallback, 1); } void WsConnection::OnError(const std::string& error) { auto observer = observer_.lock(); if (observer) { observer->OnError(error); } } void WsConnection::OnReceive(const uint8_t* data, size_t len, bool is_binary) { auto observer = observer_.lock(); if (observer) { observer->OnReceive(data, len, is_binary); } } void WsConnection::OnOpen() { auto observer = observer_.lock(); if (observer) { observer->OnOpen(); } } void WsConnection::OnClose() { auto observer = observer_.lock(); if (observer) { observer->OnClose(); } } void WsConnection::OnWriteable() { WsBuffer buffer; { std::lock_guard lock(write_queue_mutex_); if (write_queue_.size() == 0) { return; } buffer = std::move(write_queue_.front()); write_queue_.pop_front(); } auto flags = lws_write_ws_flags( buffer.is_binary() ? LWS_WRITE_BINARY : LWS_WRITE_TEXT, true, true); auto res = lws_write(wsi_, buffer.data(), buffer.size(), (enum lws_write_protocol)flags); if (res != buffer.size()) { LOG(WARNING) << "Unable to send the entire message!"; } } bool WsConnection::Send(const Json::Value& msg) { Json::StreamWriterBuilder factory; auto str = Json::writeString(factory, msg); return Send(reinterpret_cast(str.c_str()), str.size()); } bool WsConnection::Send(const uint8_t* data, size_t len, bool binary) { if (!wsi_) { LOG(WARNING) << "Send called on an uninitialized connection!!"; return false; } WsBuffer buffer(data, len, binary); { std::lock_guard lock(write_queue_mutex_); write_queue_.emplace_back(std::move(buffer)); } lws_callback_on_writable(wsi_); return true; } int LwsCallback(struct lws* wsi, enum lws_callback_reasons reason, void* user, void* in, size_t len) { constexpr int DROP = -1; constexpr int OK = 0; // For some values of `reason`, `user` doesn't point to the value provided // when the connection was created. This function object should be used with // care. auto with_connection = [wsi, user](std::function)> cb) { auto context = reinterpret_cast(user); auto connection = context->GetConnection(wsi); if (!connection) { return DROP; } cb(connection); return OK; }; switch (reason) { case LWS_CALLBACK_CLIENT_CONNECTION_ERROR: return with_connection([in](std::shared_ptr connection) { connection->OnError(in ? (char*)in : "(null)"); }); case LWS_CALLBACK_CLIENT_RECEIVE: return with_connection( [in, len, wsi](std::shared_ptr connection) { connection->OnReceive((const uint8_t*)in, len, lws_frame_is_binary(wsi)); }); case LWS_CALLBACK_CLIENT_ESTABLISHED: return with_connection([](std::shared_ptr connection) { connection->OnOpen(); }); case LWS_CALLBACK_CLIENT_CLOSED: return with_connection([](std::shared_ptr connection) { connection->OnClose(); }); case LWS_CALLBACK_CLIENT_WRITEABLE: return with_connection([](std::shared_ptr connection) { connection->OnWriteable(); }); default: LOG(VERBOSE) << "Unhandled value: " << reason; return lws_callback_http_dummy(wsi, reason, user, in, len); } } void CreateConnectionCallback(lws_sorted_usec_list_t* sul) { std::shared_ptr connection = reinterpret_cast(sul) ->weak_this.lock(); if (!connection) { LOG(WARNING) << "The object was already destroyed by the time of the first " << "connection attempt. That's unusual."; return; } connection->ConnectInner(); } void WsConnection::ConnectInner() { struct lws_client_connect_info connect_info; memset(&connect_info, 0, sizeof(connect_info)); connect_info.context = context_->lws_context(); connect_info.port = port_; connect_info.address = addr_.c_str(); connect_info.path = path_.c_str(); connect_info.host = connect_info.address; connect_info.origin = connect_info.address; switch (security_) { case ServerConfig::Security::kAllowSelfSigned: connect_info.ssl_connection = LCCSCF_ALLOW_SELFSIGNED | LCCSCF_SKIP_SERVER_CERT_HOSTNAME_CHECK | LCCSCF_USE_SSL; break; case ServerConfig::Security::kStrict: connect_info.ssl_connection = LCCSCF_USE_SSL; break; case ServerConfig::Security::kInsecure: connect_info.ssl_connection = 0; break; } connect_info.protocol = "webrtc-operator"; connect_info.local_protocol_name = kProtocolName; connect_info.pwsi = &wsi_; connect_info.retry_and_idle_policy = &kRetry; // There is no guarantee the connection object still exists when the callback // is called. Put the context instead as the user data which is guaranteed to // still exist and holds a weak ptr to the connection. connect_info.userdata = context_.get(); if (lws_client_connect_via_info(&connect_info)) { // wsi_ is not initialized until after the call to // lws_client_connect_via_info(). Luckily, this is guaranteed to run before // the protocol callback is called because it runs in the same loop. context_->RememberConnection(wsi_, weak_from_this()); } else { LOG(ERROR) << "Connection failed!"; } } } // namespace webrtc_streaming } // namespace cuttlefish