// Copyright 2018 The Chromium Authors. All rights reserved. // Use of this source code is governed by a BSD-style license that can be // found in the LICENSE file. #ifndef OSP_PUBLIC_MESSAGE_DEMUXER_H_ #define OSP_PUBLIC_MESSAGE_DEMUXER_H_ #include #include #include #include "osp/msgs/osp_messages.h" #include "platform/api/time.h" #include "platform/base/error.h" namespace openscreen { namespace osp { class QuicStream; // This class separates QUIC stream data into CBOR messages by reading a type // prefix from the stream and passes those messages to any callback matching the // source endpoint and message type. If there is no callback for a given // message type, it will also try a default message listener. class MessageDemuxer { public: class MessageCallback { public: virtual ~MessageCallback() = default; // |buffer| contains data for a message of type |message_type|. However, // the data may be incomplete, in which case the callback should return an // error code of Error::Code::kCborIncompleteMessage. This way, // the MessageDemuxer knows to neither consume the data nor discard it as // bad. virtual ErrorOr OnStreamMessage(uint64_t endpoint_id, uint64_t connection_id, msgs::Type message_type, const uint8_t* buffer, size_t buffer_size, Clock::time_point now) = 0; }; class MessageWatch { public: MessageWatch(); MessageWatch(MessageDemuxer* parent, bool is_default, uint64_t endpoint_id, msgs::Type message_type); MessageWatch(MessageWatch&&) noexcept; ~MessageWatch(); MessageWatch& operator=(MessageWatch&&) noexcept; explicit operator bool() const { return parent_; } private: MessageDemuxer* parent_ = nullptr; bool is_default_; uint64_t endpoint_id_; msgs::Type message_type_; }; static constexpr size_t kDefaultBufferLimit = 1 << 16; MessageDemuxer(ClockNowFunctionPtr now_function, size_t buffer_limit); ~MessageDemuxer(); // Starts watching for messages of type |message_type| from the endpoint // identified by |endpoint_id|. When such a message arrives, or if some are // already buffered, |callback| will be called with the message data. MessageWatch WatchMessageType(uint64_t endpoint_id, msgs::Type message_type, MessageCallback* callback); // Starts watching for messages of type |message_type| from any endpoint when // there is not callback set for its specific endpoint ID. MessageWatch SetDefaultMessageTypeWatch(msgs::Type message_type, MessageCallback* callback); // Gives data from |endpoint_id| to the demuxer for processing. // TODO(btolsch): It'd be nice if errors could propagate out of here to close // the stream. void OnStreamData(uint64_t endpoint_id, uint64_t connection_id, const uint8_t* data, size_t data_size); private: struct HandleStreamBufferResult { bool handled; size_t consumed; }; void StopWatchingMessageType(uint64_t endpoint_id, msgs::Type message_type); void StopDefaultMessageTypeWatch(msgs::Type message_type); HandleStreamBufferResult HandleStreamBufferLoop( uint64_t endpoint_id, uint64_t connection_id, std::map>::iterator endpoint_entry, std::vector* buffer); HandleStreamBufferResult HandleStreamBuffer( uint64_t endpoint_id, uint64_t connection_id, std::map* message_callbacks, std::vector* buffer); const ClockNowFunctionPtr now_function_; const size_t buffer_limit_; std::map> message_callbacks_; std::map default_callbacks_; // Map> std::map>> buffers_; }; // TODO(btolsch): Make sure all uses of MessageWatch are converted to this // resest function for readability. void StopWatching(MessageDemuxer::MessageWatch* watch); class MessageTypeDecoder { public: static ErrorOr DecodeType(const std::vector& buffer, size_t* num_bytes_decoded); private: static ErrorOr DecodeVarUint(const std::vector& buffer, size_t* num_bytes_decoded); }; } // namespace osp } // namespace openscreen #endif // OSP_PUBLIC_MESSAGE_DEMUXER_H_