1 // Copyright 2020 The Pigweed Authors
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License"); you may not
4 // use this file except in compliance with the License. You may obtain a copy of
5 // the License at
6 //
7 // https://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
11 // WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
12 // License for the specific language governing permissions and limitations under
13 // the License.
14
15 #include "pw_protobuf/decoder.h"
16
17 #include <cstring>
18
19 #include "pw_varint/varint.h"
20
21 namespace pw::protobuf {
22
Next()23 Status Decoder::Next() {
24 if (!previous_field_consumed_) {
25 if (Status status = SkipField(); !status.ok()) {
26 return status;
27 }
28 }
29 if (proto_.empty()) {
30 return Status::OutOfRange();
31 }
32 previous_field_consumed_ = false;
33 return FieldSize() == 0 ? Status::DataLoss() : OkStatus();
34 }
35
SkipField()36 Status Decoder::SkipField() {
37 if (proto_.empty()) {
38 return Status::OutOfRange();
39 }
40
41 size_t bytes_to_skip = FieldSize();
42 if (bytes_to_skip == 0) {
43 return Status::DataLoss();
44 }
45
46 proto_ = proto_.subspan(bytes_to_skip);
47 return proto_.empty() ? Status::OutOfRange() : OkStatus();
48 }
49
FieldNumber() const50 uint32_t Decoder::FieldNumber() const {
51 uint64_t key;
52 varint::Decode(proto_, &key);
53 return key >> kFieldNumberShift;
54 }
55
ReadUint32(uint32_t * out)56 Status Decoder::ReadUint32(uint32_t* out) {
57 uint64_t value = 0;
58 Status status = ReadUint64(&value);
59 if (!status.ok()) {
60 return status;
61 }
62 if (value > std::numeric_limits<uint32_t>::max()) {
63 return Status::OutOfRange();
64 }
65 *out = value;
66 return OkStatus();
67 }
68
ReadSint32(int32_t * out)69 Status Decoder::ReadSint32(int32_t* out) {
70 int64_t value = 0;
71 Status status = ReadSint64(&value);
72 if (!status.ok()) {
73 return status;
74 }
75 if (value > std::numeric_limits<int32_t>::max()) {
76 return Status::OutOfRange();
77 }
78 *out = value;
79 return OkStatus();
80 }
81
ReadSint64(int64_t * out)82 Status Decoder::ReadSint64(int64_t* out) {
83 uint64_t value = 0;
84 Status status = ReadUint64(&value);
85 if (!status.ok()) {
86 return status;
87 }
88 *out = varint::ZigZagDecode(value);
89 return OkStatus();
90 }
91
ReadBool(bool * out)92 Status Decoder::ReadBool(bool* out) {
93 uint64_t value = 0;
94 Status status = ReadUint64(&value);
95 if (!status.ok()) {
96 return status;
97 }
98 *out = value;
99 return OkStatus();
100 }
101
ReadString(std::string_view * out)102 Status Decoder::ReadString(std::string_view* out) {
103 std::span<const std::byte> bytes;
104 Status status = ReadDelimited(&bytes);
105 if (!status.ok()) {
106 return status;
107 }
108 *out = std::string_view(reinterpret_cast<const char*>(bytes.data()),
109 bytes.size());
110 return OkStatus();
111 }
112
FieldSize() const113 size_t Decoder::FieldSize() const {
114 uint64_t key;
115 size_t key_size = varint::Decode(proto_, &key);
116 if (key_size == 0) {
117 return 0;
118 }
119
120 std::span<const std::byte> remainder = proto_.subspan(key_size);
121 WireType wire_type = static_cast<WireType>(key & kWireTypeMask);
122 uint64_t value = 0;
123 size_t expected_size = 0;
124
125 switch (wire_type) {
126 case WireType::kVarint:
127 expected_size = varint::Decode(remainder, &value);
128 if (expected_size == 0) {
129 return 0;
130 }
131 break;
132
133 case WireType::kDelimited:
134 // Varint at cursor indicates size of the field.
135 expected_size = varint::Decode(remainder, &value);
136 if (expected_size == 0) {
137 return 0;
138 }
139 expected_size += value;
140 break;
141
142 case WireType::kFixed32:
143 expected_size = sizeof(uint32_t);
144 break;
145
146 case WireType::kFixed64:
147 expected_size = sizeof(uint64_t);
148 break;
149 }
150
151 if (remainder.size() < expected_size) {
152 return 0;
153 }
154
155 return key_size + expected_size;
156 }
157
ConsumeKey(WireType expected_type)158 Status Decoder::ConsumeKey(WireType expected_type) {
159 uint64_t key;
160 size_t bytes_read = varint::Decode(proto_, &key);
161 if (bytes_read == 0) {
162 return Status::FailedPrecondition();
163 }
164
165 WireType wire_type = static_cast<WireType>(key & kWireTypeMask);
166 if (wire_type != expected_type) {
167 return Status::FailedPrecondition();
168 }
169
170 // Advance past the key.
171 proto_ = proto_.subspan(bytes_read);
172 return OkStatus();
173 }
174
ReadVarint(uint64_t * out)175 Status Decoder::ReadVarint(uint64_t* out) {
176 if (Status status = ConsumeKey(WireType::kVarint); !status.ok()) {
177 return status;
178 }
179
180 size_t bytes_read = varint::Decode(proto_, out);
181 if (bytes_read == 0) {
182 return Status::DataLoss();
183 }
184
185 // Advance to the next field.
186 proto_ = proto_.subspan(bytes_read);
187 previous_field_consumed_ = true;
188 return OkStatus();
189 }
190
ReadFixed(std::byte * out,size_t size)191 Status Decoder::ReadFixed(std::byte* out, size_t size) {
192 WireType expected_wire_type =
193 size == sizeof(uint32_t) ? WireType::kFixed32 : WireType::kFixed64;
194 Status status = ConsumeKey(expected_wire_type);
195 if (!status.ok()) {
196 return status;
197 }
198
199 if (proto_.size() < size) {
200 return Status::DataLoss();
201 }
202
203 std::memcpy(out, proto_.data(), size);
204 proto_ = proto_.subspan(size);
205 previous_field_consumed_ = true;
206
207 return OkStatus();
208 }
209
ReadDelimited(std::span<const std::byte> * out)210 Status Decoder::ReadDelimited(std::span<const std::byte>* out) {
211 Status status = ConsumeKey(WireType::kDelimited);
212 if (!status.ok()) {
213 return status;
214 }
215
216 uint64_t length;
217 size_t bytes_read = varint::Decode(proto_, &length);
218 if (bytes_read == 0) {
219 return Status::DataLoss();
220 }
221
222 proto_ = proto_.subspan(bytes_read);
223 if (proto_.size() < length) {
224 return Status::DataLoss();
225 }
226
227 *out = proto_.first(length);
228 proto_ = proto_.subspan(length);
229 previous_field_consumed_ = true;
230
231 return OkStatus();
232 }
233
Decode(std::span<const std::byte> proto)234 Status CallbackDecoder::Decode(std::span<const std::byte> proto) {
235 if (handler_ == nullptr || state_ != kReady) {
236 return Status::FailedPrecondition();
237 }
238
239 state_ = kDecodeInProgress;
240 decoder_.Reset(proto);
241
242 // Iterate the proto, calling the handler with each field number.
243 while (state_ == kDecodeInProgress) {
244 if (Status status = decoder_.Next(); !status.ok()) {
245 if (status.IsOutOfRange()) {
246 // Reached the end of the proto.
247 break;
248 }
249
250 // Proto data is malformed.
251 return status;
252 }
253
254 Status status = handler_->ProcessField(*this, decoder_.FieldNumber());
255 if (!status.ok()) {
256 state_ = status.IsCancelled() ? kDecodeCancelled : kDecodeFailed;
257 return status;
258 }
259
260 // The callback function can modify the decoder's state; check that
261 // everything is still okay.
262 if (state_ == kDecodeFailed) {
263 break;
264 }
265 }
266
267 if (state_ != kDecodeInProgress) {
268 return Status::DataLoss();
269 }
270
271 state_ = kReady;
272 return OkStatus();
273 }
274
275 } // namespace pw::protobuf
276