1 // Copyright 2018 The Chromium Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4 
5 #ifndef OSP_PUBLIC_MESSAGE_DEMUXER_H_
6 #define OSP_PUBLIC_MESSAGE_DEMUXER_H_
7 
8 #include <map>
9 #include <memory>
10 #include <vector>
11 
12 #include "osp/msgs/osp_messages.h"
13 #include "platform/api/time.h"
14 #include "platform/base/error.h"
15 
16 namespace openscreen {
17 namespace osp {
18 
19 class QuicStream;
20 
21 // This class separates QUIC stream data into CBOR messages by reading a type
22 // prefix from the stream and passes those messages to any callback matching the
23 // source endpoint and message type.  If there is no callback for a given
24 // message type, it will also try a default message listener.
25 class MessageDemuxer {
26  public:
27   class MessageCallback {
28    public:
29     virtual ~MessageCallback() = default;
30 
31     // |buffer| contains data for a message of type |message_type|.  However,
32     // the data may be incomplete, in which case the callback should return an
33     // error code of Error::Code::kCborIncompleteMessage.  This way,
34     // the MessageDemuxer knows to neither consume the data nor discard it as
35     // bad.
36     virtual ErrorOr<size_t> OnStreamMessage(uint64_t endpoint_id,
37                                             uint64_t connection_id,
38                                             msgs::Type message_type,
39                                             const uint8_t* buffer,
40                                             size_t buffer_size,
41                                             Clock::time_point now) = 0;
42   };
43 
44   class MessageWatch {
45    public:
46     MessageWatch();
47     MessageWatch(MessageDemuxer* parent,
48                  bool is_default,
49                  uint64_t endpoint_id,
50                  msgs::Type message_type);
51     MessageWatch(MessageWatch&&) noexcept;
52     ~MessageWatch();
53     MessageWatch& operator=(MessageWatch&&) noexcept;
54 
55     explicit operator bool() const { return parent_; }
56 
57    private:
58     MessageDemuxer* parent_ = nullptr;
59     bool is_default_;
60     uint64_t endpoint_id_;
61     msgs::Type message_type_;
62   };
63 
64   static constexpr size_t kDefaultBufferLimit = 1 << 16;
65 
66   MessageDemuxer(ClockNowFunctionPtr now_function, size_t buffer_limit);
67   ~MessageDemuxer();
68 
69   // Starts watching for messages of type |message_type| from the endpoint
70   // identified by |endpoint_id|.  When such a message arrives, or if some are
71   // already buffered, |callback| will be called with the message data.
72   MessageWatch WatchMessageType(uint64_t endpoint_id,
73                                 msgs::Type message_type,
74                                 MessageCallback* callback);
75 
76   // Starts watching for messages of type |message_type| from any endpoint when
77   // there is not callback set for its specific endpoint ID.
78   MessageWatch SetDefaultMessageTypeWatch(msgs::Type message_type,
79                                           MessageCallback* callback);
80 
81   // Gives data from |endpoint_id| to the demuxer for processing.
82   // TODO(btolsch): It'd be nice if errors could propagate out of here to close
83   // the stream.
84   void OnStreamData(uint64_t endpoint_id,
85                     uint64_t connection_id,
86                     const uint8_t* data,
87                     size_t data_size);
88 
89  private:
90   struct HandleStreamBufferResult {
91     bool handled;
92     size_t consumed;
93   };
94 
95   void StopWatchingMessageType(uint64_t endpoint_id, msgs::Type message_type);
96   void StopDefaultMessageTypeWatch(msgs::Type message_type);
97 
98   HandleStreamBufferResult HandleStreamBufferLoop(
99       uint64_t endpoint_id,
100       uint64_t connection_id,
101       std::map<uint64_t, std::map<msgs::Type, MessageCallback*>>::iterator
102           endpoint_entry,
103       std::vector<uint8_t>* buffer);
104 
105   HandleStreamBufferResult HandleStreamBuffer(
106       uint64_t endpoint_id,
107       uint64_t connection_id,
108       std::map<msgs::Type, MessageCallback*>* message_callbacks,
109       std::vector<uint8_t>* buffer);
110 
111   const ClockNowFunctionPtr now_function_;
112   const size_t buffer_limit_;
113   std::map<uint64_t, std::map<msgs::Type, MessageCallback*>> message_callbacks_;
114   std::map<msgs::Type, MessageCallback*> default_callbacks_;
115 
116   // Map<endpoint_id, Map<connection_id, data_buffer>>
117   std::map<uint64_t, std::map<uint64_t, std::vector<uint8_t>>> buffers_;
118 };
119 
120 // TODO(btolsch): Make sure all uses of MessageWatch are converted to this
121 // resest function for readability.
122 void StopWatching(MessageDemuxer::MessageWatch* watch);
123 
124 class MessageTypeDecoder {
125  public:
126   static ErrorOr<msgs::Type> DecodeType(const std::vector<uint8_t>& buffer,
127                                         size_t* num_bytes_decoded);
128 
129  private:
130   static ErrorOr<uint64_t> DecodeVarUint(const std::vector<uint8_t>& buffer,
131                                          size_t* num_bytes_decoded);
132 };
133 
134 }  // namespace osp
135 }  // namespace openscreen
136 
137 #endif  // OSP_PUBLIC_MESSAGE_DEMUXER_H_
138