1 // Copyright 2020 The Pigweed Authors
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License"); you may not
4 // use this file except in compliance with the License. You may obtain a copy of
5 // the License at
6 //
7 //     https://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
11 // WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
12 // License for the specific language governing permissions and limitations under
13 // the License.
14 
15 #include "pw_protobuf/decoder.h"
16 
17 #include <cstring>
18 
19 #include "pw_varint/varint.h"
20 
21 namespace pw::protobuf {
22 
Next()23 Status Decoder::Next() {
24   if (!previous_field_consumed_) {
25     if (Status status = SkipField(); !status.ok()) {
26       return status;
27     }
28   }
29   if (proto_.empty()) {
30     return Status::OutOfRange();
31   }
32   previous_field_consumed_ = false;
33   return FieldSize() == 0 ? Status::DataLoss() : OkStatus();
34 }
35 
SkipField()36 Status Decoder::SkipField() {
37   if (proto_.empty()) {
38     return Status::OutOfRange();
39   }
40 
41   size_t bytes_to_skip = FieldSize();
42   if (bytes_to_skip == 0) {
43     return Status::DataLoss();
44   }
45 
46   proto_ = proto_.subspan(bytes_to_skip);
47   return proto_.empty() ? Status::OutOfRange() : OkStatus();
48 }
49 
FieldNumber() const50 uint32_t Decoder::FieldNumber() const {
51   uint64_t key;
52   varint::Decode(proto_, &key);
53   return key >> kFieldNumberShift;
54 }
55 
ReadUint32(uint32_t * out)56 Status Decoder::ReadUint32(uint32_t* out) {
57   uint64_t value = 0;
58   Status status = ReadUint64(&value);
59   if (!status.ok()) {
60     return status;
61   }
62   if (value > std::numeric_limits<uint32_t>::max()) {
63     return Status::OutOfRange();
64   }
65   *out = value;
66   return OkStatus();
67 }
68 
ReadSint32(int32_t * out)69 Status Decoder::ReadSint32(int32_t* out) {
70   int64_t value = 0;
71   Status status = ReadSint64(&value);
72   if (!status.ok()) {
73     return status;
74   }
75   if (value > std::numeric_limits<int32_t>::max()) {
76     return Status::OutOfRange();
77   }
78   *out = value;
79   return OkStatus();
80 }
81 
ReadSint64(int64_t * out)82 Status Decoder::ReadSint64(int64_t* out) {
83   uint64_t value = 0;
84   Status status = ReadUint64(&value);
85   if (!status.ok()) {
86     return status;
87   }
88   *out = varint::ZigZagDecode(value);
89   return OkStatus();
90 }
91 
ReadBool(bool * out)92 Status Decoder::ReadBool(bool* out) {
93   uint64_t value = 0;
94   Status status = ReadUint64(&value);
95   if (!status.ok()) {
96     return status;
97   }
98   *out = value;
99   return OkStatus();
100 }
101 
ReadString(std::string_view * out)102 Status Decoder::ReadString(std::string_view* out) {
103   std::span<const std::byte> bytes;
104   Status status = ReadDelimited(&bytes);
105   if (!status.ok()) {
106     return status;
107   }
108   *out = std::string_view(reinterpret_cast<const char*>(bytes.data()),
109                           bytes.size());
110   return OkStatus();
111 }
112 
FieldSize() const113 size_t Decoder::FieldSize() const {
114   uint64_t key;
115   size_t key_size = varint::Decode(proto_, &key);
116   if (key_size == 0) {
117     return 0;
118   }
119 
120   std::span<const std::byte> remainder = proto_.subspan(key_size);
121   WireType wire_type = static_cast<WireType>(key & kWireTypeMask);
122   uint64_t value = 0;
123   size_t expected_size = 0;
124 
125   switch (wire_type) {
126     case WireType::kVarint:
127       expected_size = varint::Decode(remainder, &value);
128       if (expected_size == 0) {
129         return 0;
130       }
131       break;
132 
133     case WireType::kDelimited:
134       // Varint at cursor indicates size of the field.
135       expected_size = varint::Decode(remainder, &value);
136       if (expected_size == 0) {
137         return 0;
138       }
139       expected_size += value;
140       break;
141 
142     case WireType::kFixed32:
143       expected_size = sizeof(uint32_t);
144       break;
145 
146     case WireType::kFixed64:
147       expected_size = sizeof(uint64_t);
148       break;
149   }
150 
151   if (remainder.size() < expected_size) {
152     return 0;
153   }
154 
155   return key_size + expected_size;
156 }
157 
ConsumeKey(WireType expected_type)158 Status Decoder::ConsumeKey(WireType expected_type) {
159   uint64_t key;
160   size_t bytes_read = varint::Decode(proto_, &key);
161   if (bytes_read == 0) {
162     return Status::FailedPrecondition();
163   }
164 
165   WireType wire_type = static_cast<WireType>(key & kWireTypeMask);
166   if (wire_type != expected_type) {
167     return Status::FailedPrecondition();
168   }
169 
170   // Advance past the key.
171   proto_ = proto_.subspan(bytes_read);
172   return OkStatus();
173 }
174 
ReadVarint(uint64_t * out)175 Status Decoder::ReadVarint(uint64_t* out) {
176   if (Status status = ConsumeKey(WireType::kVarint); !status.ok()) {
177     return status;
178   }
179 
180   size_t bytes_read = varint::Decode(proto_, out);
181   if (bytes_read == 0) {
182     return Status::DataLoss();
183   }
184 
185   // Advance to the next field.
186   proto_ = proto_.subspan(bytes_read);
187   previous_field_consumed_ = true;
188   return OkStatus();
189 }
190 
ReadFixed(std::byte * out,size_t size)191 Status Decoder::ReadFixed(std::byte* out, size_t size) {
192   WireType expected_wire_type =
193       size == sizeof(uint32_t) ? WireType::kFixed32 : WireType::kFixed64;
194   Status status = ConsumeKey(expected_wire_type);
195   if (!status.ok()) {
196     return status;
197   }
198 
199   if (proto_.size() < size) {
200     return Status::DataLoss();
201   }
202 
203   std::memcpy(out, proto_.data(), size);
204   proto_ = proto_.subspan(size);
205   previous_field_consumed_ = true;
206 
207   return OkStatus();
208 }
209 
ReadDelimited(std::span<const std::byte> * out)210 Status Decoder::ReadDelimited(std::span<const std::byte>* out) {
211   Status status = ConsumeKey(WireType::kDelimited);
212   if (!status.ok()) {
213     return status;
214   }
215 
216   uint64_t length;
217   size_t bytes_read = varint::Decode(proto_, &length);
218   if (bytes_read == 0) {
219     return Status::DataLoss();
220   }
221 
222   proto_ = proto_.subspan(bytes_read);
223   if (proto_.size() < length) {
224     return Status::DataLoss();
225   }
226 
227   *out = proto_.first(length);
228   proto_ = proto_.subspan(length);
229   previous_field_consumed_ = true;
230 
231   return OkStatus();
232 }
233 
Decode(std::span<const std::byte> proto)234 Status CallbackDecoder::Decode(std::span<const std::byte> proto) {
235   if (handler_ == nullptr || state_ != kReady) {
236     return Status::FailedPrecondition();
237   }
238 
239   state_ = kDecodeInProgress;
240   decoder_.Reset(proto);
241 
242   // Iterate the proto, calling the handler with each field number.
243   while (state_ == kDecodeInProgress) {
244     if (Status status = decoder_.Next(); !status.ok()) {
245       if (status.IsOutOfRange()) {
246         // Reached the end of the proto.
247         break;
248       }
249 
250       // Proto data is malformed.
251       return status;
252     }
253 
254     Status status = handler_->ProcessField(*this, decoder_.FieldNumber());
255     if (!status.ok()) {
256       state_ = status.IsCancelled() ? kDecodeCancelled : kDecodeFailed;
257       return status;
258     }
259 
260     // The callback function can modify the decoder's state; check that
261     // everything is still okay.
262     if (state_ == kDecodeFailed) {
263       break;
264     }
265   }
266 
267   if (state_ != kDecodeInProgress) {
268     return Status::DataLoss();
269   }
270 
271   state_ = kReady;
272   return OkStatus();
273 }
274 
275 }  // namespace pw::protobuf
276