1 // Copyright 2019 The Chromium Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4
5 #include "osp/public/presentation/presentation_controller.h"
6
7 #include <algorithm>
8 #include <sstream>
9 #include <type_traits>
10
11 #include "absl/types/optional.h"
12 #include "osp/impl/presentation/url_availability_requester.h"
13 #include "osp/msgs/osp_messages.h"
14 #include "osp/public/message_demuxer.h"
15 #include "osp/public/network_service_manager.h"
16 #include "osp/public/protocol_connection_client.h"
17 #include "osp/public/request_response_handler.h"
18 #include "util/osp_logging.h"
19
20 namespace openscreen {
21 namespace osp {
22
23 #define DECLARE_MSG_REQUEST_RESPONSE(base_name) \
24 using RequestMsgType = msgs::Presentation##base_name##Request; \
25 using ResponseMsgType = msgs::Presentation##base_name##Response; \
26 \
27 static constexpr MessageEncodingFunction<RequestMsgType> kEncoder = \
28 &msgs::EncodePresentation##base_name##Request; \
29 static constexpr MessageDecodingFunction<ResponseMsgType> kDecoder = \
30 &msgs::DecodePresentation##base_name##Response; \
31 static constexpr msgs::Type kResponseType = \
32 msgs::Type::kPresentation##base_name##Response
33
34 struct StartRequest {
35 DECLARE_MSG_REQUEST_RESPONSE(Start);
36
37 msgs::PresentationStartRequest request;
38 RequestDelegate* delegate;
39 Connection::Delegate* presentation_connection_delegate;
40 };
41
42 struct ConnectionOpenRequest {
43 DECLARE_MSG_REQUEST_RESPONSE(ConnectionOpen);
44
45 msgs::PresentationConnectionOpenRequest request;
46 RequestDelegate* delegate;
47 Connection::Delegate* presentation_connection_delegate;
48 std::unique_ptr<Connection> connection;
49 };
50
51 struct ConnectionCloseRequest {
52 DECLARE_MSG_REQUEST_RESPONSE(ConnectionClose);
53
54 msgs::PresentationConnectionCloseRequest request;
55 };
56
57 struct TerminationRequest {
58 DECLARE_MSG_REQUEST_RESPONSE(Termination);
59
60 msgs::PresentationTerminationRequest request;
61 };
62
63 class Controller::MessageGroupStreams final
64 : public ProtocolConnectionClient::ConnectionRequestCallback,
65 public ProtocolConnection::Observer,
66 public RequestResponseHandler<StartRequest>::Delegate,
67 public RequestResponseHandler<ConnectionOpenRequest>::Delegate,
68 public RequestResponseHandler<ConnectionCloseRequest>::Delegate,
69 public RequestResponseHandler<TerminationRequest>::Delegate {
70 public:
71 MessageGroupStreams(Controller* controller, const std::string& service_id);
72 ~MessageGroupStreams();
73
74 uint64_t SendStartRequest(StartRequest request);
75 void CancelStartRequest(uint64_t request_id);
76 void OnMatchedResponse(StartRequest* request,
77 msgs::PresentationStartResponse* response,
78 uint64_t endpoint_id) override;
79 void OnError(StartRequest* request, Error error) override;
80
81 uint64_t SendConnectionOpenRequest(ConnectionOpenRequest request);
82 void CancelConnectionOpenRequest(uint64_t request_id);
83 void OnMatchedResponse(ConnectionOpenRequest* request,
84 msgs::PresentationConnectionOpenResponse* response,
85 uint64_t endpoint_id) override;
86 void OnError(ConnectionOpenRequest* request, Error error) override;
87
88 void SendConnectionCloseRequest(ConnectionCloseRequest request);
89 void OnMatchedResponse(ConnectionCloseRequest* request,
90 msgs::PresentationConnectionCloseResponse* response,
91 uint64_t endpoint_id) override;
92 void OnError(ConnectionCloseRequest* request, Error error) override;
93
94 void SendTerminationRequest(TerminationRequest request);
95 void OnMatchedResponse(TerminationRequest* request,
96 msgs::PresentationTerminationResponse* response,
97 uint64_t endpoint_id) override;
98 void OnError(TerminationRequest* request, Error error) override;
99
100 // ProtocolConnectionClient::ConnectionRequestCallback overrides.
101 void OnConnectionOpened(
102 uint64_t request_id,
103 std::unique_ptr<ProtocolConnection> connection) override;
104 void OnConnectionFailed(uint64_t request_id) override;
105
106 // ProtocolConnection::Observer overrides.
107 void OnConnectionClosed(const ProtocolConnection& connection) override;
108
109 private:
110 uint64_t GetNextInternalRequestId();
111
112 Controller* const controller_;
113 const std::string service_id_;
114
115 uint64_t next_internal_request_id_ = 1;
116 ProtocolConnectionClient::ConnectRequest initiation_connect_request_;
117 std::unique_ptr<ProtocolConnection> initiation_protocol_connection_;
118 ProtocolConnectionClient::ConnectRequest connection_connect_request_;
119 std::unique_ptr<ProtocolConnection> connection_protocol_connection_;
120
121 // TODO(btolsch): Improve the ergo of QuicClient::Connect because this is bad.
122 bool initiation_connect_request_stack_{false};
123 bool connection_connect_request_stack_{false};
124
125 RequestResponseHandler<StartRequest> initiation_handler_;
126 RequestResponseHandler<ConnectionOpenRequest> connection_open_handler_;
127 RequestResponseHandler<ConnectionCloseRequest> connection_close_handler_;
128 RequestResponseHandler<TerminationRequest> termination_handler_;
129 };
130
MessageGroupStreams(Controller * controller,const std::string & service_id)131 Controller::MessageGroupStreams::MessageGroupStreams(
132 Controller* controller,
133 const std::string& service_id)
134 : controller_(controller),
135 service_id_(service_id),
136 initiation_handler_(this),
137 connection_open_handler_(this),
138 connection_close_handler_(this),
139 termination_handler_(this) {}
140
141 Controller::MessageGroupStreams::~MessageGroupStreams() = default;
142
SendStartRequest(StartRequest request)143 uint64_t Controller::MessageGroupStreams::SendStartRequest(
144 StartRequest request) {
145 uint64_t request_id = GetNextInternalRequestId();
146 if (!initiation_protocol_connection_ && !initiation_connect_request_) {
147 initiation_connect_request_stack_ = true;
148 initiation_connect_request_ =
149 NetworkServiceManager::Get()->GetProtocolConnectionClient()->Connect(
150 controller_->receiver_endpoints_[service_id_], this);
151 initiation_connect_request_stack_ = false;
152 }
153 initiation_handler_.WriteMessage(request_id, std::move(request));
154 return request_id;
155 }
156
CancelStartRequest(uint64_t request_id)157 void Controller::MessageGroupStreams::CancelStartRequest(uint64_t request_id) {
158 // TODO(btolsch): Instead, mark the |request_id| for immediate termination if
159 // we get a successful response.
160 initiation_handler_.CancelMessage(request_id);
161 }
162
OnMatchedResponse(StartRequest * request,msgs::PresentationStartResponse * response,uint64_t endpoint_id)163 void Controller::MessageGroupStreams::OnMatchedResponse(
164 StartRequest* request,
165 msgs::PresentationStartResponse* response,
166 uint64_t endpoint_id) {
167 if (response->result != msgs::PresentationStartResponse_result::kSuccess) {
168 std::stringstream ss;
169 ss << "presentation-start-response for " << request->request.url
170 << " failed: " << static_cast<int>(response->result);
171 Error error(Error::Code::kUnknownStartError, ss.str());
172 OSP_LOG_INFO << error.message();
173 request->delegate->OnError(std::move(error));
174 return;
175 }
176 OSP_LOG_INFO << "presentation started for " << request->request.url;
177 Controller::ControlledPresentation& presentation =
178 controller_->presentations_[request->request.presentation_id];
179 presentation.service_id = service_id_;
180 presentation.url = request->request.url;
181 auto connection = std::make_unique<Connection>(
182 Connection::PresentationInfo{request->request.presentation_id,
183 request->request.url},
184 request->presentation_connection_delegate, controller_);
185 controller_->OpenConnection(response->connection_id, endpoint_id, service_id_,
186 request->delegate, std::move(connection),
187 NetworkServiceManager::Get()
188 ->GetProtocolConnectionClient()
189 ->CreateProtocolConnection(endpoint_id));
190 }
191
OnError(StartRequest * request,Error error)192 void Controller::MessageGroupStreams::OnError(StartRequest* request,
193 Error error) {
194 request->delegate->OnError(std::move(error));
195 }
196
SendConnectionOpenRequest(ConnectionOpenRequest request)197 uint64_t Controller::MessageGroupStreams::SendConnectionOpenRequest(
198 ConnectionOpenRequest request) {
199 uint64_t request_id = GetNextInternalRequestId();
200 if (!connection_protocol_connection_ && !connection_connect_request_) {
201 connection_connect_request_stack_ = true;
202 connection_connect_request_ =
203 NetworkServiceManager::Get()->GetProtocolConnectionClient()->Connect(
204 controller_->receiver_endpoints_[service_id_], this);
205 connection_connect_request_stack_ = false;
206 }
207 connection_open_handler_.WriteMessage(request_id, std::move(request));
208 return request_id;
209 }
210
CancelConnectionOpenRequest(uint64_t request_id)211 void Controller::MessageGroupStreams::CancelConnectionOpenRequest(
212 uint64_t request_id) {
213 connection_open_handler_.CancelMessage(request_id);
214 }
215
OnMatchedResponse(ConnectionOpenRequest * request,msgs::PresentationConnectionOpenResponse * response,uint64_t endpoint_id)216 void Controller::MessageGroupStreams::OnMatchedResponse(
217 ConnectionOpenRequest* request,
218 msgs::PresentationConnectionOpenResponse* response,
219 uint64_t endpoint_id) {
220 if (response->result !=
221 msgs::PresentationConnectionOpenResponse_result::kSuccess) {
222 std::stringstream ss;
223 ss << "presentation-connection-open-response for " << request->request.url
224 << " failed: " << static_cast<int>(response->result);
225 Error error(Error::Code::kUnknownStartError, ss.str());
226 OSP_LOG_INFO << error.message();
227 request->delegate->OnError(std::move(error));
228 return;
229 }
230 OSP_LOG_INFO << "presentation connection opened to "
231 << request->request.presentation_id;
232 if (request->presentation_connection_delegate) {
233 request->connection = std::make_unique<Connection>(
234 Connection::PresentationInfo{request->request.presentation_id,
235 request->request.url},
236 request->presentation_connection_delegate, controller_);
237 }
238 std::unique_ptr<ProtocolConnection> protocol_connection =
239 NetworkServiceManager::Get()
240 ->GetProtocolConnectionClient()
241 ->CreateProtocolConnection(endpoint_id);
242 request->connection->OnConnected(response->connection_id, endpoint_id,
243 std::move(protocol_connection));
244 controller_->AddConnection(request->connection.get());
245 request->delegate->OnConnection(std::move(request->connection));
246 }
247
OnError(ConnectionOpenRequest * request,Error error)248 void Controller::MessageGroupStreams::OnError(ConnectionOpenRequest* request,
249 Error error) {
250 request->delegate->OnError(std::move(error));
251 }
252
SendConnectionCloseRequest(ConnectionCloseRequest request)253 void Controller::MessageGroupStreams::SendConnectionCloseRequest(
254 ConnectionCloseRequest request) {
255 if (!connection_protocol_connection_ && !connection_connect_request_) {
256 connection_connect_request_stack_ = true;
257 connection_connect_request_ =
258 NetworkServiceManager::Get()->GetProtocolConnectionClient()->Connect(
259 controller_->receiver_endpoints_[service_id_], this);
260 connection_connect_request_stack_ = false;
261 }
262 connection_close_handler_.WriteMessage(std::move(request));
263 }
264
OnMatchedResponse(ConnectionCloseRequest * request,msgs::PresentationConnectionCloseResponse * response,uint64_t endpoint_id)265 void Controller::MessageGroupStreams::OnMatchedResponse(
266 ConnectionCloseRequest* request,
267 msgs::PresentationConnectionCloseResponse* response,
268 uint64_t endpoint_id) {
269 OSP_LOG_IF(INFO,
270 response->result !=
271 msgs::PresentationConnectionCloseResponse_result::kSuccess)
272 << "error in presentation-connection-close-response: "
273 << static_cast<int>(response->result);
274 }
275
OnError(ConnectionCloseRequest * request,Error error)276 void Controller::MessageGroupStreams::OnError(ConnectionCloseRequest* request,
277 Error error) {
278 OSP_LOG_INFO << "got error when closing connection "
279 << request->request.connection_id << ": " << error;
280 }
281
SendTerminationRequest(TerminationRequest request)282 void Controller::MessageGroupStreams::SendTerminationRequest(
283 TerminationRequest request) {
284 if (!initiation_protocol_connection_ && !initiation_connect_request_) {
285 initiation_connect_request_ =
286 NetworkServiceManager::Get()->GetProtocolConnectionClient()->Connect(
287 controller_->receiver_endpoints_[service_id_], this);
288 }
289 termination_handler_.WriteMessage(std::move(request));
290 }
291
OnMatchedResponse(TerminationRequest * request,msgs::PresentationTerminationResponse * response,uint64_t endpoint_id)292 void Controller::MessageGroupStreams::OnMatchedResponse(
293 TerminationRequest* request,
294 msgs::PresentationTerminationResponse* response,
295 uint64_t endpoint_id) {
296 OSP_VLOG << "got presentation-termination-response for "
297 << request->request.presentation_id << " with result "
298 << static_cast<int>(response->result);
299 controller_->TerminatePresentationById(request->request.presentation_id);
300 }
301
OnError(TerminationRequest * request,Error error)302 void Controller::MessageGroupStreams::OnError(TerminationRequest* request,
303 Error error) {}
304
OnConnectionOpened(uint64_t request_id,std::unique_ptr<ProtocolConnection> connection)305 void Controller::MessageGroupStreams::OnConnectionOpened(
306 uint64_t request_id,
307 std::unique_ptr<ProtocolConnection> connection) {
308 if ((initiation_connect_request_ &&
309 initiation_connect_request_.request_id() == request_id) ||
310 initiation_connect_request_stack_) {
311 initiation_protocol_connection_ = std::move(connection);
312 initiation_protocol_connection_->SetObserver(this);
313 initiation_connect_request_.MarkComplete();
314 initiation_handler_.SetConnection(initiation_protocol_connection_.get());
315 termination_handler_.SetConnection(initiation_protocol_connection_.get());
316 } else if ((connection_connect_request_ &&
317 connection_connect_request_.request_id() == request_id) ||
318 connection_connect_request_stack_) {
319 connection_protocol_connection_ = std::move(connection);
320 connection_protocol_connection_->SetObserver(this);
321 connection_connect_request_.MarkComplete();
322 connection_open_handler_.SetConnection(
323 connection_protocol_connection_.get());
324 connection_close_handler_.SetConnection(
325 connection_protocol_connection_.get());
326 }
327 }
328
OnConnectionFailed(uint64_t request_id)329 void Controller::MessageGroupStreams::OnConnectionFailed(uint64_t request_id) {
330 if (initiation_connect_request_ &&
331 initiation_connect_request_.request_id() == request_id) {
332 initiation_connect_request_.MarkComplete();
333 initiation_handler_.Reset();
334 termination_handler_.Reset();
335 } else if (connection_connect_request_ &&
336 connection_connect_request_.request_id() == request_id) {
337 connection_connect_request_.MarkComplete();
338 connection_open_handler_.Reset();
339 connection_close_handler_.Reset();
340 }
341 }
342
OnConnectionClosed(const ProtocolConnection & connection)343 void Controller::MessageGroupStreams::OnConnectionClosed(
344 const ProtocolConnection& connection) {
345 if (&connection == initiation_protocol_connection_.get()) {
346 initiation_handler_.Reset();
347 termination_handler_.Reset();
348 }
349 }
350
GetNextInternalRequestId()351 uint64_t Controller::MessageGroupStreams::GetNextInternalRequestId() {
352 return ++next_internal_request_id_;
353 }
354
355 Controller::ReceiverWatch::ReceiverWatch() = default;
ReceiverWatch(Controller * controller,const std::vector<std::string> & urls,ReceiverObserver * observer)356 Controller::ReceiverWatch::ReceiverWatch(Controller* controller,
357 const std::vector<std::string>& urls,
358 ReceiverObserver* observer)
359 : urls_(urls), observer_(observer), controller_(controller) {}
360
ReceiverWatch(Controller::ReceiverWatch && other)361 Controller::ReceiverWatch::ReceiverWatch(
362 Controller::ReceiverWatch&& other) noexcept {
363 swap(*this, other);
364 }
365
~ReceiverWatch()366 Controller::ReceiverWatch::~ReceiverWatch() {
367 if (observer_) {
368 controller_->CancelReceiverWatch(urls_, observer_);
369 }
370 observer_ = nullptr;
371 }
372
operator =(Controller::ReceiverWatch other)373 Controller::ReceiverWatch& Controller::ReceiverWatch::operator=(
374 Controller::ReceiverWatch other) {
375 swap(*this, other);
376 return *this;
377 }
378
swap(Controller::ReceiverWatch & a,Controller::ReceiverWatch & b)379 void swap(Controller::ReceiverWatch& a, Controller::ReceiverWatch& b) {
380 using std::swap;
381 swap(a.urls_, b.urls_);
382 swap(a.observer_, b.observer_);
383 swap(a.controller_, b.controller_);
384 }
385
386 Controller::ConnectRequest::ConnectRequest() = default;
ConnectRequest(Controller * controller,const std::string & service_id,bool is_reconnect,absl::optional<uint64_t> request_id)387 Controller::ConnectRequest::ConnectRequest(Controller* controller,
388 const std::string& service_id,
389 bool is_reconnect,
390 absl::optional<uint64_t> request_id)
391 : service_id_(service_id),
392 is_reconnect_(is_reconnect),
393 request_id_(request_id),
394 controller_(controller) {}
395
ConnectRequest(ConnectRequest && other)396 Controller::ConnectRequest::ConnectRequest(ConnectRequest&& other) noexcept {
397 swap(*this, other);
398 }
399
~ConnectRequest()400 Controller::ConnectRequest::~ConnectRequest() {
401 if (request_id_) {
402 controller_->CancelConnectRequest(service_id_, is_reconnect_,
403 request_id_.value());
404 }
405 request_id_ = 0;
406 }
407
operator =(ConnectRequest other)408 Controller::ConnectRequest& Controller::ConnectRequest::operator=(
409 ConnectRequest other) {
410 swap(*this, other);
411 return *this;
412 }
413
swap(Controller::ConnectRequest & a,Controller::ConnectRequest & b)414 void swap(Controller::ConnectRequest& a, Controller::ConnectRequest& b) {
415 using std::swap;
416 swap(a.service_id_, b.service_id_);
417 swap(a.is_reconnect_, b.is_reconnect_);
418 swap(a.request_id_, b.request_id_);
419 swap(a.controller_, b.controller_);
420 }
421
Controller(ClockNowFunctionPtr now_function)422 Controller::Controller(ClockNowFunctionPtr now_function) {
423 availability_requester_ =
424 std::make_unique<UrlAvailabilityRequester>(now_function);
425 connection_manager_ =
426 std::make_unique<ConnectionManager>(NetworkServiceManager::Get()
427 ->GetProtocolConnectionClient()
428 ->message_demuxer());
429 const std::vector<ServiceInfo>& receivers =
430 NetworkServiceManager::Get()->GetMdnsServiceListener()->GetReceivers();
431 for (const auto& info : receivers) {
432 // TODO(crbug.com/openscreen/33): Replace service_id with endpoint_id when
433 // endpoint_id is more than just an IPEndpoint counter and actually relates
434 // to a device's identity.
435 receiver_endpoints_.emplace(info.service_id, info.v4_endpoint.port
436 ? info.v4_endpoint
437 : info.v6_endpoint);
438 availability_requester_->AddReceiver(info);
439 }
440 // TODO(btolsch): This is for |receiver_endpoints_|, but this should really be
441 // tracked elsewhere so it's available to other protocols as well.
442 NetworkServiceManager::Get()->GetMdnsServiceListener()->AddObserver(this);
443 }
444
~Controller()445 Controller::~Controller() {
446 connection_manager_.reset();
447 NetworkServiceManager::Get()->GetMdnsServiceListener()->RemoveObserver(this);
448 }
449
RegisterReceiverWatch(const std::vector<std::string> & urls,ReceiverObserver * observer)450 Controller::ReceiverWatch Controller::RegisterReceiverWatch(
451 const std::vector<std::string>& urls,
452 ReceiverObserver* observer) {
453 availability_requester_->AddObserver(urls, observer);
454 return ReceiverWatch(this, urls, observer);
455 }
456
StartPresentation(const std::string & url,const std::string & service_id,RequestDelegate * delegate,Connection::Delegate * conn_delegate)457 Controller::ConnectRequest Controller::StartPresentation(
458 const std::string& url,
459 const std::string& service_id,
460 RequestDelegate* delegate,
461 Connection::Delegate* conn_delegate) {
462 StartRequest request;
463 request.request.url = url;
464 request.request.presentation_id = MakePresentationId(url, service_id);
465 request.delegate = delegate;
466 request.presentation_connection_delegate = conn_delegate;
467 uint64_t request_id =
468 group_streams_[service_id]->SendStartRequest(std::move(request));
469 constexpr bool is_reconnect = false;
470 return ConnectRequest(this, service_id, is_reconnect, request_id);
471 }
472
ReconnectPresentation(const std::vector<std::string> & urls,const std::string & presentation_id,const std::string & service_id,RequestDelegate * delegate,Connection::Delegate * conn_delegate)473 Controller::ConnectRequest Controller::ReconnectPresentation(
474 const std::vector<std::string>& urls,
475 const std::string& presentation_id,
476 const std::string& service_id,
477 RequestDelegate* delegate,
478 Connection::Delegate* conn_delegate) {
479 auto presentation_entry = presentations_.find(presentation_id);
480 if (presentation_entry == presentations_.end()) {
481 delegate->OnError(Error::Code::kNoPresentationFound);
482 return ConnectRequest();
483 }
484 auto matching_url_it =
485 std::find(urls.begin(), urls.end(), presentation_entry->second.url);
486 if (matching_url_it == urls.end()) {
487 delegate->OnError(Error::Code::kNoPresentationFound);
488 return ConnectRequest();
489 }
490 ConnectionOpenRequest request;
491 request.request.url = presentation_entry->second.url;
492 request.request.presentation_id = presentation_id;
493 request.delegate = delegate;
494 request.presentation_connection_delegate = conn_delegate;
495 request.connection = nullptr;
496 uint64_t request_id =
497 group_streams_[service_id]->SendConnectionOpenRequest(std::move(request));
498 constexpr bool is_reconnect = true;
499 return ConnectRequest(this, service_id, is_reconnect, request_id);
500 }
501
ReconnectConnection(std::unique_ptr<Connection> connection,RequestDelegate * delegate)502 Controller::ConnectRequest Controller::ReconnectConnection(
503 std::unique_ptr<Connection> connection,
504 RequestDelegate* delegate) {
505 if (connection->state() != Connection::State::kClosed) {
506 delegate->OnError(Error::Code::kInvalidConnectionState);
507 return ConnectRequest();
508 }
509 const Connection::PresentationInfo& info = connection->presentation_info();
510 auto presentation_entry = presentations_.find(info.id);
511 if (presentation_entry == presentations_.end() ||
512 presentation_entry->second.url != info.url) {
513 OSP_LOG_ERROR << "missing ControlledPresentation for non-terminated "
514 "connection with info ("
515 << info.id << ", " << info.url << ")";
516 delegate->OnError(Error::Code::kNoPresentationFound);
517 return ConnectRequest();
518 }
519 OSP_DCHECK(connection_manager_->GetConnection(connection->connection_id()))
520 << "otherwise valid connection for reconnect is unknown to the "
521 "connection manager";
522 connection_manager_->RemoveConnection(connection.get());
523 connection->OnConnecting();
524 ConnectionOpenRequest request;
525 request.request.url = info.url;
526 request.request.presentation_id = info.id;
527 request.delegate = delegate;
528 request.presentation_connection_delegate = nullptr;
529 request.connection = std::move(connection);
530 const std::string& service_id = presentation_entry->second.service_id;
531 uint64_t request_id =
532 group_streams_[service_id]->SendConnectionOpenRequest(std::move(request));
533 constexpr bool is_reconnect = true;
534 return ConnectRequest(this, service_id, is_reconnect, request_id);
535 }
536
CloseConnection(Connection * connection,Connection::CloseReason reason)537 Error Controller::CloseConnection(Connection* connection,
538 Connection::CloseReason reason) {
539 auto presentation_entry =
540 presentations_.find(connection->presentation_info().id);
541 if (presentation_entry == presentations_.end()) {
542 std::stringstream ss;
543 ss << "no presentation found when trying to close connection "
544 << connection->presentation_info().id << ":"
545 << connection->connection_id();
546 return Error(Error::Code::kNoPresentationFound, ss.str());
547 }
548 ConnectionCloseRequest request;
549 request.request.connection_id = connection->connection_id();
550 group_streams_[presentation_entry->second.service_id]
551 ->SendConnectionCloseRequest(std::move(request));
552 return Error::None();
553 }
554
OnPresentationTerminated(const std::string & presentation_id,TerminationReason reason)555 Error Controller::OnPresentationTerminated(const std::string& presentation_id,
556 TerminationReason reason) {
557 auto presentation_entry = presentations_.find(presentation_id);
558 if (presentation_entry == presentations_.end()) {
559 return Error::Code::kNoPresentationFound;
560 }
561 ControlledPresentation& presentation = presentation_entry->second;
562 for (auto* connection : presentation.connections) {
563 connection->OnTerminated();
564 }
565 TerminationRequest request;
566 request.request.presentation_id = presentation_id;
567 request.request.reason =
568 msgs::PresentationTerminationRequest_reason::kUserTerminatedViaController;
569 group_streams_[presentation.service_id]->SendTerminationRequest(
570 std::move(request));
571 presentations_.erase(presentation_entry);
572 termination_listener_by_id_.erase(presentation_id);
573 return Error::None();
574 }
575
OnConnectionDestroyed(Connection * connection)576 void Controller::OnConnectionDestroyed(Connection* connection) {
577 auto presentation_entry =
578 presentations_.find(connection->presentation_info().id);
579 if (presentation_entry == presentations_.end()) {
580 return;
581 }
582
583 std::vector<Connection*>& connections =
584 presentation_entry->second.connections;
585
586 connections.erase(
587 std::remove(connections.begin(), connections.end(), connection),
588 connections.end());
589
590 connection_manager_->RemoveConnection(connection);
591 }
592
GetServiceIdForPresentationId(const std::string & presentation_id) const593 std::string Controller::GetServiceIdForPresentationId(
594 const std::string& presentation_id) const {
595 auto presentation_entry = presentations_.find(presentation_id);
596 if (presentation_entry == presentations_.end()) {
597 return "";
598 }
599 return presentation_entry->second.service_id;
600 }
601
GetConnectionRequestGroupStream(const std::string & service_id)602 ProtocolConnection* Controller::GetConnectionRequestGroupStream(
603 const std::string& service_id) {
604 OSP_UNIMPLEMENTED();
605 return nullptr;
606 }
607
OnError(ServiceListenerError)608 void Controller::OnError(ServiceListenerError) {}
OnMetrics(ServiceListener::Metrics)609 void Controller::OnMetrics(ServiceListener::Metrics) {}
610
611 class Controller::TerminationListener final
612 : public MessageDemuxer::MessageCallback {
613 public:
614 TerminationListener(Controller* controller,
615 const std::string& presentation_id,
616 uint64_t endpoint_id);
617 ~TerminationListener() override;
618
619 // MessageDemuxer::MessageCallback overrides.
620 ErrorOr<size_t> OnStreamMessage(uint64_t endpoint_id,
621 uint64_t connection_id,
622 msgs::Type message_type,
623 const uint8_t* buffer,
624 size_t buffer_size,
625 Clock::time_point now) override;
626
627 private:
628 Controller* const controller_;
629 std::string presentation_id_;
630 MessageDemuxer::MessageWatch event_watch_;
631 };
632
TerminationListener(Controller * controller,const std::string & presentation_id,uint64_t endpoint_id)633 Controller::TerminationListener::TerminationListener(
634 Controller* controller,
635 const std::string& presentation_id,
636 uint64_t endpoint_id)
637 : controller_(controller), presentation_id_(presentation_id) {
638 event_watch_ =
639 NetworkServiceManager::Get()
640 ->GetProtocolConnectionClient()
641 ->message_demuxer()
642 ->WatchMessageType(endpoint_id,
643 msgs::Type::kPresentationTerminationEvent, this);
644 }
645
646 Controller::TerminationListener::~TerminationListener() = default;
647
OnStreamMessage(uint64_t endpoint_id,uint64_t connection_id,msgs::Type message_type,const uint8_t * buffer,size_t buffer_size,Clock::time_point now)648 ErrorOr<size_t> Controller::TerminationListener::OnStreamMessage(
649 uint64_t endpoint_id,
650 uint64_t connection_id,
651 msgs::Type message_type,
652 const uint8_t* buffer,
653 size_t buffer_size,
654 Clock::time_point now) {
655 OSP_CHECK_EQ(static_cast<int>(msgs::Type::kPresentationTerminationEvent),
656 static_cast<int>(message_type));
657 msgs::PresentationTerminationEvent event;
658 ssize_t result =
659 msgs::DecodePresentationTerminationEvent(buffer, buffer_size, &event);
660 if (result < 0) {
661 OSP_LOG_WARN << "decode presentation-termination-event error: " << result;
662 return Error::Code::kCborParsing;
663 } else if (event.presentation_id != presentation_id_) {
664 OSP_LOG_WARN << "got presentation-termination-event for wrong id: "
665 << presentation_id_ << " vs. " << event.presentation_id;
666 return result;
667 }
668 OSP_LOG_INFO << "termination event";
669 auto presentation_entry =
670 controller_->presentations_.find(event.presentation_id);
671 if (presentation_entry != controller_->presentations_.end()) {
672 for (auto* connection : presentation_entry->second.connections)
673 connection->OnTerminated();
674 controller_->presentations_.erase(presentation_entry);
675 }
676 controller_->termination_listener_by_id_.erase(event.presentation_id);
677 return result;
678 }
679
680 // static
MakePresentationId(const std::string & url,const std::string & service_id)681 std::string Controller::MakePresentationId(const std::string& url,
682 const std::string& service_id) {
683 // TODO(btolsch): This is just a placeholder for the demo. It should
684 // eventually become a GUID/unguessable token routine.
685 std::string safe_id = service_id;
686 for (auto& c : safe_id)
687 if (c < ' ' || c > '~')
688 c = '.';
689 return safe_id + ":" + url;
690 }
691
AddConnection(Connection * connection)692 void Controller::AddConnection(Connection* connection) {
693 connection_manager_->AddConnection(connection);
694 }
695
OpenConnection(uint64_t connection_id,uint64_t endpoint_id,const std::string & service_id,RequestDelegate * request_delegate,std::unique_ptr<Connection> && connection,std::unique_ptr<ProtocolConnection> && protocol_connection)696 void Controller::OpenConnection(
697 uint64_t connection_id,
698 uint64_t endpoint_id,
699 const std::string& service_id,
700 RequestDelegate* request_delegate,
701 std::unique_ptr<Connection>&& connection,
702 std::unique_ptr<ProtocolConnection>&& protocol_connection) {
703 connection->OnConnected(connection_id, endpoint_id,
704 std::move(protocol_connection));
705 const std::string& presentation_id = connection->presentation_info().id;
706 auto presentation_entry = presentations_.find(presentation_id);
707 if (presentation_entry == presentations_.end()) {
708 auto emplace_entry = presentations_.emplace(
709 presentation_id,
710 ControlledPresentation{
711 service_id, connection->presentation_info().url, {}});
712 presentation_entry = emplace_entry.first;
713 }
714 ControlledPresentation& presentation = presentation_entry->second;
715 presentation.connections.push_back(connection.get());
716 AddConnection(connection.get());
717
718 auto terminate_entry = termination_listener_by_id_.find(presentation_id);
719 if (terminate_entry == termination_listener_by_id_.end()) {
720 termination_listener_by_id_.emplace(
721 presentation_id, std::make_unique<TerminationListener>(
722 this, presentation_id, endpoint_id));
723 }
724 request_delegate->OnConnection(std::move(connection));
725 }
726
TerminatePresentationById(const std::string & presentation_id)727 void Controller::TerminatePresentationById(const std::string& presentation_id) {
728 auto presentation_entry = presentations_.find(presentation_id);
729 if (presentation_entry != presentations_.end()) {
730 for (auto* connection : presentation_entry->second.connections) {
731 connection->OnTerminated();
732 }
733 presentations_.erase(presentation_entry);
734 }
735 }
736
CancelReceiverWatch(const std::vector<std::string> & urls,ReceiverObserver * observer)737 void Controller::CancelReceiverWatch(const std::vector<std::string>& urls,
738 ReceiverObserver* observer) {
739 availability_requester_->RemoveObserverUrls(urls, observer);
740 }
741
CancelConnectRequest(const std::string & service_id,bool is_reconnect,uint64_t request_id)742 void Controller::CancelConnectRequest(const std::string& service_id,
743 bool is_reconnect,
744 uint64_t request_id) {
745 auto group_streams_entry = group_streams_.find(service_id);
746 if (group_streams_entry == group_streams_.end())
747 return;
748 if (is_reconnect) {
749 group_streams_entry->second->CancelConnectionOpenRequest(request_id);
750 } else {
751 group_streams_entry->second->CancelStartRequest(request_id);
752 }
753 }
754
OnStarted()755 void Controller::OnStarted() {}
OnStopped()756 void Controller::OnStopped() {}
OnSuspended()757 void Controller::OnSuspended() {}
OnSearching()758 void Controller::OnSearching() {}
759
OnReceiverAdded(const ServiceInfo & info)760 void Controller::OnReceiverAdded(const ServiceInfo& info) {
761 receiver_endpoints_.emplace(info.service_id, info.v4_endpoint.port
762 ? info.v4_endpoint
763 : info.v6_endpoint);
764 auto group_streams =
765 std::make_unique<MessageGroupStreams>(this, info.service_id);
766 group_streams_[info.service_id] = std::move(group_streams);
767 availability_requester_->AddReceiver(info);
768 }
769
OnReceiverChanged(const ServiceInfo & info)770 void Controller::OnReceiverChanged(const ServiceInfo& info) {
771 receiver_endpoints_[info.service_id] =
772 info.v4_endpoint.port ? info.v4_endpoint : info.v6_endpoint;
773 availability_requester_->ChangeReceiver(info);
774 }
775
OnReceiverRemoved(const ServiceInfo & info)776 void Controller::OnReceiverRemoved(const ServiceInfo& info) {
777 receiver_endpoints_.erase(info.service_id);
778 group_streams_.erase(info.service_id);
779 availability_requester_->RemoveReceiver(info);
780 }
781
OnAllReceiversRemoved()782 void Controller::OnAllReceiversRemoved() {
783 receiver_endpoints_.clear();
784 availability_requester_->RemoveAllReceivers();
785 }
786
787 } // namespace osp
788 } // namespace openscreen
789