1 // Copyright 2018 The Chromium Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4 
5 #include "osp/public/message_demuxer.h"
6 
7 #include <memory>
8 #include <utility>
9 
10 #include "osp/impl/quic/quic_connection.h"
11 #include "platform/base/error.h"
12 #include "util/big_endian.h"
13 #include "util/osp_logging.h"
14 
15 namespace openscreen {
16 namespace osp {
17 
18 // static
19 // Decodes a varUint, expecting it to follow the encoding format described here:
20 // https://tools.ietf.org/html/draft-ietf-quic-transport-16#section-16
DecodeVarUint(const std::vector<uint8_t> & buffer,size_t * num_bytes_decoded)21 ErrorOr<uint64_t> MessageTypeDecoder::DecodeVarUint(
22     const std::vector<uint8_t>& buffer,
23     size_t* num_bytes_decoded) {
24   if (buffer.size() == 0) {
25     return Error::Code::kCborIncompleteMessage;
26   }
27 
28   uint8_t num_type_bytes = static_cast<uint8_t>(buffer[0] >> 6 & 0x03);
29   *num_bytes_decoded = 0x1 << num_type_bytes;
30 
31   // Ensure that ReadBigEndian won't read beyond the end of the buffer. Also,
32   // since we expect the id to be followed by the message, equality is not valid
33   if (buffer.size() <= *num_bytes_decoded) {
34     return Error::Code::kCborIncompleteMessage;
35   }
36 
37   switch (num_type_bytes) {
38     case 0:
39       return ReadBigEndian<uint8_t>(&buffer[0]) & ~0xC0;
40     case 1:
41       return ReadBigEndian<uint16_t>(&buffer[0]) & ~(0xC0 << 8);
42     case 2:
43       return ReadBigEndian<uint32_t>(&buffer[0]) & ~(0xC0 << 24);
44     case 3:
45       return ReadBigEndian<uint64_t>(&buffer[0]) & ~(uint64_t{0xC0} << 56);
46     default:
47       OSP_NOTREACHED();
48   }
49 }
50 
51 // static
52 // Decodes the Type of message, expecting it to follow the encoding format
53 // described here:
54 // https://tools.ietf.org/html/draft-ietf-quic-transport-16#section-16
DecodeType(const std::vector<uint8_t> & buffer,size_t * num_bytes_decoded)55 ErrorOr<msgs::Type> MessageTypeDecoder::DecodeType(
56     const std::vector<uint8_t>& buffer,
57     size_t* num_bytes_decoded) {
58   ErrorOr<uint64_t> message_type =
59       MessageTypeDecoder::DecodeVarUint(buffer, num_bytes_decoded);
60   if (message_type.is_error()) {
61     return message_type.error();
62   }
63 
64   msgs::Type parsed_type =
65       msgs::TypeEnumValidator::SafeCast(message_type.value());
66   if (parsed_type == msgs::Type::kUnknown) {
67     return Error::Code::kCborInvalidMessage;
68   }
69 
70   return parsed_type;
71 }
72 
73 // static
74 constexpr size_t MessageDemuxer::kDefaultBufferLimit;
75 
76 MessageDemuxer::MessageWatch::MessageWatch() = default;
77 
MessageWatch(MessageDemuxer * parent,bool is_default,uint64_t endpoint_id,msgs::Type message_type)78 MessageDemuxer::MessageWatch::MessageWatch(MessageDemuxer* parent,
79                                            bool is_default,
80                                            uint64_t endpoint_id,
81                                            msgs::Type message_type)
82     : parent_(parent),
83       is_default_(is_default),
84       endpoint_id_(endpoint_id),
85       message_type_(message_type) {}
86 
MessageWatch(MessageDemuxer::MessageWatch && other)87 MessageDemuxer::MessageWatch::MessageWatch(
88     MessageDemuxer::MessageWatch&& other) noexcept
89     : parent_(other.parent_),
90       is_default_(other.is_default_),
91       endpoint_id_(other.endpoint_id_),
92       message_type_(other.message_type_) {
93   other.parent_ = nullptr;
94 }
95 
~MessageWatch()96 MessageDemuxer::MessageWatch::~MessageWatch() {
97   if (parent_) {
98     if (is_default_) {
99       OSP_VLOG << "dropping default handler for type: "
100                << static_cast<int>(message_type_);
101       parent_->StopDefaultMessageTypeWatch(message_type_);
102     } else {
103       OSP_VLOG << "dropping handler for type: "
104                << static_cast<int>(message_type_);
105       parent_->StopWatchingMessageType(endpoint_id_, message_type_);
106     }
107   }
108 }
109 
operator =(MessageWatch && other)110 MessageDemuxer::MessageWatch& MessageDemuxer::MessageWatch::operator=(
111     MessageWatch&& other) noexcept {
112   using std::swap;
113   swap(parent_, other.parent_);
114   swap(is_default_, other.is_default_);
115   swap(endpoint_id_, other.endpoint_id_);
116   swap(message_type_, other.message_type_);
117   return *this;
118 }
119 
MessageDemuxer(ClockNowFunctionPtr now_function,size_t buffer_limit=kDefaultBufferLimit)120 MessageDemuxer::MessageDemuxer(ClockNowFunctionPtr now_function,
121                                size_t buffer_limit = kDefaultBufferLimit)
122     : now_function_(now_function), buffer_limit_(buffer_limit) {
123   OSP_DCHECK(now_function_);
124 }
125 
126 MessageDemuxer::~MessageDemuxer() = default;
127 
WatchMessageType(uint64_t endpoint_id,msgs::Type message_type,MessageCallback * callback)128 MessageDemuxer::MessageWatch MessageDemuxer::WatchMessageType(
129     uint64_t endpoint_id,
130     msgs::Type message_type,
131     MessageCallback* callback) {
132   auto callbacks_entry = message_callbacks_.find(endpoint_id);
133   if (callbacks_entry == message_callbacks_.end()) {
134     callbacks_entry =
135         message_callbacks_
136             .emplace(endpoint_id, std::map<msgs::Type, MessageCallback*>{})
137             .first;
138   }
139   auto emplace_result = callbacks_entry->second.emplace(message_type, callback);
140   if (!emplace_result.second)
141     return MessageWatch();
142   auto endpoint_entry = buffers_.find(endpoint_id);
143   if (endpoint_entry != buffers_.end()) {
144     for (auto& buffer : endpoint_entry->second) {
145       if (buffer.second.empty())
146         continue;
147       auto buffered_type = static_cast<msgs::Type>(buffer.second[0]);
148       if (message_type == buffered_type) {
149         HandleStreamBufferLoop(endpoint_id, buffer.first, callbacks_entry,
150                                &buffer.second);
151       }
152     }
153   }
154   return MessageWatch(this, false, endpoint_id, message_type);
155 }
156 
SetDefaultMessageTypeWatch(msgs::Type message_type,MessageCallback * callback)157 MessageDemuxer::MessageWatch MessageDemuxer::SetDefaultMessageTypeWatch(
158     msgs::Type message_type,
159     MessageCallback* callback) {
160   auto emplace_result = default_callbacks_.emplace(message_type, callback);
161   if (!emplace_result.second)
162     return MessageWatch();
163   for (auto& endpoint_buffers : buffers_) {
164     auto endpoint_id = endpoint_buffers.first;
165     for (auto& stream_map : endpoint_buffers.second) {
166       if (stream_map.second.empty())
167         continue;
168       auto buffered_type = static_cast<msgs::Type>(stream_map.second[0]);
169       if (message_type == buffered_type) {
170         auto connection_id = stream_map.first;
171         auto callbacks_entry = message_callbacks_.find(endpoint_id);
172         HandleStreamBufferLoop(endpoint_id, connection_id, callbacks_entry,
173                                &stream_map.second);
174       }
175     }
176   }
177   return MessageWatch(this, true, 0, message_type);
178 }
179 
OnStreamData(uint64_t endpoint_id,uint64_t connection_id,const uint8_t * data,size_t data_size)180 void MessageDemuxer::OnStreamData(uint64_t endpoint_id,
181                                   uint64_t connection_id,
182                                   const uint8_t* data,
183                                   size_t data_size) {
184   OSP_VLOG << __func__ << ": [" << endpoint_id << ", " << connection_id
185            << "] - (" << data_size << ")";
186   auto& stream_map = buffers_[endpoint_id];
187   if (!data_size) {
188     stream_map.erase(connection_id);
189     if (stream_map.empty())
190       buffers_.erase(endpoint_id);
191     return;
192   }
193   std::vector<uint8_t>& buffer = stream_map[connection_id];
194   buffer.insert(buffer.end(), data, data + data_size);
195 
196   auto callbacks_entry = message_callbacks_.find(endpoint_id);
197   HandleStreamBufferLoop(endpoint_id, connection_id, callbacks_entry, &buffer);
198 
199   if (buffer.size() > buffer_limit_)
200     stream_map.erase(connection_id);
201 }
202 
StopWatchingMessageType(uint64_t endpoint_id,msgs::Type message_type)203 void MessageDemuxer::StopWatchingMessageType(uint64_t endpoint_id,
204                                              msgs::Type message_type) {
205   auto& message_map = message_callbacks_[endpoint_id];
206   auto it = message_map.find(message_type);
207   message_map.erase(it);
208 }
209 
StopDefaultMessageTypeWatch(msgs::Type message_type)210 void MessageDemuxer::StopDefaultMessageTypeWatch(msgs::Type message_type) {
211   default_callbacks_.erase(message_type);
212 }
213 
HandleStreamBufferLoop(uint64_t endpoint_id,uint64_t connection_id,std::map<uint64_t,std::map<msgs::Type,MessageCallback * >>::iterator callbacks_entry,std::vector<uint8_t> * buffer)214 MessageDemuxer::HandleStreamBufferResult MessageDemuxer::HandleStreamBufferLoop(
215     uint64_t endpoint_id,
216     uint64_t connection_id,
217     std::map<uint64_t, std::map<msgs::Type, MessageCallback*>>::iterator
218         callbacks_entry,
219     std::vector<uint8_t>* buffer) {
220   HandleStreamBufferResult result;
221   do {
222     result = {false, 0};
223     if (callbacks_entry != message_callbacks_.end()) {
224       OSP_VLOG << "attempting endpoint-specific handling";
225       result = HandleStreamBuffer(endpoint_id, connection_id,
226                                   &callbacks_entry->second, buffer);
227     }
228     if (!result.handled) {
229       if (!default_callbacks_.empty()) {
230         OSP_VLOG << "attempting generic message handling";
231         result = HandleStreamBuffer(endpoint_id, connection_id,
232                                     &default_callbacks_, buffer);
233       }
234     }
235     OSP_VLOG_IF(!result.handled) << "no message handler matched";
236   } while (result.consumed && !buffer->empty());
237   return result;
238 }
239 
240 // TODO(rwkeane) Use absl::Span for the buffer
HandleStreamBuffer(uint64_t endpoint_id,uint64_t connection_id,std::map<msgs::Type,MessageCallback * > * message_callbacks,std::vector<uint8_t> * buffer)241 MessageDemuxer::HandleStreamBufferResult MessageDemuxer::HandleStreamBuffer(
242     uint64_t endpoint_id,
243     uint64_t connection_id,
244     std::map<msgs::Type, MessageCallback*>* message_callbacks,
245     std::vector<uint8_t>* buffer) {
246   size_t consumed = 0;
247   size_t total_consumed = 0;
248   bool handled = false;
249   do {
250     consumed = 0;
251     size_t msg_type_byte_length;
252     ErrorOr<msgs::Type> message_type =
253         MessageTypeDecoder::DecodeType(*buffer, &msg_type_byte_length);
254     if (message_type.is_error()) {
255       buffer->clear();
256       break;
257     }
258     auto callback_entry = message_callbacks->find(message_type.value());
259     if (callback_entry == message_callbacks->end())
260       break;
261     handled = true;
262     OSP_VLOG << "handling message type "
263              << static_cast<int>(message_type.value());
264     auto consumed_or_error = callback_entry->second->OnStreamMessage(
265         endpoint_id, connection_id, message_type.value(),
266         buffer->data() + msg_type_byte_length,
267         buffer->size() - msg_type_byte_length, now_function_());
268     if (!consumed_or_error) {
269       if (consumed_or_error.error().code() !=
270           Error::Code::kCborIncompleteMessage) {
271         buffer->clear();
272         break;
273       }
274     } else {
275       consumed = consumed_or_error.value();
276       buffer->erase(buffer->begin(),
277                     buffer->begin() + consumed + msg_type_byte_length);
278     }
279     total_consumed += consumed;
280   } while (consumed && !buffer->empty());
281   return HandleStreamBufferResult{handled, total_consumed};
282 }
283 
StopWatching(MessageDemuxer::MessageWatch * watch)284 void StopWatching(MessageDemuxer::MessageWatch* watch) {
285   *watch = MessageDemuxer::MessageWatch();
286 }
287 
288 }  // namespace osp
289 }  // namespace openscreen
290