1 // Copyright 2021 The Pigweed Authors
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License"); you may not
4 // use this file except in compliance with the License. You may obtain a copy of
5 // the License at
6 //
7 //     https://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
11 // WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
12 // License for the specific language governing permissions and limitations under
13 // the License.
14 #pragma once
15 
16 #include <cstddef>
17 #include <cstring>
18 #include <span>
19 
20 #include "pw_bytes/span.h"
21 #include "pw_protobuf/config.h"
22 #include "pw_protobuf/wire_format.h"
23 #include "pw_result/result.h"
24 #include "pw_status/try.h"
25 #include "pw_varint/varint.h"
26 
27 namespace pw::protobuf {
28 
29 // A streaming protobuf encoder which encodes to a user-specified buffer.
30 class Encoder {
31  public:
32   using SizeType = config::SizeType;
33 
Encoder(ByteSpan buffer,std::span<SizeType * > locations,std::span<SizeType * > stack)34   constexpr Encoder(ByteSpan buffer,
35                     std::span<SizeType*> locations,
36                     std::span<SizeType*> stack)
37       : buffer_(buffer),
38         cursor_(buffer.data()),
39         blob_locations_(locations),
40         blob_count_(0),
41         blob_stack_(stack),
42         depth_(0),
43         encode_status_(OkStatus()) {}
44 
45   // Disallow copy/assign to avoid confusion about who owns the buffer.
46   Encoder(const Encoder& other) = delete;
47   Encoder& operator=(const Encoder& other) = delete;
48 
49   // Per the protobuf specification, valid field numbers range between 1 and
50   // 2**29 - 1, inclusive. The numbers 19000-19999 are reserved for internal
51   // use.
52   constexpr static uint32_t kMaxFieldNumber = (1u << 29) - 1;
53   constexpr static uint32_t kFirstReservedNumber = 19000;
54   constexpr static uint32_t kLastReservedNumber = 19999;
55 
56   // Writes a proto uint32 key-value pair.
WriteUint32(uint32_t field_number,uint32_t value)57   Status WriteUint32(uint32_t field_number, uint32_t value) {
58     return WriteUint64(field_number, value);
59   }
60 
61   // Writes a repeated uint32 using packed encoding.
WritePackedUint32(uint32_t field_number,std::span<const uint32_t> values)62   Status WritePackedUint32(uint32_t field_number,
63                            std::span<const uint32_t> values) {
64     return WritePackedVarints(field_number, values, /*zigzag=*/false);
65   }
66 
67   // Writes a proto uint64 key-value pair.
68   Status WriteUint64(uint32_t field_number, uint64_t value);
69 
70   // Writes a repeated uint64 using packed encoding.
WritePackedUint64(uint64_t field_number,std::span<const uint64_t> values)71   Status WritePackedUint64(uint64_t field_number,
72                            std::span<const uint64_t> values) {
73     return WritePackedVarints(field_number, values, /*zigzag=*/false);
74   }
75 
76   // Writes a proto int32 key-value pair.
WriteInt32(uint32_t field_number,int32_t value)77   Status WriteInt32(uint32_t field_number, int32_t value) {
78     return WriteUint64(field_number, value);
79   }
80 
81   // Writes a repeated int32 using packed encoding.
WritePackedInt32(uint32_t field_number,std::span<const int32_t> values)82   Status WritePackedInt32(uint32_t field_number,
83                           std::span<const int32_t> values) {
84     return WritePackedVarints(
85         field_number,
86         std::span(reinterpret_cast<const uint32_t*>(values.data()),
87                   values.size()),
88         /*zigzag=*/false);
89   }
90 
91   // Writes a proto int64 key-value pair.
WriteInt64(uint32_t field_number,int64_t value)92   Status WriteInt64(uint32_t field_number, int64_t value) {
93     return WriteUint64(field_number, value);
94   }
95 
96   // Writes a repeated int64 using packed encoding.
WritePackedInt64(uint32_t field_number,std::span<const int64_t> values)97   Status WritePackedInt64(uint32_t field_number,
98                           std::span<const int64_t> values) {
99     return WritePackedVarints(
100         field_number,
101         std::span(reinterpret_cast<const uint64_t*>(values.data()),
102                   values.size()),
103         /*zigzag=*/false);
104   }
105 
106   // Writes a proto sint32 key-value pair.
WriteSint32(uint32_t field_number,int32_t value)107   Status WriteSint32(uint32_t field_number, int32_t value) {
108     return WriteUint64(field_number, varint::ZigZagEncode(value));
109   }
110 
111   // Writes a repeated sint32 using packed encoding.
WritePackedSint32(uint32_t field_number,std::span<const int32_t> values)112   Status WritePackedSint32(uint32_t field_number,
113                            std::span<const int32_t> values) {
114     return WritePackedVarints(
115         field_number,
116         std::span(reinterpret_cast<const uint32_t*>(values.data()),
117                   values.size()),
118         /*zigzag=*/true);
119   }
120 
121   // Writes a proto sint64 key-value pair.
WriteSint64(uint32_t field_number,int64_t value)122   Status WriteSint64(uint32_t field_number, int64_t value) {
123     return WriteUint64(field_number, varint::ZigZagEncode(value));
124   }
125 
126   // Writes a repeated sint64 using packed encoding.
WritePackedSint64(uint32_t field_number,std::span<const int64_t> values)127   Status WritePackedSint64(uint32_t field_number,
128                            std::span<const int64_t> values) {
129     return WritePackedVarints(
130         field_number,
131         std::span(reinterpret_cast<const uint64_t*>(values.data()),
132                   values.size()),
133         /*zigzag=*/true);
134   }
135 
136   // Writes a proto bool key-value pair.
WriteBool(uint32_t field_number,bool value)137   Status WriteBool(uint32_t field_number, bool value) {
138     return WriteUint32(field_number, static_cast<uint32_t>(value));
139   }
140 
141   // Writes a proto fixed32 key-value pair.
WriteFixed32(uint32_t field_number,uint32_t value)142   Status WriteFixed32(uint32_t field_number, uint32_t value) {
143     std::byte* original_cursor = cursor_;
144     WriteFieldKey(field_number, WireType::kFixed32);
145     WriteRawBytes(value);
146     return IncreaseParentSize(cursor_ - original_cursor);
147   }
148 
149   // Writes a repeated fixed32 field using packed encoding.
WritePackedFixed32(uint32_t field_number,std::span<const uint32_t> values)150   Status WritePackedFixed32(uint32_t field_number,
151                             std::span<const uint32_t> values) {
152     return WriteBytes(field_number, std::as_bytes(values));
153   }
154 
155   // Writes a proto fixed64 key-value pair.
WriteFixed64(uint32_t field_number,uint64_t value)156   Status WriteFixed64(uint32_t field_number, uint64_t value) {
157     std::byte* original_cursor = cursor_;
158     WriteFieldKey(field_number, WireType::kFixed64);
159     WriteRawBytes(value);
160     return IncreaseParentSize(cursor_ - original_cursor);
161   }
162 
163   // Writes a repeated fixed64 field using packed encoding.
WritePackedFixed64(uint32_t field_number,std::span<const uint64_t> values)164   Status WritePackedFixed64(uint32_t field_number,
165                             std::span<const uint64_t> values) {
166     return WriteBytes(field_number, std::as_bytes(values));
167   }
168 
169   // Writes a proto sfixed32 key-value pair.
WriteSfixed32(uint32_t field_number,int32_t value)170   Status WriteSfixed32(uint32_t field_number, int32_t value) {
171     return WriteFixed32(field_number, static_cast<uint32_t>(value));
172   }
173 
174   // Writes a repeated sfixed32 field using packed encoding.
WritePackedSfixed32(uint32_t field_number,std::span<const int32_t> values)175   Status WritePackedSfixed32(uint32_t field_number,
176                              std::span<const int32_t> values) {
177     return WriteBytes(field_number, std::as_bytes(values));
178   }
179 
180   // Writes a proto sfixed64 key-value pair.
WriteSfixed64(uint32_t field_number,int64_t value)181   Status WriteSfixed64(uint32_t field_number, int64_t value) {
182     return WriteFixed64(field_number, static_cast<uint64_t>(value));
183   }
184 
185   // Writes a repeated sfixed64 field using packed encoding.
WritePackedSfixed64(uint32_t field_number,std::span<const int64_t> values)186   Status WritePackedSfixed64(uint32_t field_number,
187                              std::span<const int64_t> values) {
188     return WriteBytes(field_number, std::as_bytes(values));
189   }
190 
191   // Writes a proto float key-value pair.
WriteFloat(uint32_t field_number,float value)192   Status WriteFloat(uint32_t field_number, float value) {
193     static_assert(sizeof(float) == sizeof(uint32_t),
194                   "Float and uint32_t are not the same size");
195     std::byte* original_cursor = cursor_;
196     WriteFieldKey(field_number, WireType::kFixed32);
197     WriteRawBytes(value);
198     return IncreaseParentSize(cursor_ - original_cursor);
199   }
200 
201   // Writes a repeated float field using packed encoding.
WritePackedFloat(uint32_t field_number,std::span<const float> values)202   Status WritePackedFloat(uint32_t field_number,
203                           std::span<const float> values) {
204     return WriteBytes(field_number, std::as_bytes(values));
205   }
206 
207   // Writes a proto double key-value pair.
WriteDouble(uint32_t field_number,double value)208   Status WriteDouble(uint32_t field_number, double value) {
209     static_assert(sizeof(double) == sizeof(uint64_t),
210                   "Double and uint64_t are not the same size");
211     std::byte* original_cursor = cursor_;
212     WriteFieldKey(field_number, WireType::kFixed64);
213     WriteRawBytes(value);
214     return IncreaseParentSize(cursor_ - original_cursor);
215   }
216 
217   // Writes a repeated double field using packed encoding.
WritePackedDouble(uint32_t field_number,std::span<const double> values)218   Status WritePackedDouble(uint32_t field_number,
219                            std::span<const double> values) {
220     return WriteBytes(field_number, std::as_bytes(values));
221   }
222 
223   // Writes a proto bytes key-value pair.
WriteBytes(uint32_t field_number,ConstByteSpan value)224   Status WriteBytes(uint32_t field_number, ConstByteSpan value) {
225     std::byte* original_cursor = cursor_;
226     WriteFieldKey(field_number, WireType::kDelimited);
227     WriteVarint(value.size_bytes());
228     WriteRawBytes(value.data(), value.size_bytes());
229     return IncreaseParentSize(cursor_ - original_cursor);
230   }
231 
232   // Writes a proto string key-value pair.
WriteString(uint32_t field_number,const char * value,size_t size)233   Status WriteString(uint32_t field_number, const char* value, size_t size) {
234     return WriteBytes(field_number, std::as_bytes(std::span(value, size)));
235   }
236 
WriteString(uint32_t field_number,const char * value)237   Status WriteString(uint32_t field_number, const char* value) {
238     return WriteString(field_number, value, strlen(value));
239   }
240 
241   // Begins writing a sub-message with a specified field number.
242   Status Push(uint32_t field_number);
243 
244   // Finishes writing a sub-message.
245   Status Pop();
246 
247   // Returns the total encoded size of the proto message.
EncodedSize()248   size_t EncodedSize() const { return cursor_ - buffer_.data(); }
249 
250   // Returns the number of bytes remaining in the buffer.
RemainingSize()251   size_t RemainingSize() const { return buffer_.size() - EncodedSize(); }
252 
253   // Resets write index to the start of the buffer. This invalidates any spans
254   // obtained from Encode().
Clear()255   void Clear() {
256     cursor_ = buffer_.data();
257     encode_status_ = OkStatus();
258     blob_count_ = 0;
259     depth_ = 0;
260   }
261 
262   // Runs a final encoding pass over the intermediary data and returns the
263   // encoded protobuf message.
264   Result<ConstByteSpan> Encode();
265 
266   // DEPRECATED. Use Encode() instead.
267   // TODO(frolv): Remove this after all references to it are updated.
Encode(ConstByteSpan * out)268   Status Encode(ConstByteSpan* out) {
269     Result result = Encode();
270     if (!result.ok()) {
271       return result.status();
272     }
273     *out = result.value();
274     return OkStatus();
275   }
276 
277  private:
ValidFieldNumber(uint32_t field_number)278   constexpr bool ValidFieldNumber(uint32_t field_number) const {
279     return field_number != 0 && field_number <= kMaxFieldNumber &&
280            !(field_number >= kFirstReservedNumber &&
281              field_number <= kLastReservedNumber);
282   }
283 
284   // Encodes the key for a proto field consisting of its number and wire type.
WriteFieldKey(uint32_t field_number,WireType wire_type)285   Status WriteFieldKey(uint32_t field_number, WireType wire_type) {
286     if (!ValidFieldNumber(field_number)) {
287       encode_status_ = Status::InvalidArgument();
288       return encode_status_;
289     }
290 
291     return WriteVarint(MakeKey(field_number, wire_type));
292   }
293 
294   Status WriteVarint(uint64_t value);
295 
WriteZigzagVarint(int64_t value)296   Status WriteZigzagVarint(int64_t value) {
297     return WriteVarint(varint::ZigZagEncode(value));
298   }
299 
300   template <typename T>
WriteRawBytes(const T & value)301   Status WriteRawBytes(const T& value) {
302     return WriteRawBytes(reinterpret_cast<const std::byte*>(&value),
303                          sizeof(value));
304   }
305 
306   Status WriteRawBytes(const std::byte* ptr, size_t size);
307 
308   // Writes a list of varints to the buffer in length-delimited packed encoding.
309   // If zigzag is true, zig-zag encodes each of the varints.
310   template <typename T>
WritePackedVarints(uint32_t field_number,std::span<T> values,bool zigzag)311   Status WritePackedVarints(uint32_t field_number,
312                             std::span<T> values,
313                             bool zigzag) {
314     if (Status status = Push(field_number); !status.ok()) {
315       return status;
316     }
317 
318     std::byte* original_cursor = cursor_;
319     for (T value : values) {
320       if (zigzag) {
321         WriteZigzagVarint(static_cast<std::make_signed_t<T>>(value));
322       } else {
323         WriteVarint(value);
324       }
325     }
326     PW_TRY(IncreaseParentSize(cursor_ - original_cursor));
327 
328     return Pop();
329   }
330 
331   // Adds to the parent proto's size field in the buffer.
332   Status IncreaseParentSize(size_t size_bytes);
333 
334   // Returns the size of `n` encoded as a varint.
VarintSizeBytes(uint64_t n)335   size_t VarintSizeBytes(uint64_t n) {
336     size_t size_bytes = 1;
337     while (n > 127) {
338       ++size_bytes;
339       n >>= 7;
340     }
341     return size_bytes;
342   }
343 
344   // Do the actual (potentially partial) encoding. Also used in Pop().
345   Result<ConstByteSpan> EncodeFrom(size_t blob);
346 
347   // The buffer into which the proto is encoded.
348   ByteSpan buffer_;
349   std::byte* cursor_;
350 
351   // List of pointers to sub-messages' delimiting size fields.
352   std::span<SizeType*> blob_locations_;
353   size_t blob_count_;
354 
355   // Stack of current nested message size locations. Push() operations add a new
356   // entry to this stack and Pop() operations remove one.
357   std::span<SizeType*> blob_stack_;
358   size_t depth_;
359 
360   Status encode_status_;
361 };
362 
363 // A proto encoder with its own blob stack.
364 template <size_t kMaxNestedDepth = 1, size_t kMaxBlobs = kMaxNestedDepth>
365 class NestedEncoder : public Encoder {
366  public:
NestedEncoder(ByteSpan buffer)367   NestedEncoder(ByteSpan buffer) : Encoder(buffer, blobs_, stack_) {}
368 
369   // Disallow copy/assign to avoid confusion about who owns the buffer.
370   NestedEncoder(const NestedEncoder& other) = delete;
371   NestedEncoder& operator=(const NestedEncoder& other) = delete;
372 
373  private:
374   std::array<Encoder::SizeType*, kMaxBlobs> blobs_;
375   std::array<Encoder::SizeType*, kMaxNestedDepth> stack_;
376 };
377 
378 // Explicit template argument deduction to hide warnings.
379 NestedEncoder()->NestedEncoder<>;
380 
381 }  // namespace pw::protobuf
382