1 //
2 // Copyright (C) 2020 The Android Open Source Project
3 //
4 // Licensed under the Apache License, Version 2.0 (the "License");
5 // you may not use this file except in compliance with the License.
6 // You may obtain a copy of the License at
7 //
8 //      http://www.apache.org/licenses/LICENSE-2.0
9 //
10 // Unless required by applicable law or agreed to in writing, software
11 // distributed under the License is distributed on an "AS IS" BASIS,
12 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 // See the License for the specific language governing permissions and
14 // limitations under the License.
15 
16 #include "host/frontend/webrtc_operator/client_handler.h"
17 
18 #include <algorithm>
19 #include <random>
20 
21 #include <android-base/logging.h>
22 
23 #include "host/frontend/webrtc_operator/constants/signaling_constants.h"
24 #include "host/frontend/webrtc_operator/device_handler.h"
25 
26 namespace cuttlefish {
27 
28 namespace {
RandomClientSecret(size_t len)29 std::string RandomClientSecret(size_t len) {
30   static constexpr auto chars =
31       "0123456789"
32       "ABCDEFGHIJKLMNOPQRSTUVWXYZ"
33       "abcdefghijklmnopqrstuvwxyz";
34   std::string ret(len, '\0');
35   std::default_random_engine e{std::random_device{}()};
36   std::uniform_int_distribution<int> random{
37       0, static_cast<int>(std::strlen(chars)) - 1};
38   std::generate_n(ret.begin(), len, [&]() { return chars[random(e)]; });
39   return ret;
40 }
41 }
42 
ClientWSHandler(struct lws * wsi,DeviceRegistry * registry,const ServerConfig & server_config)43 ClientWSHandler::ClientWSHandler(struct lws* wsi, DeviceRegistry* registry,
44                              const ServerConfig& server_config)
45     : SignalHandler(wsi, registry, server_config),
46       device_handler_(),
47       client_id_(0) {}
48 
OnClosed()49 void ClientWSHandler::OnClosed() {
50   auto device_handler = device_handler_.lock();
51   if (device_handler) {
52     device_handler->SendClientDisconnectMessage(client_id_);
53   }
54 }
55 
SendDeviceMessage(const Json::Value & device_message)56 void ClientWSHandler::SendDeviceMessage(const Json::Value& device_message) {
57   Json::Value message;
58   message[webrtc_signaling::kTypeField] = webrtc_signaling::kDeviceMessageType;
59   message[webrtc_signaling::kPayloadField] = device_message;
60   Reply(message);
61 }
62 
handleMessage(const std::string & type,const Json::Value & message)63 void ClientWSHandler::handleMessage(const std::string& type,
64                                   const Json::Value& message) {
65   if (type == webrtc_signaling::kConnectType) {
66     handleConnectionRequest(message);
67   } else if (type == webrtc_signaling::kForwardType) {
68     handleForward(message);
69   } else {
70     LogAndReplyError("Unknown message type: " + type);
71   }
72 }
73 
handleConnectionRequest(const Json::Value & message)74 void ClientWSHandler::handleConnectionRequest(const Json::Value& message) {
75   if (client_id_ > 0) {
76     LogAndReplyError(
77         "Attempt to connect to multiple devices over same websocket");
78     Close();
79     return;
80   }
81   if (!message.isMember(webrtc_signaling::kDeviceIdField) ||
82       !message[webrtc_signaling::kDeviceIdField].isString()) {
83     LogAndReplyError("Invalid connection request: Missing device id");
84     Close();
85     return;
86   }
87   auto device_id = message[webrtc_signaling::kDeviceIdField].asString();
88   // Always send the server config back, even if the requested device is not
89   // registered. Applications may put clients on hold until the device is ready
90   // to connect.
91   SendServerConfig();
92 
93   auto device_handler = registry_->GetDevice(device_id);
94   if (!device_handler) {
95     LogAndReplyError("Connection failed: Device not found: '" + device_id +
96                      "'");
97     Close();
98     return;
99   }
100 
101   client_id_ = device_handler->RegisterClient(shared_from_this());
102   device_handler_ = device_handler;
103   Json::Value device_info_reply;
104   device_info_reply[webrtc_signaling::kTypeField] =
105       webrtc_signaling::kDeviceInfoType;
106   device_info_reply[webrtc_signaling::kDeviceInfoField] =
107       device_handler->device_info();
108   Reply(device_info_reply);
109 }
110 
handleForward(const Json::Value & message)111 void ClientWSHandler::handleForward(const Json::Value& message) {
112   if (client_id_ == 0) {
113     LogAndReplyError("Forward failed: No device associated to client");
114     Close();
115     return;
116   }
117   if (!message.isMember(webrtc_signaling::kPayloadField)) {
118     LogAndReplyError("Forward failed: No payload present in message");
119     Close();
120     return;
121   }
122   auto device_handler = device_handler_.lock();
123   if (!device_handler) {
124     LogAndReplyError("Forward failed: Device disconnected");
125     // Disconnect this client since the device is gone
126     Close();
127     return;
128   }
129   device_handler->SendClientMessage(client_id_,
130                                     message[webrtc_signaling::kPayloadField]);
131 }
132 
ClientWSHandlerFactory(DeviceRegistry * registry,const ServerConfig & server_config)133 ClientWSHandlerFactory::ClientWSHandlerFactory(DeviceRegistry* registry,
134                                            const ServerConfig& server_config)
135   : registry_(registry),
136     server_config_(server_config) {}
137 
Build(struct lws * wsi)138 std::shared_ptr<WebSocketHandler> ClientWSHandlerFactory::Build(struct lws* wsi) {
139   return std::shared_ptr<WebSocketHandler>(
140       new ClientWSHandler(wsi, registry_, server_config_));
141 }
142 
143 /******************************************************************************/
144 
145 class PollConnectionHandler : public ClientHandler {
146  public:
147   PollConnectionHandler() = default;
148 
SendDeviceMessage(const Json::Value & message)149   void SendDeviceMessage(const Json::Value& message) override {
150     constexpr size_t kMaxMessagesInQueue = 1000;
151     if (messages_.size() > kMaxMessagesInQueue) {
152       LOG(ERROR) << "Polling client " << client_id_ << " reached "
153                  << kMaxMessagesInQueue
154                  << " messages queued. Started to drop messages.";
155       return;
156     }
157     messages_.push_back(message);
158   }
159 
PollMessages()160   std::vector<Json::Value> PollMessages() {
161     std::vector<Json::Value> ret;
162     std::swap(ret, messages_);
163     return ret;
164   }
165 
SetDeviceHandler(std::weak_ptr<DeviceHandler> device_handler)166   void SetDeviceHandler(std::weak_ptr<DeviceHandler> device_handler) {
167     device_handler_ = device_handler;
168   }
169 
SetClientId(size_t client_id)170   void SetClientId(size_t client_id) { client_id_ = client_id; }
171 
client_id() const172   size_t client_id() const { return client_id_; }
device_handler() const173   std::shared_ptr<DeviceHandler> device_handler() const {
174     return device_handler_.lock();
175   }
176 
177  private:
178   size_t client_id_ = 0;
179   std::weak_ptr<DeviceHandler> device_handler_;
180   std::vector<Json::Value> messages_;
181 };
182 
Get(const std::string & conn_id) const183 std::shared_ptr<PollConnectionHandler> PollConnectionStore::Get(
184     const std::string& conn_id) const {
185   if (!handlers_.count(conn_id)) {
186     return nullptr;
187   }
188   return handlers_.at(conn_id);
189 }
190 
Add(std::shared_ptr<PollConnectionHandler> handler)191 std::string PollConnectionStore::Add(std::shared_ptr<PollConnectionHandler> handler) {
192   std::string conn_id;
193   do {
194     conn_id = RandomClientSecret(64);
195   } while (handlers_.count(conn_id));
196   handlers_[conn_id] = handler;
197   return conn_id;
198 }
199 
ClientDynHandler(struct lws * wsi,PollConnectionStore * poll_store)200 ClientDynHandler::ClientDynHandler(struct lws* wsi,
201                                    PollConnectionStore* poll_store)
202     : DynHandler(wsi), poll_store_(poll_store) {}
203 
DoGet()204 HttpStatusCode ClientDynHandler::DoGet() {
205   // No message from the client uses the GET method because all of them
206   // change the server state somehow
207   return HttpStatusCode::MethodNotAllowed;
208 }
209 
Reply(const Json::Value & json)210 void ClientDynHandler::Reply(const Json::Value& json) {
211   Json::StreamWriterBuilder factory;
212   auto replyAsString = Json::writeString(factory, json);
213   AppendDataOut(replyAsString);
214 }
215 
ReplyError(const std::string & message)216 void ClientDynHandler::ReplyError(const std::string& message) {
217   LOG(ERROR) << message;
218   Json::Value reply;
219   reply["type"] = "error";
220   reply["error"] = message;
221   Reply(reply);
222 }
223 
DoPost()224 HttpStatusCode ClientDynHandler::DoPost() {
225   auto& data = GetDataIn();
226   Json::Value json_message;
227   std::shared_ptr<PollConnectionHandler> poll_handler;
228   if (data.size() > 0) {
229     Json::CharReaderBuilder builder;
230     std::unique_ptr<Json::CharReader> json_reader(builder.newCharReader());
231     std::string error_message;
232     if (!json_reader->parse(data.c_str(), data.c_str() + data.size(), &json_message,
233                             &error_message)) {
234       ReplyError("Error parsing JSON: " + error_message);
235       // Rate limiting would be a good idea here
236       return HttpStatusCode::BadRequest;
237     }
238 
239     std::string conn_id;
240     if (json_message.isMember(webrtc_signaling::kClientSecretField)) {
241       conn_id =
242           json_message[webrtc_signaling::kClientSecretField].asString();
243       poll_handler = poll_store_->Get(conn_id);
244       if (!poll_handler) {
245         ReplyError("Error: Unknown connection id" + conn_id);
246         return HttpStatusCode::Unauthorized;
247       }
248     }
249   }
250   return DoPostInner(poll_handler, json_message);
251 }
252 
Poll(std::shared_ptr<PollConnectionHandler> poll_handler)253 HttpStatusCode ClientDynHandler::Poll(
254     std::shared_ptr<PollConnectionHandler> poll_handler) {
255   if (!poll_handler) {
256     ReplyError("Poll failed: No device associated to client");
257     return HttpStatusCode::Unauthorized;
258   }
259   auto messages = poll_handler->PollMessages();
260   Json::Value reply(Json::arrayValue);
261   for (auto& msg : messages) {
262     reply.append(msg);
263   }
264   Reply(reply);
265   return HttpStatusCode::Ok;
266 }
267 
ConnectHandler(struct lws * wsi,DeviceRegistry * registry,PollConnectionStore * poll_store)268 ConnectHandler::ConnectHandler(struct lws* wsi, DeviceRegistry* registry,
269                                PollConnectionStore* poll_store)
270     : ClientDynHandler(wsi, poll_store), registry_(registry) {}
271 
DoPostInner(std::shared_ptr<PollConnectionHandler> poll_handler,const Json::Value & message)272 HttpStatusCode ConnectHandler::DoPostInner(
273     std::shared_ptr<PollConnectionHandler> poll_handler,
274     const Json::Value& message) {
275   if (!message.isMember(webrtc_signaling::kDeviceIdField) ||
276       !message[webrtc_signaling::kDeviceIdField].isString()) {
277     ReplyError("Invalid connection request: Missing device id");
278     return HttpStatusCode::BadRequest;
279   }
280   auto device_id = message[webrtc_signaling::kDeviceIdField].asString();
281 
282   auto device_handler = registry_->GetDevice(device_id);
283   if (!device_handler) {
284     ReplyError("Connection failed: Device not found: '" + device_id + "'");
285     return HttpStatusCode::NotFound;
286   }
287 
288   poll_handler = std::make_shared<PollConnectionHandler>();
289   poll_handler->SetClientId(device_handler->RegisterClient(poll_handler));
290   poll_handler->SetDeviceHandler(device_handler);
291   auto conn_id = poll_store_->Add(poll_handler);
292 
293   Json::Value device_info_reply;
294   device_info_reply[webrtc_signaling::kClientSecretField] = conn_id;
295   device_info_reply[webrtc_signaling::kTypeField] =
296       webrtc_signaling::kDeviceInfoType;
297   device_info_reply[webrtc_signaling::kDeviceInfoField] =
298       device_handler->device_info();
299   Reply(device_info_reply);
300 
301   return HttpStatusCode::Ok;
302 }
303 
ForwardHandler(struct lws * wsi,PollConnectionStore * poll_store)304 ForwardHandler::ForwardHandler(struct lws* wsi,
305                                PollConnectionStore* poll_store)
306     : ClientDynHandler(wsi, poll_store) {}
307 
DoPostInner(std::shared_ptr<PollConnectionHandler> poll_handler,const Json::Value & message)308 HttpStatusCode ForwardHandler::DoPostInner(
309     std::shared_ptr<PollConnectionHandler> poll_handler,
310     const Json::Value& message) {
311   if (!poll_handler) {
312     ReplyError("Forward failed: No device associated to client");
313     return HttpStatusCode::Unauthorized;
314   }
315   auto client_id = poll_handler->client_id();
316   if (client_id == 0) {
317     ReplyError("Forward failed: No device associated to client");
318     return HttpStatusCode::Unauthorized;
319   }
320   if (!message.isMember(webrtc_signaling::kPayloadField)) {
321     ReplyError("Forward failed: No payload present in message");
322     return HttpStatusCode::BadRequest;
323   }
324   auto device_handler = poll_handler->device_handler();
325   if (!device_handler) {
326     ReplyError("Forward failed: Device disconnected");
327     return HttpStatusCode::NotFound;
328   }
329   device_handler->SendClientMessage(client_id,
330                                     message[webrtc_signaling::kPayloadField]);
331   // Don't waste an HTTP session returning nothing, send any pending device
332   // messages to the client instead.
333   return Poll(poll_handler);
334 }
335 
PollHandler(struct lws * wsi,PollConnectionStore * poll_store)336 PollHandler::PollHandler(struct lws* wsi, PollConnectionStore* poll_store)
337     : ClientDynHandler(wsi, poll_store) {}
338 
DoPostInner(std::shared_ptr<PollConnectionHandler> poll_handler,const Json::Value &)339 HttpStatusCode PollHandler::DoPostInner(
340     std::shared_ptr<PollConnectionHandler> poll_handler,
341     const Json::Value& /*message*/) {
342   return Poll(poll_handler);
343 }
344 
ConfigHandler(struct lws * wsi,const ServerConfig & server_config)345 ConfigHandler::ConfigHandler(struct lws* wsi, const ServerConfig& server_config)
346     : DynHandler(wsi), server_config_(server_config) {}
347 
DoGet()348 HttpStatusCode ConfigHandler::DoGet() {
349   Json::Value reply = server_config_.ToJson();
350   reply[webrtc_signaling::kTypeField] = webrtc_signaling::kConfigType;
351   Json::StreamWriterBuilder factory;
352   auto replyAsString = Json::writeString(factory, reply);
353   AppendDataOut(replyAsString);
354   return HttpStatusCode::Ok;
355 }
356 
DoPost()357 HttpStatusCode ConfigHandler::DoPost() {
358   return HttpStatusCode::MethodNotAllowed;
359 }
360 
361 }  // namespace cuttlefish
362