1 // Copyright 2019 The Chromium Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4 
5 #include "cast/common/channel/virtual_connection_router.h"
6 
7 #include <utility>
8 
9 #include "cast/common/channel/cast_message_handler.h"
10 #include "cast/common/channel/connection_namespace_handler.h"
11 #include "cast/common/channel/message_util.h"
12 #include "cast/common/channel/proto/cast_channel.pb.h"
13 #include "util/osp_logging.h"
14 
15 namespace openscreen {
16 namespace cast {
17 
18 using ::cast::channel::CastMessage;
19 
20 VirtualConnectionRouter::VirtualConnectionRouter() = default;
21 
22 VirtualConnectionRouter::~VirtualConnectionRouter() = default;
23 
AddConnection(VirtualConnection virtual_connection,VirtualConnection::AssociatedData associated_data)24 void VirtualConnectionRouter::AddConnection(
25     VirtualConnection virtual_connection,
26     VirtualConnection::AssociatedData associated_data) {
27   auto& socket_map = connections_[virtual_connection.socket_id];
28   auto local_entries = socket_map.equal_range(virtual_connection.local_id);
29   auto it = std::find_if(
30       local_entries.first, local_entries.second,
31       [&virtual_connection](const std::pair<std::string, VCTail>& entry) {
32         return entry.second.peer_id == virtual_connection.peer_id;
33       });
34   if (it == socket_map.end()) {
35     socket_map.emplace(std::move(virtual_connection.local_id),
36                        VCTail{std::move(virtual_connection.peer_id),
37                               std::move(associated_data)});
38   }
39 }
40 
RemoveConnection(const VirtualConnection & virtual_connection,VirtualConnection::CloseReason reason)41 bool VirtualConnectionRouter::RemoveConnection(
42     const VirtualConnection& virtual_connection,
43     VirtualConnection::CloseReason reason) {
44   auto socket_entry = connections_.find(virtual_connection.socket_id);
45   if (socket_entry == connections_.end()) {
46     return false;
47   }
48 
49   auto& socket_map = socket_entry->second;
50   auto local_entries = socket_map.equal_range(virtual_connection.local_id);
51   if (local_entries.first == socket_map.end()) {
52     return false;
53   }
54   for (auto it = local_entries.first; it != local_entries.second; ++it) {
55     if (it->second.peer_id == virtual_connection.peer_id) {
56       socket_map.erase(it);
57       if (socket_map.empty()) {
58         connections_.erase(socket_entry);
59       }
60       return true;
61     }
62   }
63   return false;
64 }
65 
RemoveConnectionsByLocalId(const std::string & local_id)66 void VirtualConnectionRouter::RemoveConnectionsByLocalId(
67     const std::string& local_id) {
68   for (auto socket_entry = connections_.begin();
69        socket_entry != connections_.end();) {
70     auto& socket_map = socket_entry->second;
71     auto local_entries = socket_map.equal_range(local_id);
72     if (local_entries.first != socket_map.end()) {
73       socket_map.erase(local_entries.first, local_entries.second);
74       if (socket_map.empty()) {
75         socket_entry = connections_.erase(socket_entry);
76         continue;
77       }
78     }
79     ++socket_entry;
80   }
81 }
82 
RemoveConnectionsBySocketId(int socket_id)83 void VirtualConnectionRouter::RemoveConnectionsBySocketId(int socket_id) {
84   auto entry = connections_.find(socket_id);
85   if (entry != connections_.end()) {
86     connections_.erase(entry);
87   }
88 }
89 
90 absl::optional<const VirtualConnection::AssociatedData*>
GetConnectionData(const VirtualConnection & virtual_connection) const91 VirtualConnectionRouter::GetConnectionData(
92     const VirtualConnection& virtual_connection) const {
93   auto socket_entry = connections_.find(virtual_connection.socket_id);
94   if (socket_entry == connections_.end()) {
95     return absl::nullopt;
96   }
97 
98   auto& socket_map = socket_entry->second;
99   auto local_entries = socket_map.equal_range(virtual_connection.local_id);
100   if (local_entries.first == socket_map.end()) {
101     return absl::nullopt;
102   }
103   for (auto it = local_entries.first; it != local_entries.second; ++it) {
104     if (it->second.peer_id == virtual_connection.peer_id) {
105       return &it->second.data;
106     }
107   }
108   return absl::nullopt;
109 }
110 
AddHandlerForLocalId(std::string local_id,CastMessageHandler * endpoint)111 bool VirtualConnectionRouter::AddHandlerForLocalId(
112     std::string local_id,
113     CastMessageHandler* endpoint) {
114   return endpoints_.emplace(std::move(local_id), endpoint).second;
115 }
116 
RemoveHandlerForLocalId(const std::string & local_id)117 bool VirtualConnectionRouter::RemoveHandlerForLocalId(
118     const std::string& local_id) {
119   return endpoints_.erase(local_id) == 1u;
120 }
121 
TakeSocket(SocketErrorHandler * error_handler,std::unique_ptr<CastSocket> socket)122 void VirtualConnectionRouter::TakeSocket(SocketErrorHandler* error_handler,
123                                          std::unique_ptr<CastSocket> socket) {
124   int id = socket->socket_id();
125   socket->SetClient(this);
126   sockets_.emplace(id, SocketWithHandler{std::move(socket), error_handler});
127 }
128 
CloseSocket(int id)129 void VirtualConnectionRouter::CloseSocket(int id) {
130   auto it = sockets_.find(id);
131   if (it != sockets_.end()) {
132     RemoveConnectionsBySocketId(id);
133     std::unique_ptr<CastSocket> socket = std::move(it->second.socket);
134     SocketErrorHandler* error_handler = it->second.error_handler;
135     sockets_.erase(it);
136     error_handler->OnClose(socket.get());
137   }
138 }
139 
Send(VirtualConnection virtual_conn,CastMessage message)140 Error VirtualConnectionRouter::Send(VirtualConnection virtual_conn,
141                                     CastMessage message) {
142   if (virtual_conn.peer_id == kBroadcastId) {
143     return BroadcastFromLocalPeer(std::move(virtual_conn.local_id),
144                                   std::move(message));
145   }
146 
147   if (!IsTransportNamespace(message.namespace_()) &&
148       !GetConnectionData(virtual_conn)) {
149     return Error::Code::kNoActiveConnection;
150   }
151   auto it = sockets_.find(virtual_conn.socket_id);
152   if (it == sockets_.end()) {
153     return Error::Code::kItemNotFound;
154   }
155   message.set_source_id(std::move(virtual_conn.local_id));
156   message.set_destination_id(std::move(virtual_conn.peer_id));
157   return it->second.socket->Send(message);
158 }
159 
BroadcastFromLocalPeer(std::string local_id,::cast::channel::CastMessage message)160 Error VirtualConnectionRouter::BroadcastFromLocalPeer(
161     std::string local_id,
162     ::cast::channel::CastMessage message) {
163   message.set_source_id(std::move(local_id));
164   message.set_destination_id(kBroadcastId);
165 
166   // Broadcast to local endpoints.
167   for (const auto& entry : endpoints_) {
168     if (entry.first != message.source_id()) {
169       entry.second->OnMessage(this, nullptr, message);
170     }
171   }
172 
173   // Broadcast to remote endpoints. If an Error occurs, continue broadcasting,
174   // and later return the first Error that occurred.
175   Error error;
176   for (const auto& entry : sockets_) {
177     auto result = entry.second.socket->Send(message);
178     if (!result.ok() && error.ok()) {
179       error = std::move(result);
180     }
181   }
182   return error;
183 }
184 
OnError(CastSocket * socket,Error error)185 void VirtualConnectionRouter::OnError(CastSocket* socket, Error error) {
186   const int id = socket->socket_id();
187   auto it = sockets_.find(id);
188   if (it != sockets_.end()) {
189     RemoveConnectionsBySocketId(id);
190     std::unique_ptr<CastSocket> socket_owned = std::move(it->second.socket);
191     SocketErrorHandler* error_handler = it->second.error_handler;
192     sockets_.erase(it);
193     error_handler->OnError(socket, error);
194   }
195 }
196 
OnMessage(CastSocket * socket,CastMessage message)197 void VirtualConnectionRouter::OnMessage(CastSocket* socket,
198                                         CastMessage message) {
199   OSP_DCHECK(socket);
200 
201   const std::string& local_id = message.destination_id();
202   if (local_id == kBroadcastId) {
203     for (const auto& entry : endpoints_) {
204       entry.second->OnMessage(this, socket, message);
205     }
206   } else {
207     // Connection namespace messages are weird: The message.source_id() and
208     // message.destination_id() are NOT treated as "envelope routing
209     // information," like for all other namespaces. Instead, they are considered
210     // part of the payload data for CONNECT/CLOSE requests. Thus, they require
211     // special-case handling here.
212     if (message.namespace_() == kConnectionNamespace) {
213       if (connection_handler_) {
214         connection_handler_->OnMessage(this, socket, std::move(message));
215       }
216       return;
217     }
218 
219     // Drop all messages for virtual connections that do not yet exist.
220     // Exception: All transport namespace messages (e.g., device auth,
221     // heartbeats, etc.); because these are always assumed to have a route.
222     if (!IsTransportNamespace(message.namespace_()) &&
223         !GetConnectionData(VirtualConnection{local_id, message.source_id(),
224                                              socket->socket_id()})) {
225       return;
226     }
227     auto it = endpoints_.find(local_id);
228     if (it != endpoints_.end()) {
229       it->second->OnMessage(this, socket, std::move(message));
230     }
231   }
232 }
233 
234 }  // namespace cast
235 }  // namespace openscreen
236