1 // Copyright 2019 The Chromium Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4 
5 #include "osp/public/presentation/presentation_controller.h"
6 
7 #include <algorithm>
8 #include <sstream>
9 #include <type_traits>
10 
11 #include "absl/types/optional.h"
12 #include "osp/impl/presentation/url_availability_requester.h"
13 #include "osp/msgs/osp_messages.h"
14 #include "osp/public/message_demuxer.h"
15 #include "osp/public/network_service_manager.h"
16 #include "osp/public/protocol_connection_client.h"
17 #include "osp/public/request_response_handler.h"
18 #include "util/osp_logging.h"
19 
20 namespace openscreen {
21 namespace osp {
22 
23 #define DECLARE_MSG_REQUEST_RESPONSE(base_name)                        \
24   using RequestMsgType = msgs::Presentation##base_name##Request;       \
25   using ResponseMsgType = msgs::Presentation##base_name##Response;     \
26                                                                        \
27   static constexpr MessageEncodingFunction<RequestMsgType> kEncoder =  \
28       &msgs::EncodePresentation##base_name##Request;                   \
29   static constexpr MessageDecodingFunction<ResponseMsgType> kDecoder = \
30       &msgs::DecodePresentation##base_name##Response;                  \
31   static constexpr msgs::Type kResponseType =                          \
32       msgs::Type::kPresentation##base_name##Response
33 
34 struct StartRequest {
35   DECLARE_MSG_REQUEST_RESPONSE(Start);
36 
37   msgs::PresentationStartRequest request;
38   RequestDelegate* delegate;
39   Connection::Delegate* presentation_connection_delegate;
40 };
41 
42 struct ConnectionOpenRequest {
43   DECLARE_MSG_REQUEST_RESPONSE(ConnectionOpen);
44 
45   msgs::PresentationConnectionOpenRequest request;
46   RequestDelegate* delegate;
47   Connection::Delegate* presentation_connection_delegate;
48   std::unique_ptr<Connection> connection;
49 };
50 
51 struct ConnectionCloseRequest {
52   DECLARE_MSG_REQUEST_RESPONSE(ConnectionClose);
53 
54   msgs::PresentationConnectionCloseRequest request;
55 };
56 
57 struct TerminationRequest {
58   DECLARE_MSG_REQUEST_RESPONSE(Termination);
59 
60   msgs::PresentationTerminationRequest request;
61 };
62 
63 class Controller::MessageGroupStreams final
64     : public ProtocolConnectionClient::ConnectionRequestCallback,
65       public ProtocolConnection::Observer,
66       public RequestResponseHandler<StartRequest>::Delegate,
67       public RequestResponseHandler<ConnectionOpenRequest>::Delegate,
68       public RequestResponseHandler<ConnectionCloseRequest>::Delegate,
69       public RequestResponseHandler<TerminationRequest>::Delegate {
70  public:
71   MessageGroupStreams(Controller* controller, const std::string& service_id);
72   ~MessageGroupStreams();
73 
74   uint64_t SendStartRequest(StartRequest request);
75   void CancelStartRequest(uint64_t request_id);
76   void OnMatchedResponse(StartRequest* request,
77                          msgs::PresentationStartResponse* response,
78                          uint64_t endpoint_id) override;
79   void OnError(StartRequest* request, Error error) override;
80 
81   uint64_t SendConnectionOpenRequest(ConnectionOpenRequest request);
82   void CancelConnectionOpenRequest(uint64_t request_id);
83   void OnMatchedResponse(ConnectionOpenRequest* request,
84                          msgs::PresentationConnectionOpenResponse* response,
85                          uint64_t endpoint_id) override;
86   void OnError(ConnectionOpenRequest* request, Error error) override;
87 
88   void SendConnectionCloseRequest(ConnectionCloseRequest request);
89   void OnMatchedResponse(ConnectionCloseRequest* request,
90                          msgs::PresentationConnectionCloseResponse* response,
91                          uint64_t endpoint_id) override;
92   void OnError(ConnectionCloseRequest* request, Error error) override;
93 
94   void SendTerminationRequest(TerminationRequest request);
95   void OnMatchedResponse(TerminationRequest* request,
96                          msgs::PresentationTerminationResponse* response,
97                          uint64_t endpoint_id) override;
98   void OnError(TerminationRequest* request, Error error) override;
99 
100   // ProtocolConnectionClient::ConnectionRequestCallback overrides.
101   void OnConnectionOpened(
102       uint64_t request_id,
103       std::unique_ptr<ProtocolConnection> connection) override;
104   void OnConnectionFailed(uint64_t request_id) override;
105 
106   // ProtocolConnection::Observer overrides.
107   void OnConnectionClosed(const ProtocolConnection& connection) override;
108 
109  private:
110   uint64_t GetNextInternalRequestId();
111 
112   Controller* const controller_;
113   const std::string service_id_;
114 
115   uint64_t next_internal_request_id_ = 1;
116   ProtocolConnectionClient::ConnectRequest initiation_connect_request_;
117   std::unique_ptr<ProtocolConnection> initiation_protocol_connection_;
118   ProtocolConnectionClient::ConnectRequest connection_connect_request_;
119   std::unique_ptr<ProtocolConnection> connection_protocol_connection_;
120 
121   // TODO(btolsch): Improve the ergo of QuicClient::Connect because this is bad.
122   bool initiation_connect_request_stack_{false};
123   bool connection_connect_request_stack_{false};
124 
125   RequestResponseHandler<StartRequest> initiation_handler_;
126   RequestResponseHandler<ConnectionOpenRequest> connection_open_handler_;
127   RequestResponseHandler<ConnectionCloseRequest> connection_close_handler_;
128   RequestResponseHandler<TerminationRequest> termination_handler_;
129 };
130 
MessageGroupStreams(Controller * controller,const std::string & service_id)131 Controller::MessageGroupStreams::MessageGroupStreams(
132     Controller* controller,
133     const std::string& service_id)
134     : controller_(controller),
135       service_id_(service_id),
136       initiation_handler_(this),
137       connection_open_handler_(this),
138       connection_close_handler_(this),
139       termination_handler_(this) {}
140 
141 Controller::MessageGroupStreams::~MessageGroupStreams() = default;
142 
SendStartRequest(StartRequest request)143 uint64_t Controller::MessageGroupStreams::SendStartRequest(
144     StartRequest request) {
145   uint64_t request_id = GetNextInternalRequestId();
146   if (!initiation_protocol_connection_ && !initiation_connect_request_) {
147     initiation_connect_request_stack_ = true;
148     initiation_connect_request_ =
149         NetworkServiceManager::Get()->GetProtocolConnectionClient()->Connect(
150             controller_->receiver_endpoints_[service_id_], this);
151     initiation_connect_request_stack_ = false;
152   }
153   initiation_handler_.WriteMessage(request_id, std::move(request));
154   return request_id;
155 }
156 
CancelStartRequest(uint64_t request_id)157 void Controller::MessageGroupStreams::CancelStartRequest(uint64_t request_id) {
158   // TODO(btolsch): Instead, mark the |request_id| for immediate termination if
159   // we get a successful response.
160   initiation_handler_.CancelMessage(request_id);
161 }
162 
OnMatchedResponse(StartRequest * request,msgs::PresentationStartResponse * response,uint64_t endpoint_id)163 void Controller::MessageGroupStreams::OnMatchedResponse(
164     StartRequest* request,
165     msgs::PresentationStartResponse* response,
166     uint64_t endpoint_id) {
167   if (response->result != msgs::PresentationStartResponse_result::kSuccess) {
168     std::stringstream ss;
169     ss << "presentation-start-response for " << request->request.url
170        << " failed: " << static_cast<int>(response->result);
171     Error error(Error::Code::kUnknownStartError, ss.str());
172     OSP_LOG_INFO << error.message();
173     request->delegate->OnError(std::move(error));
174     return;
175   }
176   OSP_LOG_INFO << "presentation started for " << request->request.url;
177   Controller::ControlledPresentation& presentation =
178       controller_->presentations_[request->request.presentation_id];
179   presentation.service_id = service_id_;
180   presentation.url = request->request.url;
181   auto connection = std::make_unique<Connection>(
182       Connection::PresentationInfo{request->request.presentation_id,
183                                    request->request.url},
184       request->presentation_connection_delegate, controller_);
185   controller_->OpenConnection(response->connection_id, endpoint_id, service_id_,
186                               request->delegate, std::move(connection),
187                               NetworkServiceManager::Get()
188                                   ->GetProtocolConnectionClient()
189                                   ->CreateProtocolConnection(endpoint_id));
190 }
191 
OnError(StartRequest * request,Error error)192 void Controller::MessageGroupStreams::OnError(StartRequest* request,
193                                               Error error) {
194   request->delegate->OnError(std::move(error));
195 }
196 
SendConnectionOpenRequest(ConnectionOpenRequest request)197 uint64_t Controller::MessageGroupStreams::SendConnectionOpenRequest(
198     ConnectionOpenRequest request) {
199   uint64_t request_id = GetNextInternalRequestId();
200   if (!connection_protocol_connection_ && !connection_connect_request_) {
201     connection_connect_request_stack_ = true;
202     connection_connect_request_ =
203         NetworkServiceManager::Get()->GetProtocolConnectionClient()->Connect(
204             controller_->receiver_endpoints_[service_id_], this);
205     connection_connect_request_stack_ = false;
206   }
207   connection_open_handler_.WriteMessage(request_id, std::move(request));
208   return request_id;
209 }
210 
CancelConnectionOpenRequest(uint64_t request_id)211 void Controller::MessageGroupStreams::CancelConnectionOpenRequest(
212     uint64_t request_id) {
213   connection_open_handler_.CancelMessage(request_id);
214 }
215 
OnMatchedResponse(ConnectionOpenRequest * request,msgs::PresentationConnectionOpenResponse * response,uint64_t endpoint_id)216 void Controller::MessageGroupStreams::OnMatchedResponse(
217     ConnectionOpenRequest* request,
218     msgs::PresentationConnectionOpenResponse* response,
219     uint64_t endpoint_id) {
220   if (response->result !=
221       msgs::PresentationConnectionOpenResponse_result::kSuccess) {
222     std::stringstream ss;
223     ss << "presentation-connection-open-response for " << request->request.url
224        << " failed: " << static_cast<int>(response->result);
225     Error error(Error::Code::kUnknownStartError, ss.str());
226     OSP_LOG_INFO << error.message();
227     request->delegate->OnError(std::move(error));
228     return;
229   }
230   OSP_LOG_INFO << "presentation connection opened to "
231                << request->request.presentation_id;
232   if (request->presentation_connection_delegate) {
233     request->connection = std::make_unique<Connection>(
234         Connection::PresentationInfo{request->request.presentation_id,
235                                      request->request.url},
236         request->presentation_connection_delegate, controller_);
237   }
238   std::unique_ptr<ProtocolConnection> protocol_connection =
239       NetworkServiceManager::Get()
240           ->GetProtocolConnectionClient()
241           ->CreateProtocolConnection(endpoint_id);
242   request->connection->OnConnected(response->connection_id, endpoint_id,
243                                    std::move(protocol_connection));
244   controller_->AddConnection(request->connection.get());
245   request->delegate->OnConnection(std::move(request->connection));
246 }
247 
OnError(ConnectionOpenRequest * request,Error error)248 void Controller::MessageGroupStreams::OnError(ConnectionOpenRequest* request,
249                                               Error error) {
250   request->delegate->OnError(std::move(error));
251 }
252 
SendConnectionCloseRequest(ConnectionCloseRequest request)253 void Controller::MessageGroupStreams::SendConnectionCloseRequest(
254     ConnectionCloseRequest request) {
255   if (!connection_protocol_connection_ && !connection_connect_request_) {
256     connection_connect_request_stack_ = true;
257     connection_connect_request_ =
258         NetworkServiceManager::Get()->GetProtocolConnectionClient()->Connect(
259             controller_->receiver_endpoints_[service_id_], this);
260     connection_connect_request_stack_ = false;
261   }
262   connection_close_handler_.WriteMessage(std::move(request));
263 }
264 
OnMatchedResponse(ConnectionCloseRequest * request,msgs::PresentationConnectionCloseResponse * response,uint64_t endpoint_id)265 void Controller::MessageGroupStreams::OnMatchedResponse(
266     ConnectionCloseRequest* request,
267     msgs::PresentationConnectionCloseResponse* response,
268     uint64_t endpoint_id) {
269   OSP_LOG_IF(INFO,
270              response->result !=
271                  msgs::PresentationConnectionCloseResponse_result::kSuccess)
272       << "error in presentation-connection-close-response: "
273       << static_cast<int>(response->result);
274 }
275 
OnError(ConnectionCloseRequest * request,Error error)276 void Controller::MessageGroupStreams::OnError(ConnectionCloseRequest* request,
277                                               Error error) {
278   OSP_LOG_INFO << "got error when closing connection "
279                << request->request.connection_id << ": " << error;
280 }
281 
SendTerminationRequest(TerminationRequest request)282 void Controller::MessageGroupStreams::SendTerminationRequest(
283     TerminationRequest request) {
284   if (!initiation_protocol_connection_ && !initiation_connect_request_) {
285     initiation_connect_request_ =
286         NetworkServiceManager::Get()->GetProtocolConnectionClient()->Connect(
287             controller_->receiver_endpoints_[service_id_], this);
288   }
289   termination_handler_.WriteMessage(std::move(request));
290 }
291 
OnMatchedResponse(TerminationRequest * request,msgs::PresentationTerminationResponse * response,uint64_t endpoint_id)292 void Controller::MessageGroupStreams::OnMatchedResponse(
293     TerminationRequest* request,
294     msgs::PresentationTerminationResponse* response,
295     uint64_t endpoint_id) {
296   OSP_VLOG << "got presentation-termination-response for "
297            << request->request.presentation_id << " with result "
298            << static_cast<int>(response->result);
299   controller_->TerminatePresentationById(request->request.presentation_id);
300 }
301 
OnError(TerminationRequest * request,Error error)302 void Controller::MessageGroupStreams::OnError(TerminationRequest* request,
303                                               Error error) {}
304 
OnConnectionOpened(uint64_t request_id,std::unique_ptr<ProtocolConnection> connection)305 void Controller::MessageGroupStreams::OnConnectionOpened(
306     uint64_t request_id,
307     std::unique_ptr<ProtocolConnection> connection) {
308   if ((initiation_connect_request_ &&
309        initiation_connect_request_.request_id() == request_id) ||
310       initiation_connect_request_stack_) {
311     initiation_protocol_connection_ = std::move(connection);
312     initiation_protocol_connection_->SetObserver(this);
313     initiation_connect_request_.MarkComplete();
314     initiation_handler_.SetConnection(initiation_protocol_connection_.get());
315     termination_handler_.SetConnection(initiation_protocol_connection_.get());
316   } else if ((connection_connect_request_ &&
317               connection_connect_request_.request_id() == request_id) ||
318              connection_connect_request_stack_) {
319     connection_protocol_connection_ = std::move(connection);
320     connection_protocol_connection_->SetObserver(this);
321     connection_connect_request_.MarkComplete();
322     connection_open_handler_.SetConnection(
323         connection_protocol_connection_.get());
324     connection_close_handler_.SetConnection(
325         connection_protocol_connection_.get());
326   }
327 }
328 
OnConnectionFailed(uint64_t request_id)329 void Controller::MessageGroupStreams::OnConnectionFailed(uint64_t request_id) {
330   if (initiation_connect_request_ &&
331       initiation_connect_request_.request_id() == request_id) {
332     initiation_connect_request_.MarkComplete();
333     initiation_handler_.Reset();
334     termination_handler_.Reset();
335   } else if (connection_connect_request_ &&
336              connection_connect_request_.request_id() == request_id) {
337     connection_connect_request_.MarkComplete();
338     connection_open_handler_.Reset();
339     connection_close_handler_.Reset();
340   }
341 }
342 
OnConnectionClosed(const ProtocolConnection & connection)343 void Controller::MessageGroupStreams::OnConnectionClosed(
344     const ProtocolConnection& connection) {
345   if (&connection == initiation_protocol_connection_.get()) {
346     initiation_handler_.Reset();
347     termination_handler_.Reset();
348   }
349 }
350 
GetNextInternalRequestId()351 uint64_t Controller::MessageGroupStreams::GetNextInternalRequestId() {
352   return ++next_internal_request_id_;
353 }
354 
355 Controller::ReceiverWatch::ReceiverWatch() = default;
ReceiverWatch(Controller * controller,const std::vector<std::string> & urls,ReceiverObserver * observer)356 Controller::ReceiverWatch::ReceiverWatch(Controller* controller,
357                                          const std::vector<std::string>& urls,
358                                          ReceiverObserver* observer)
359     : urls_(urls), observer_(observer), controller_(controller) {}
360 
ReceiverWatch(Controller::ReceiverWatch && other)361 Controller::ReceiverWatch::ReceiverWatch(
362     Controller::ReceiverWatch&& other) noexcept {
363   swap(*this, other);
364 }
365 
~ReceiverWatch()366 Controller::ReceiverWatch::~ReceiverWatch() {
367   if (observer_) {
368     controller_->CancelReceiverWatch(urls_, observer_);
369   }
370   observer_ = nullptr;
371 }
372 
operator =(Controller::ReceiverWatch other)373 Controller::ReceiverWatch& Controller::ReceiverWatch::operator=(
374     Controller::ReceiverWatch other) {
375   swap(*this, other);
376   return *this;
377 }
378 
swap(Controller::ReceiverWatch & a,Controller::ReceiverWatch & b)379 void swap(Controller::ReceiverWatch& a, Controller::ReceiverWatch& b) {
380   using std::swap;
381   swap(a.urls_, b.urls_);
382   swap(a.observer_, b.observer_);
383   swap(a.controller_, b.controller_);
384 }
385 
386 Controller::ConnectRequest::ConnectRequest() = default;
ConnectRequest(Controller * controller,const std::string & service_id,bool is_reconnect,absl::optional<uint64_t> request_id)387 Controller::ConnectRequest::ConnectRequest(Controller* controller,
388                                            const std::string& service_id,
389                                            bool is_reconnect,
390                                            absl::optional<uint64_t> request_id)
391     : service_id_(service_id),
392       is_reconnect_(is_reconnect),
393       request_id_(request_id),
394       controller_(controller) {}
395 
ConnectRequest(ConnectRequest && other)396 Controller::ConnectRequest::ConnectRequest(ConnectRequest&& other) noexcept {
397   swap(*this, other);
398 }
399 
~ConnectRequest()400 Controller::ConnectRequest::~ConnectRequest() {
401   if (request_id_) {
402     controller_->CancelConnectRequest(service_id_, is_reconnect_,
403                                       request_id_.value());
404   }
405   request_id_ = 0;
406 }
407 
operator =(ConnectRequest other)408 Controller::ConnectRequest& Controller::ConnectRequest::operator=(
409     ConnectRequest other) {
410   swap(*this, other);
411   return *this;
412 }
413 
swap(Controller::ConnectRequest & a,Controller::ConnectRequest & b)414 void swap(Controller::ConnectRequest& a, Controller::ConnectRequest& b) {
415   using std::swap;
416   swap(a.service_id_, b.service_id_);
417   swap(a.is_reconnect_, b.is_reconnect_);
418   swap(a.request_id_, b.request_id_);
419   swap(a.controller_, b.controller_);
420 }
421 
Controller(ClockNowFunctionPtr now_function)422 Controller::Controller(ClockNowFunctionPtr now_function) {
423   availability_requester_ =
424       std::make_unique<UrlAvailabilityRequester>(now_function);
425   connection_manager_ =
426       std::make_unique<ConnectionManager>(NetworkServiceManager::Get()
427                                               ->GetProtocolConnectionClient()
428                                               ->message_demuxer());
429   const std::vector<ServiceInfo>& receivers =
430       NetworkServiceManager::Get()->GetMdnsServiceListener()->GetReceivers();
431   for (const auto& info : receivers) {
432     // TODO(crbug.com/openscreen/33): Replace service_id with endpoint_id when
433     // endpoint_id is more than just an IPEndpoint counter and actually relates
434     // to a device's identity.
435     receiver_endpoints_.emplace(info.service_id, info.v4_endpoint.port
436                                                      ? info.v4_endpoint
437                                                      : info.v6_endpoint);
438     availability_requester_->AddReceiver(info);
439   }
440   // TODO(btolsch): This is for |receiver_endpoints_|, but this should really be
441   // tracked elsewhere so it's available to other protocols as well.
442   NetworkServiceManager::Get()->GetMdnsServiceListener()->AddObserver(this);
443 }
444 
~Controller()445 Controller::~Controller() {
446   connection_manager_.reset();
447   NetworkServiceManager::Get()->GetMdnsServiceListener()->RemoveObserver(this);
448 }
449 
RegisterReceiverWatch(const std::vector<std::string> & urls,ReceiverObserver * observer)450 Controller::ReceiverWatch Controller::RegisterReceiverWatch(
451     const std::vector<std::string>& urls,
452     ReceiverObserver* observer) {
453   availability_requester_->AddObserver(urls, observer);
454   return ReceiverWatch(this, urls, observer);
455 }
456 
StartPresentation(const std::string & url,const std::string & service_id,RequestDelegate * delegate,Connection::Delegate * conn_delegate)457 Controller::ConnectRequest Controller::StartPresentation(
458     const std::string& url,
459     const std::string& service_id,
460     RequestDelegate* delegate,
461     Connection::Delegate* conn_delegate) {
462   StartRequest request;
463   request.request.url = url;
464   request.request.presentation_id = MakePresentationId(url, service_id);
465   request.delegate = delegate;
466   request.presentation_connection_delegate = conn_delegate;
467   uint64_t request_id =
468       group_streams_[service_id]->SendStartRequest(std::move(request));
469   constexpr bool is_reconnect = false;
470   return ConnectRequest(this, service_id, is_reconnect, request_id);
471 }
472 
ReconnectPresentation(const std::vector<std::string> & urls,const std::string & presentation_id,const std::string & service_id,RequestDelegate * delegate,Connection::Delegate * conn_delegate)473 Controller::ConnectRequest Controller::ReconnectPresentation(
474     const std::vector<std::string>& urls,
475     const std::string& presentation_id,
476     const std::string& service_id,
477     RequestDelegate* delegate,
478     Connection::Delegate* conn_delegate) {
479   auto presentation_entry = presentations_.find(presentation_id);
480   if (presentation_entry == presentations_.end()) {
481     delegate->OnError(Error::Code::kNoPresentationFound);
482     return ConnectRequest();
483   }
484   auto matching_url_it =
485       std::find(urls.begin(), urls.end(), presentation_entry->second.url);
486   if (matching_url_it == urls.end()) {
487     delegate->OnError(Error::Code::kNoPresentationFound);
488     return ConnectRequest();
489   }
490   ConnectionOpenRequest request;
491   request.request.url = presentation_entry->second.url;
492   request.request.presentation_id = presentation_id;
493   request.delegate = delegate;
494   request.presentation_connection_delegate = conn_delegate;
495   request.connection = nullptr;
496   uint64_t request_id =
497       group_streams_[service_id]->SendConnectionOpenRequest(std::move(request));
498   constexpr bool is_reconnect = true;
499   return ConnectRequest(this, service_id, is_reconnect, request_id);
500 }
501 
ReconnectConnection(std::unique_ptr<Connection> connection,RequestDelegate * delegate)502 Controller::ConnectRequest Controller::ReconnectConnection(
503     std::unique_ptr<Connection> connection,
504     RequestDelegate* delegate) {
505   if (connection->state() != Connection::State::kClosed) {
506     delegate->OnError(Error::Code::kInvalidConnectionState);
507     return ConnectRequest();
508   }
509   const Connection::PresentationInfo& info = connection->presentation_info();
510   auto presentation_entry = presentations_.find(info.id);
511   if (presentation_entry == presentations_.end() ||
512       presentation_entry->second.url != info.url) {
513     OSP_LOG_ERROR << "missing ControlledPresentation for non-terminated "
514                      "connection with info ("
515                   << info.id << ", " << info.url << ")";
516     delegate->OnError(Error::Code::kNoPresentationFound);
517     return ConnectRequest();
518   }
519   OSP_DCHECK(connection_manager_->GetConnection(connection->connection_id()))
520       << "otherwise valid connection for reconnect is unknown to the "
521          "connection manager";
522   connection_manager_->RemoveConnection(connection.get());
523   connection->OnConnecting();
524   ConnectionOpenRequest request;
525   request.request.url = info.url;
526   request.request.presentation_id = info.id;
527   request.delegate = delegate;
528   request.presentation_connection_delegate = nullptr;
529   request.connection = std::move(connection);
530   const std::string& service_id = presentation_entry->second.service_id;
531   uint64_t request_id =
532       group_streams_[service_id]->SendConnectionOpenRequest(std::move(request));
533   constexpr bool is_reconnect = true;
534   return ConnectRequest(this, service_id, is_reconnect, request_id);
535 }
536 
CloseConnection(Connection * connection,Connection::CloseReason reason)537 Error Controller::CloseConnection(Connection* connection,
538                                   Connection::CloseReason reason) {
539   auto presentation_entry =
540       presentations_.find(connection->presentation_info().id);
541   if (presentation_entry == presentations_.end()) {
542     std::stringstream ss;
543     ss << "no presentation found when trying to close connection "
544        << connection->presentation_info().id << ":"
545        << connection->connection_id();
546     return Error(Error::Code::kNoPresentationFound, ss.str());
547   }
548   ConnectionCloseRequest request;
549   request.request.connection_id = connection->connection_id();
550   group_streams_[presentation_entry->second.service_id]
551       ->SendConnectionCloseRequest(std::move(request));
552   return Error::None();
553 }
554 
OnPresentationTerminated(const std::string & presentation_id,TerminationReason reason)555 Error Controller::OnPresentationTerminated(const std::string& presentation_id,
556                                            TerminationReason reason) {
557   auto presentation_entry = presentations_.find(presentation_id);
558   if (presentation_entry == presentations_.end()) {
559     return Error::Code::kNoPresentationFound;
560   }
561   ControlledPresentation& presentation = presentation_entry->second;
562   for (auto* connection : presentation.connections) {
563     connection->OnTerminated();
564   }
565   TerminationRequest request;
566   request.request.presentation_id = presentation_id;
567   request.request.reason =
568       msgs::PresentationTerminationRequest_reason::kUserTerminatedViaController;
569   group_streams_[presentation.service_id]->SendTerminationRequest(
570       std::move(request));
571   presentations_.erase(presentation_entry);
572   termination_listener_by_id_.erase(presentation_id);
573   return Error::None();
574 }
575 
OnConnectionDestroyed(Connection * connection)576 void Controller::OnConnectionDestroyed(Connection* connection) {
577   auto presentation_entry =
578       presentations_.find(connection->presentation_info().id);
579   if (presentation_entry == presentations_.end()) {
580     return;
581   }
582 
583   std::vector<Connection*>& connections =
584       presentation_entry->second.connections;
585 
586   connections.erase(
587       std::remove(connections.begin(), connections.end(), connection),
588       connections.end());
589 
590   connection_manager_->RemoveConnection(connection);
591 }
592 
GetServiceIdForPresentationId(const std::string & presentation_id) const593 std::string Controller::GetServiceIdForPresentationId(
594     const std::string& presentation_id) const {
595   auto presentation_entry = presentations_.find(presentation_id);
596   if (presentation_entry == presentations_.end()) {
597     return "";
598   }
599   return presentation_entry->second.service_id;
600 }
601 
GetConnectionRequestGroupStream(const std::string & service_id)602 ProtocolConnection* Controller::GetConnectionRequestGroupStream(
603     const std::string& service_id) {
604   OSP_UNIMPLEMENTED();
605   return nullptr;
606 }
607 
OnError(ServiceListenerError)608 void Controller::OnError(ServiceListenerError) {}
OnMetrics(ServiceListener::Metrics)609 void Controller::OnMetrics(ServiceListener::Metrics) {}
610 
611 class Controller::TerminationListener final
612     : public MessageDemuxer::MessageCallback {
613  public:
614   TerminationListener(Controller* controller,
615                       const std::string& presentation_id,
616                       uint64_t endpoint_id);
617   ~TerminationListener() override;
618 
619   // MessageDemuxer::MessageCallback overrides.
620   ErrorOr<size_t> OnStreamMessage(uint64_t endpoint_id,
621                                   uint64_t connection_id,
622                                   msgs::Type message_type,
623                                   const uint8_t* buffer,
624                                   size_t buffer_size,
625                                   Clock::time_point now) override;
626 
627  private:
628   Controller* const controller_;
629   std::string presentation_id_;
630   MessageDemuxer::MessageWatch event_watch_;
631 };
632 
TerminationListener(Controller * controller,const std::string & presentation_id,uint64_t endpoint_id)633 Controller::TerminationListener::TerminationListener(
634     Controller* controller,
635     const std::string& presentation_id,
636     uint64_t endpoint_id)
637     : controller_(controller), presentation_id_(presentation_id) {
638   event_watch_ =
639       NetworkServiceManager::Get()
640           ->GetProtocolConnectionClient()
641           ->message_demuxer()
642           ->WatchMessageType(endpoint_id,
643                              msgs::Type::kPresentationTerminationEvent, this);
644 }
645 
646 Controller::TerminationListener::~TerminationListener() = default;
647 
OnStreamMessage(uint64_t endpoint_id,uint64_t connection_id,msgs::Type message_type,const uint8_t * buffer,size_t buffer_size,Clock::time_point now)648 ErrorOr<size_t> Controller::TerminationListener::OnStreamMessage(
649     uint64_t endpoint_id,
650     uint64_t connection_id,
651     msgs::Type message_type,
652     const uint8_t* buffer,
653     size_t buffer_size,
654     Clock::time_point now) {
655   OSP_CHECK_EQ(static_cast<int>(msgs::Type::kPresentationTerminationEvent),
656                static_cast<int>(message_type));
657   msgs::PresentationTerminationEvent event;
658   ssize_t result =
659       msgs::DecodePresentationTerminationEvent(buffer, buffer_size, &event);
660   if (result < 0) {
661     OSP_LOG_WARN << "decode presentation-termination-event error: " << result;
662     return Error::Code::kCborParsing;
663   } else if (event.presentation_id != presentation_id_) {
664     OSP_LOG_WARN << "got presentation-termination-event for wrong id: "
665                  << presentation_id_ << " vs. " << event.presentation_id;
666     return result;
667   }
668   OSP_LOG_INFO << "termination event";
669   auto presentation_entry =
670       controller_->presentations_.find(event.presentation_id);
671   if (presentation_entry != controller_->presentations_.end()) {
672     for (auto* connection : presentation_entry->second.connections)
673       connection->OnTerminated();
674     controller_->presentations_.erase(presentation_entry);
675   }
676   controller_->termination_listener_by_id_.erase(event.presentation_id);
677   return result;
678 }
679 
680 // static
MakePresentationId(const std::string & url,const std::string & service_id)681 std::string Controller::MakePresentationId(const std::string& url,
682                                            const std::string& service_id) {
683   // TODO(btolsch): This is just a placeholder for the demo. It should
684   // eventually become a GUID/unguessable token routine.
685   std::string safe_id = service_id;
686   for (auto& c : safe_id)
687     if (c < ' ' || c > '~')
688       c = '.';
689   return safe_id + ":" + url;
690 }
691 
AddConnection(Connection * connection)692 void Controller::AddConnection(Connection* connection) {
693   connection_manager_->AddConnection(connection);
694 }
695 
OpenConnection(uint64_t connection_id,uint64_t endpoint_id,const std::string & service_id,RequestDelegate * request_delegate,std::unique_ptr<Connection> && connection,std::unique_ptr<ProtocolConnection> && protocol_connection)696 void Controller::OpenConnection(
697     uint64_t connection_id,
698     uint64_t endpoint_id,
699     const std::string& service_id,
700     RequestDelegate* request_delegate,
701     std::unique_ptr<Connection>&& connection,
702     std::unique_ptr<ProtocolConnection>&& protocol_connection) {
703   connection->OnConnected(connection_id, endpoint_id,
704                           std::move(protocol_connection));
705   const std::string& presentation_id = connection->presentation_info().id;
706   auto presentation_entry = presentations_.find(presentation_id);
707   if (presentation_entry == presentations_.end()) {
708     auto emplace_entry = presentations_.emplace(
709         presentation_id,
710         ControlledPresentation{
711             service_id, connection->presentation_info().url, {}});
712     presentation_entry = emplace_entry.first;
713   }
714   ControlledPresentation& presentation = presentation_entry->second;
715   presentation.connections.push_back(connection.get());
716   AddConnection(connection.get());
717 
718   auto terminate_entry = termination_listener_by_id_.find(presentation_id);
719   if (terminate_entry == termination_listener_by_id_.end()) {
720     termination_listener_by_id_.emplace(
721         presentation_id, std::make_unique<TerminationListener>(
722                              this, presentation_id, endpoint_id));
723   }
724   request_delegate->OnConnection(std::move(connection));
725 }
726 
TerminatePresentationById(const std::string & presentation_id)727 void Controller::TerminatePresentationById(const std::string& presentation_id) {
728   auto presentation_entry = presentations_.find(presentation_id);
729   if (presentation_entry != presentations_.end()) {
730     for (auto* connection : presentation_entry->second.connections) {
731       connection->OnTerminated();
732     }
733     presentations_.erase(presentation_entry);
734   }
735 }
736 
CancelReceiverWatch(const std::vector<std::string> & urls,ReceiverObserver * observer)737 void Controller::CancelReceiverWatch(const std::vector<std::string>& urls,
738                                      ReceiverObserver* observer) {
739   availability_requester_->RemoveObserverUrls(urls, observer);
740 }
741 
CancelConnectRequest(const std::string & service_id,bool is_reconnect,uint64_t request_id)742 void Controller::CancelConnectRequest(const std::string& service_id,
743                                       bool is_reconnect,
744                                       uint64_t request_id) {
745   auto group_streams_entry = group_streams_.find(service_id);
746   if (group_streams_entry == group_streams_.end())
747     return;
748   if (is_reconnect) {
749     group_streams_entry->second->CancelConnectionOpenRequest(request_id);
750   } else {
751     group_streams_entry->second->CancelStartRequest(request_id);
752   }
753 }
754 
OnStarted()755 void Controller::OnStarted() {}
OnStopped()756 void Controller::OnStopped() {}
OnSuspended()757 void Controller::OnSuspended() {}
OnSearching()758 void Controller::OnSearching() {}
759 
OnReceiverAdded(const ServiceInfo & info)760 void Controller::OnReceiverAdded(const ServiceInfo& info) {
761   receiver_endpoints_.emplace(info.service_id, info.v4_endpoint.port
762                                                    ? info.v4_endpoint
763                                                    : info.v6_endpoint);
764   auto group_streams =
765       std::make_unique<MessageGroupStreams>(this, info.service_id);
766   group_streams_[info.service_id] = std::move(group_streams);
767   availability_requester_->AddReceiver(info);
768 }
769 
OnReceiverChanged(const ServiceInfo & info)770 void Controller::OnReceiverChanged(const ServiceInfo& info) {
771   receiver_endpoints_[info.service_id] =
772       info.v4_endpoint.port ? info.v4_endpoint : info.v6_endpoint;
773   availability_requester_->ChangeReceiver(info);
774 }
775 
OnReceiverRemoved(const ServiceInfo & info)776 void Controller::OnReceiverRemoved(const ServiceInfo& info) {
777   receiver_endpoints_.erase(info.service_id);
778   group_streams_.erase(info.service_id);
779   availability_requester_->RemoveReceiver(info);
780 }
781 
OnAllReceiversRemoved()782 void Controller::OnAllReceiversRemoved() {
783   receiver_endpoints_.clear();
784   availability_requester_->RemoveAllReceivers();
785 }
786 
787 }  // namespace osp
788 }  // namespace openscreen
789