1 // Copyright 2019 The Chromium Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4
5 #include "cast/sender/public/sender_socket_factory.h"
6
7 #include "cast/common/channel/proto/cast_channel.pb.h"
8 #include "cast/sender/channel/cast_auth_util.h"
9 #include "cast/sender/channel/message_util.h"
10 #include "platform/base/tls_connect_options.h"
11 #include "util/crypto/certificate_utils.h"
12 #include "util/osp_logging.h"
13
14 using ::cast::channel::CastMessage;
15
16 namespace openscreen {
17 namespace cast {
18
operator <(const std::unique_ptr<SenderSocketFactory::PendingAuth> & a,int b)19 bool operator<(const std::unique_ptr<SenderSocketFactory::PendingAuth>& a,
20 int b) {
21 return a && a->socket->socket_id() < b;
22 }
23
operator <(int a,const std::unique_ptr<SenderSocketFactory::PendingAuth> & b)24 bool operator<(int a,
25 const std::unique_ptr<SenderSocketFactory::PendingAuth>& b) {
26 return b && a < b->socket->socket_id();
27 }
28
SenderSocketFactory(Client * client,TaskRunner * task_runner)29 SenderSocketFactory::SenderSocketFactory(Client* client,
30 TaskRunner* task_runner)
31 : client_(client), task_runner_(task_runner) {
32 OSP_DCHECK(client);
33 OSP_DCHECK(task_runner);
34 }
35
~SenderSocketFactory()36 SenderSocketFactory::~SenderSocketFactory() {
37 OSP_DCHECK(task_runner_->IsRunningOnTaskRunner());
38 }
39
set_factory(TlsConnectionFactory * factory)40 void SenderSocketFactory::set_factory(TlsConnectionFactory* factory) {
41 OSP_DCHECK(factory);
42 factory_ = factory;
43 }
44
Connect(const IPEndpoint & endpoint,DeviceMediaPolicy media_policy,CastSocket::Client * client)45 void SenderSocketFactory::Connect(const IPEndpoint& endpoint,
46 DeviceMediaPolicy media_policy,
47 CastSocket::Client* client) {
48 OSP_DCHECK(task_runner_->IsRunningOnTaskRunner());
49 OSP_DCHECK(client);
50 auto it = FindPendingConnection(endpoint);
51 if (it == pending_connections_.end()) {
52 pending_connections_.emplace_back(
53 PendingConnection{endpoint, media_policy, client});
54 factory_->Connect(endpoint, TlsConnectOptions{true});
55 }
56 }
57
OnAccepted(TlsConnectionFactory * factory,std::vector<uint8_t> der_x509_peer_cert,std::unique_ptr<TlsConnection> connection)58 void SenderSocketFactory::OnAccepted(
59 TlsConnectionFactory* factory,
60 std::vector<uint8_t> der_x509_peer_cert,
61 std::unique_ptr<TlsConnection> connection) {
62 OSP_NOTREACHED();
63 OSP_LOG_FATAL << "This factory is connect-only";
64 }
65
OnConnected(TlsConnectionFactory * factory,std::vector<uint8_t> der_x509_peer_cert,std::unique_ptr<TlsConnection> connection)66 void SenderSocketFactory::OnConnected(
67 TlsConnectionFactory* factory,
68 std::vector<uint8_t> der_x509_peer_cert,
69 std::unique_ptr<TlsConnection> connection) {
70 const IPEndpoint& endpoint = connection->GetRemoteEndpoint();
71 auto it = FindPendingConnection(endpoint);
72 if (it == pending_connections_.end()) {
73 OSP_DLOG_ERROR << "TLS connection succeeded for unknown endpoint: "
74 << endpoint;
75 return;
76 }
77 DeviceMediaPolicy media_policy = it->media_policy;
78 CastSocket::Client* client = it->client;
79 pending_connections_.erase(it);
80
81 ErrorOr<bssl::UniquePtr<X509>> peer_cert =
82 ImportCertificate(der_x509_peer_cert.data(), der_x509_peer_cert.size());
83 if (!peer_cert) {
84 client_->OnError(this, endpoint, peer_cert.error());
85 return;
86 }
87
88 auto socket =
89 MakeSerialDelete<CastSocket>(task_runner_, std::move(connection), this);
90 pending_auth_.emplace_back(
91 new PendingAuth{endpoint, media_policy, std::move(socket), client,
92 std::make_unique<AuthContext>(AuthContext::Create()),
93 std::move(peer_cert.value())});
94 PendingAuth& pending = *pending_auth_.back();
95
96 CastMessage auth_challenge =
97 CreateAuthChallengeMessage(*pending.auth_context);
98 Error error = pending.socket->Send(auth_challenge);
99 if (!error.ok()) {
100 pending_auth_.pop_back();
101 client_->OnError(this, endpoint, error);
102 }
103 }
104
OnConnectionFailed(TlsConnectionFactory * factory,const IPEndpoint & remote_address)105 void SenderSocketFactory::OnConnectionFailed(TlsConnectionFactory* factory,
106 const IPEndpoint& remote_address) {
107 auto it = FindPendingConnection(remote_address);
108 if (it == pending_connections_.end()) {
109 OSP_DVLOG << "OnConnectionFailed reported for untracked address: "
110 << remote_address;
111 return;
112 }
113 pending_connections_.erase(it);
114 client_->OnError(this, remote_address, Error::Code::kConnectionFailed);
115 }
116
OnError(TlsConnectionFactory * factory,Error error)117 void SenderSocketFactory::OnError(TlsConnectionFactory* factory, Error error) {
118 std::vector<PendingConnection> connections;
119 pending_connections_.swap(connections);
120 for (const PendingConnection& pending : connections) {
121 client_->OnError(this, pending.endpoint, error);
122 }
123 }
124
125 std::vector<SenderSocketFactory::PendingConnection>::iterator
FindPendingConnection(const IPEndpoint & endpoint)126 SenderSocketFactory::FindPendingConnection(const IPEndpoint& endpoint) {
127 return std::find_if(pending_connections_.begin(), pending_connections_.end(),
128 [&endpoint](const PendingConnection& pending) {
129 return pending.endpoint == endpoint;
130 });
131 }
132
OnError(CastSocket * socket,Error error)133 void SenderSocketFactory::OnError(CastSocket* socket, Error error) {
134 auto it = std::find_if(pending_auth_.begin(), pending_auth_.end(),
135 [id = socket->socket_id()](
136 const std::unique_ptr<PendingAuth>& pending_auth) {
137 return pending_auth->socket->socket_id() == id;
138 });
139 if (it == pending_auth_.end()) {
140 OSP_DLOG_ERROR << "Got error for unknown pending socket";
141 return;
142 }
143 IPEndpoint endpoint = (*it)->endpoint;
144 pending_auth_.erase(it);
145 client_->OnError(this, endpoint, error);
146 }
147
OnMessage(CastSocket * socket,CastMessage message)148 void SenderSocketFactory::OnMessage(CastSocket* socket, CastMessage message) {
149 auto it = std::find_if(pending_auth_.begin(), pending_auth_.end(),
150 [id = socket->socket_id()](
151 const std::unique_ptr<PendingAuth>& pending_auth) {
152 return pending_auth->socket->socket_id() == id;
153 });
154 if (it == pending_auth_.end()) {
155 OSP_DLOG_ERROR << "Got message for unknown pending socket";
156 return;
157 }
158
159 std::unique_ptr<PendingAuth> pending = std::move(*it);
160 pending_auth_.erase(it);
161 if (!IsAuthMessage(message)) {
162 client_->OnError(this, pending->endpoint,
163 Error::Code::kCastV2AuthenticationError);
164 return;
165 }
166
167 ErrorOr<CastDeviceCertPolicy> policy_or_error = AuthenticateChallengeReply(
168 message, pending->peer_cert.get(), *pending->auth_context);
169 if (policy_or_error.is_error()) {
170 OSP_DLOG_WARN << "Authentication failed for " << pending->endpoint
171 << " with error: " << policy_or_error.error();
172 client_->OnError(this, pending->endpoint, policy_or_error.error());
173 return;
174 }
175
176 if (policy_or_error.value() == CastDeviceCertPolicy::kAudioOnly &&
177 pending->media_policy == DeviceMediaPolicy::kIncludesVideo) {
178 client_->OnError(this, pending->endpoint,
179 Error::Code::kCastV2ChannelPolicyMismatch);
180 return;
181 }
182 pending->socket->set_audio_only(policy_or_error.value() ==
183 CastDeviceCertPolicy::kAudioOnly);
184
185 pending->socket->SetClient(pending->client);
186 client_->OnConnected(this, pending->endpoint,
187 std::unique_ptr<CastSocket>(pending->socket.release()));
188 }
189
190 } // namespace cast
191 } // namespace openscreen
192