1 /*
2  * Copyright (C) 2022 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #pragma once
18 
19 #include <cassert>
20 #include <cstring>
21 #include <map>
22 #include <optional>
23 #include <span>
24 #include <string>
25 #include <string_view>
26 #include <vector>
27 
28 #include <lk/compiler.h>
29 
30 #include <dice/cbor_reader.h>
31 #include <dice/cbor_writer.h>
32 
33 namespace cbor {
34 
35 /**
36  * readCborBoolean() - Read boolean value from CBOR input object.
37  * @in: Initialized CBOR input object to read from
38  *
39  * Return: boolean if read succeeds, %nullopt otherwise
40  */
readCborBoolean(struct CborIn & in)41 static inline std::optional<bool> readCborBoolean(struct CborIn& in) {
42     if (CborReadTrue(&in) == CBOR_READ_RESULT_OK) {
43         return true;
44     } else if (CborReadFalse(&in) == CBOR_READ_RESULT_OK) {
45         return false;
46     } else {
47         return std::nullopt;
48     }
49 }
50 
51 /**
52  * encodedSizeOf() - Get number of bytes required to encode CBOR item.
53  * @val: Value (for types with no content) or length of CBOR item.
54  *
55  * Figure how many bytes we need to encode a CBOR item with a particular value
56  * or item count. This function is not limited to determining the size of
57  * unsigned integers since the encoding of arrays, maps, and scalars use a
58  * similar encoding. For CBOR types that have content, the result does not
59  * include the bytes required to store the content itself.
60  *
61  * Return: Encoded size of CBOR value less the size of its content if any.
62  */
encodedSizeOf(uint64_t val)63 static inline size_t encodedSizeOf(uint64_t val) {
64     uint8_t buffer[16];
65     struct CborOut out;
66     CborOutInit(buffer, sizeof(buffer), &out);
67     CborWriteUint(val, &out);
68     assert(!CborOutOverflowed(&out));
69     return CborOutSize(&out);
70 }
71 
72 /**
73  * encodedSizeOfInt() - Get number of bytes required to encode signed integer.
74  * @val: Integer to encode.
75  *
76  * Return: Encoded size of signed integer.
77  */
encodedSizeOfInt(int64_t val)78 static inline size_t encodedSizeOfInt(int64_t val) {
79     uint8_t buffer[16];
80     struct CborOut out;
81     CborOutInit(buffer, sizeof(buffer), &out);
82     CborWriteInt(val, &out);
83     assert(!CborOutOverflowed(&out));
84     return CborOutSize(&out);
85 }
86 
87 /**
88  * Sorts the map in canonical order, as defined in RFC 7049.
89  * https://datatracker.ietf.org/doc/html/rfc7049#section-3.9
90  */
91 template <typename T>
92 struct CBORCompare {
operatorCBORCompare93     constexpr bool operator()(const std::vector<T>& a,
94                               const std::vector<T>& b) const {
95         return keyLess(a, b);
96     }
97 
98     /* Returns true iff key a sorts before key b in CBOR order */
keyLessCBORCompare99     constexpr bool keyLess(const std::span<const T>& a,
100                            const std::span<const T>& b) const {
101         /* If two keys have different lengths, the shorter one sorts earlier */
102         if (a.size() < b.size())
103             return true;
104         if (a.size() > b.size())
105             return false;
106 
107         /* If keys have the same length, do a byte-wise comparison */
108         return std::lexicographical_compare(a.begin(), a.end(), b.begin(),
109                                             b.end());
110     }
111 };
112 
113 using CborMap = std::map<std::vector<uint8_t>,
114                          std::span<const uint8_t>,
115                          CBORCompare<uint8_t>>;
116 
populateMap(CborMap & map,const std::span<const uint8_t> & encMap)117 static inline bool populateMap(CborMap& map,
118                                const std::span<const uint8_t>& encMap) {
119     if (!encMap.size()) {
120         /* No elements to add to map */
121         return true;
122     }
123 
124     struct CborIn in;
125     CborInInit(encMap.data(), encMap.size(), &in);
126 
127     size_t numPairs;
128     if (CborReadMap(&in, &numPairs) != CBOR_READ_RESULT_OK) {
129         return false;
130     }
131 
132     int64_t key;
133     struct CborOut out;
134     struct CborIn savedIn;
135     for (size_t i = 0; i < numPairs; i++) {
136         /* Read key */
137         if (CborReadInt(&in, &key) != CBOR_READ_RESULT_OK) {
138             return false;
139         }
140 
141         /* skip value */
142         savedIn = in;
143         if (CborReadSkip(&in) != CBOR_READ_RESULT_OK) {
144             return false;
145         }
146 
147         std::vector<uint8_t> encKey(encodedSizeOfInt(key));
148         CborOutInit(encKey.data(), encKey.size(), &out);
149         CborWriteInt(key, &out);
150         assert(!CborOutOverflowed(&out));
151 
152         std::span value(savedIn.buffer + savedIn.cursor,
153                         in.cursor - savedIn.cursor);
154 
155         map[std::move(encKey)] = value;
156     }
157 
158     return true;
159 }
160 
161 /**
162  * mergeMaps() - Merge the items in two CBOR maps, return canonical map.
163  *
164  * @lhs: CBOR-encoded map using signed integers as keys
165  * @rhs: CBOR-encoded map using signed integers as keys
166  *
167  * Return:
168  *      Canonical CBOR encoding of the combined map or %nullopt if an error
169  *      occurred.
170  */
mergeMaps(const std::span<const uint8_t> & lhs,const std::span<const uint8_t> & rhs)171 static inline std::optional<std::vector<uint8_t>> mergeMaps(
172         const std::span<const uint8_t>& lhs,
173         const std::span<const uint8_t>& rhs) {
174     /*
175      * map is sorted on the encoded key which ensures that the CBOR encoding is
176      * canonical.
177      */
178     CborMap map;
179 
180     if (!populateMap(map, lhs)) {
181         return std::nullopt;
182     }
183     if (!populateMap(map, rhs)) {
184         return std::nullopt;
185     }
186 
187     size_t outputSize = encodedSizeOf(map.size());
188     for (const auto& [key, value] : map) {
189         outputSize += key.size() + value.size();
190     }
191 
192     auto output = std::vector<uint8_t>(outputSize);
193     struct CborOut out;
194     CborOutInit(output.data(), output.size(), &out);
195     CborWriteMap(map.size(), &out);
196     for (const auto& [key, value] : map) {
197         /* insert key */
198         std::memcpy(output.data() + out.cursor, key.data(), key.size());
199         out.cursor += key.size();
200 
201         /* insert value */
202         std::memcpy(output.data() + out.cursor, value.data(), value.size());
203         out.cursor += value.size();
204     }
205 
206     assert(out.cursor == output.size());
207     assert(!CborOutOverflowed(&out));
208     return output;
209 }
210 
211 /**
212  * encodeBstrHeader() - write CBOR header for a binary string of a given size.
213  * @payloadSize: Size of binary string to encode header for.
214  * @outBufSize:  Size of output buffer.
215  * @outBuf:      Output buffer to write CBOR header to.
216  *
217  * Return:       A pointer to one past the last byte written.
218  */
encodeBstrHeader(uint64_t bstrSize,size_t outBufSize,uint8_t * outBuf)219 static inline uint8_t* encodeBstrHeader(uint64_t bstrSize,
220                                         size_t outBufSize,
221                                         uint8_t* outBuf) {
222     struct CborOut fakeOut;
223     const size_t bstrHeaderSize = cbor::encodedSizeOf(bstrSize);
224     assert(0 < bstrHeaderSize <= outBufSize);
225     size_t fakeBufferSize;
226     if (__builtin_add_overflow(bstrHeaderSize, bstrSize, &fakeBufferSize)) {
227         return nullptr;
228     }
229     // NOTE: CborAllocBstr will fail if we don't provide a buffer object that
230     // appears large enough. CborAllocBstr will *only* write header information
231     // about the binary string so it will only touch allocated memory.
232     CborOutInit(outBuf, fakeBufferSize, &fakeOut);
233     // CborAllocBstr will only write the type and length of the binary string
234     // into outBuf and manipulate the fakeOut object itself. Further
235     // writes to fakeOut will trigger memory corruption.
236     uint8_t* bstrHeaderEnd = CborAllocBstr(bstrSize, &fakeOut);
237     assert(!CborOutOverflowed(&fakeOut));
238     assert(bstrHeaderEnd != nullptr);
239     assert((size_t)(bstrHeaderEnd - outBuf) == bstrHeaderSize);
240 
241     return bstrHeaderEnd;
242 }
243 
244 class ArrayVector {
245 public:
data()246     uint8_t* data() const { return mArr.get(); }
247 
size()248     size_t size() const { return mSize; }
249 
250     /**
251      * resize() - change the reported size of underlying array.
252      * @count: New size of the array.
253      *
254      * This function is needed for compatibility with std::vector. We only
255      * support two cases 1) growing a zero-element array and 2) reducing the
256      * size of a non-zero element array without shrinking the underlying
257      * allocation.
258      */
resize(size_t count)259     void resize(size_t count) {
260         if (mSize == 0 && !mArr) {
261             mArr = std::unique_ptr<uint8_t[]>(new (std::nothrow)
262                                                       uint8_t[count]);
263             mSize = mArr ? /* success */ count : /* fail */ 0;
264         } else if (count <= mSize) {
265             mSize = count;
266         } else {
267             /*
268              * Shouldn't hit this case since the CountingEncoder computes how
269              * many bytes we need for encoding.
270              */
271             assert(false && "resizing existing array allocation not supported");
272         }
273 
274         assert(count <= mSize);
275         mSize = count;
276     }
277 
arr()278     std::unique_ptr<uint8_t[]> arr() { return std::move(mArr); }
279 
280 private:
281     std::unique_ptr<uint8_t[]> mArr;
282     size_t mSize = 0;
283 };
284 
285 /**
286  * This class wraps the open-dice API defined in `cbor_writer.h`. Users of this
287  * class need not determine the correct size of the output buffer manually. By
288  * accepting a set of callbacks that are pure (can be called multiple times),
289  * this class first performs a dry run to calculate the number of bytes needed
290  * to represent a structure as CBOR, then it allocates the necessary memory and
291  * performs the actual encoding.
292  *
293  * Limitations:
294  * * The encoder will enter an error state unless map keys are ordered
295  *   canonically as defined in section 3.9 of the CBOR RFC [0].
296  * * The outermost CBOR element must be a tag, map, or an array. Trying to
297  *   encode any other item with a newly created encoder object is not supported.
298  *
299  * [0]: https://datatracker.ietf.org/doc/html/rfc7049#section-3.9
300  */
301 template <typename V>
302 class CborEncoder {
303 private:
304     /**
305      * Helper class which provides the same interface as the CborEncoder but
306      * instead of encoding its arguments, it calculates the encoding length.
307      * This lets us precisely size the CBOR output buffer ahead of time instead
308      * of having to resize it on the fly.
309      *
310      * The counts provided by this class (bytes, array elements, map pairs) are
311      * only valid if the count didn't overflow. Users of this class must check
312      * whether an overflow happened before accepting any other property.
313      */
314     class CountingEncoder {
315     public:
316         template <typename Fn>
encodeTag(int64_t tag,Fn fn)317         void encodeTag(int64_t tag, Fn fn) {
318             CountingEncoder enc;
319             fn(enc);
320 
321             countBytes(encodedSizeOf(tag));
322             countBytes(enc);
323         }
324 
325         template <typename Fn>
encodeArray(Fn fn)326         void encodeArray(Fn fn) {
327             CountingEncoder enc;
328             fn(enc);
329 
330             countBytes(encodedSizeOf(enc.arrayElements()));
331             countBytes(enc);
332         }
333 
334         template <typename Fn>
encodeMap(Fn fn)335         void encodeMap(Fn fn) {
336             CountingEncoder enc;
337             fn(enc);
338 
339             countBytes(encodedSizeOf(enc.mapPairs()));
340             countBytes(enc);
341             mArrayElements++;
342         }
343 
344         template <typename Fn>
encodeKeyValue(int64_t key,Fn fn)345         void encodeKeyValue(int64_t key, Fn fn) {
346             CountingEncoder enc;
347             fn(enc);
348 
349             countBytesToEncode(key);
350             countBytes(enc);
351             mMapPairs++;
352         }
353 
encodeKeyValue(int64_t key,int64_t val)354         void encodeKeyValue(int64_t key, int64_t val) {
355             countBytesToEncode(key);
356             countBytesToEncode(val);
357             mMapPairs++;
358         }
359 
encodeKeyValue(int64_t key,int val)360         void encodeKeyValue(int64_t key, int val) {
361             encodeKeyValue(key, (int64_t)val);
362         }
363 
encodeKeyValue(int64_t key,__UNUSED bool val)364         void encodeKeyValue(int64_t key, __UNUSED bool val) {
365             countBytesToEncode(key);
366             /* Value 20 encodes false; 21 encodes true. Each requires a byte */
367             countBytes(1);
368             mMapPairs++;
369         }
370 
encodeKeyValue(int64_t key,const char * val)371         void encodeKeyValue(int64_t key, const char* val) {
372             size_t len = strlen(val);
373             countBytesToEncode(key);
374             countBytes(encodedSizeOf(len));
375             countBytes(len);
376             mMapPairs++;
377         }
378 
encodeTstr(const std::basic_string_view<char> str)379         void encodeTstr(const std::basic_string_view<char> str) {
380             size_t len = str.size();
381             countBytes(encodedSizeOf(len));
382             countBytes(len);
383             mArrayElements++;
384         }
385 
encodeTstr(const char * str)386         void encodeTstr(const char* str) {
387             size_t len = strlen(str);
388             countBytes(encodedSizeOf(len));
389             countBytes(len);
390             mArrayElements++;
391         }
392 
encodeBstr(const std::string & str)393         void encodeBstr(const std::string& str) {
394             encodeBstr(reinterpret_cast<const uint8_t*>(str.data()),
395                        str.size());
396         }
397 
encodeBstr(const std::vector<uint8_t> & vec)398         void encodeBstr(const std::vector<uint8_t>& vec) {
399             encodeBstr(vec.data(), vec.size());
400         }
401 
encodeBstr(const std::span<const uint8_t> & view)402         void encodeBstr(const std::span<const uint8_t>& view) {
403             encodeBstr(view.data(), view.size());
404         }
405 
encodeBstr(__UNUSED const uint8_t * src,const size_t srcsz)406         void encodeBstr(__UNUSED const uint8_t* src, const size_t srcsz) {
407             countBytes(encodedSizeOf(srcsz));
408             countBytes(srcsz);
409             mArrayElements++;
410         }
411 
encodeEmptyBstr()412         void encodeEmptyBstr() {
413             countBytes(1); /* null is encoded as value 22 and takes up a byte */
414             mArrayElements++;
415         }
416 
encodeInt(const int64_t val)417         void encodeInt(const int64_t val) {
418             countBytesToEncode(val);
419             mArrayElements++;
420         }
421 
encodeUint(const uint64_t val)422         void encodeUint(const uint64_t val) {
423             countBytes(encodedSizeOf(val));
424             mArrayElements++;
425         }
426 
encodeNull()427         void encodeNull() {
428             countBytes(1);
429             mArrayElements++;
430         }
431 
copyBytes(const std::span<const uint8_t> & view)432         bool copyBytes(const std::span<const uint8_t>& view) {
433             return copyBytes(view.data(), view.size());
434         }
435 
copyBytes(const std::vector<uint8_t> & vec)436         bool copyBytes(const std::vector<uint8_t>& vec) {
437             return copyBytes(vec.data(), vec.size());
438         }
439 
copyBytes(const uint8_t * src,const size_t srcsz)440         bool copyBytes(const uint8_t* src, const size_t srcsz) {
441             countBytes(srcsz);
442             mArrayElements++;
443             return !mOverflowed;
444         }
445 
overflowed()446         bool overflowed() const { return mOverflowed; }
447 
bytes()448         size_t bytes() const { return mBytes; }
449 
arrayElements()450         size_t arrayElements() const { return mArrayElements; }
451 
mapPairs()452         size_t mapPairs() const { return mMapPairs; }
453 
454     private:
455         /*
456          * if true, the count failed and other properties should not be relied
457          * upon for CBOR encoding.
458          */
459         bool mOverflowed = false;
460         /* bytes needed for CBOR encoding unless an overflow occurred */
461         size_t mBytes = 0;
462         /* array elements, not including sub-elements, to write */
463         size_t mArrayElements = 0;
464         /* map pairs, not including map pairs in sub-elements, to write */
465         size_t mMapPairs = 0;
466 
countBytes(size_t count)467         void countBytes(size_t count) {
468             mOverflowed |= __builtin_add_overflow(count, mBytes, &mBytes);
469         }
470 
countBytes(CountingEncoder enc)471         void countBytes(CountingEncoder enc) {
472             if (enc.overflowed()) {
473                 mOverflowed = true;
474                 return;
475             }
476             countBytes(enc.bytes());
477         }
478 
countBytesToEncode(int64_t val)479         void countBytesToEncode(int64_t val) {
480             countBytes(encodedSizeOfInt(val));
481         }
482 
483         /* CborEncoded calls countBytes */
484         friend class CborEncoder;
485     };
486 
487 public:
488     enum class State {
489         /* buffer not allocated */
490         kInitial,
491         /* initialization or resizing of buffer failed */
492         kInvalid,
493         /* encoding or ready to encode */
494         kEncoding,
495         /* encoding would have overflowed buffer */
496         kOverflowed,
497         /* encoder no longer owns buffer */
498         kEmptied,
499     };
500 
501     template <typename Fn>
encodeTag(int64_t tag,Fn fn)502     void encodeTag(int64_t tag, Fn fn) {
503         CountingEncoder enc;
504         fn(enc);
505         enc.countBytes(encodedSizeOf(tag));
506 
507         if (enc.overflowed()) {
508             mState = State::kOverflowed;
509             return;
510         }
511 
512         if (ensureCapacity(enc.bytes())) {
513             CborWriteTag(tag, &mOut);
514             fn(*this);
515         }
516     }
517 
518     template <typename Fn>
encodeArray(Fn fn)519     void encodeArray(Fn fn) {
520         CountingEncoder enc;
521         fn(enc);
522         enc.countBytes(encodedSizeOf(enc.arrayElements()));
523 
524         if (enc.overflowed()) {
525             mState = State::kOverflowed;
526             return;
527         }
528 
529         if (ensureCapacity(enc.bytes())) {
530             CborWriteArray(enc.arrayElements(), &mOut);
531             fn(*this);
532         }
533     }
534 
535     template <typename Fn>
encodeMap(Fn fn)536     void encodeMap(Fn fn) {
537         CountingEncoder enc;
538         fn(enc);
539         enc.countBytes(encodedSizeOf(enc.mapPairs()));
540 
541         if (enc.overflowed()) {
542             mState = State::kOverflowed;
543             return;
544         }
545 
546         if (ensureCapacity(enc.bytes())) {
547             CborWriteMap(enc.mapPairs(), &mOut);
548 
549             auto savedKey = mLastKey;
550             mLastKey = std::span<uint8_t>{};
551 
552             fn(*this);
553 
554             mLastKey = savedKey;
555         }
556     }
557 
558     template <typename Fn>
encodeKeyValue(int64_t key,Fn fn)559     void encodeKeyValue(int64_t key, Fn fn) {
560         encodeKeyCanonicalOrder([key, this] { CborWriteInt(key, &mOut); });
561 
562         fn(*this);
563     }
564 
encodeKeyValue(int64_t key,int val)565     void encodeKeyValue(int64_t key, int val) {
566         encodeKeyValue(key, (int64_t)val);
567     }
568 
encodeKeyValue(int64_t key,int64_t val)569     void encodeKeyValue(int64_t key, int64_t val) {
570         encodeKeyCanonicalOrder([key, this] { CborWriteInt(key, &mOut); });
571         CborWriteInt(val, &mOut);
572     }
573 
encodeKeyValue(int64_t key,bool val)574     void encodeKeyValue(int64_t key, bool val) {
575         encodeKeyCanonicalOrder([key, this] { CborWriteInt(key, &mOut); });
576         if (val)
577             CborWriteTrue(&mOut);
578         else
579             CborWriteFalse(&mOut);
580     }
581 
encodeKeyValue(int64_t key,const char * val)582     void encodeKeyValue(int64_t key, const char* val) {
583         encodeKeyCanonicalOrder([key, this] { CborWriteInt(key, &mOut); });
584         encodeTstr(val);
585     }
586 
encodeTstr(const char * str)587     void encodeTstr(const char* str) {
588         const std::string_view view(str);
589         encodeTstr(view);
590     }
591 
encodeTstr(const std::string_view str)592     void encodeTstr(const std::string_view str) {
593         ensureEncoding();
594         CborWriteTstr(str.data(), &mOut);
595     }
596 
encodeBstr(const std::string & str)597     void encodeBstr(const std::string& str) {
598         encodeBstr(reinterpret_cast<const uint8_t*>(str.data()), str.size());
599     }
600 
encodeBstr(const std::span<const uint8_t> & byteView)601     void encodeBstr(const std::span<const uint8_t>& byteView) {
602         encodeBstr(byteView.data(), byteView.size());
603     }
604 
encodeBstr(const std::vector<uint8_t> & vec)605     void encodeBstr(const std::vector<uint8_t>& vec) {
606         encodeBstr(vec.data(), vec.size());
607     }
608 
encodeBstr(const uint8_t * data,const size_t size)609     void encodeBstr(const uint8_t* data, const size_t size) {
610         ensureEncoding();
611         CborWriteBstr(size, data, &mOut);
612     }
613 
encodeEmptyBstr()614     void encodeEmptyBstr() {
615         ensureEncoding();
616         encodeBstr(nullptr, 0);
617     }
618 
encodeInt(const int64_t val)619     void encodeInt(const int64_t val) {
620         ensureEncoding();
621         CborWriteInt(val, &mOut);
622     }
623 
encodeUint(const uint64_t val)624     void encodeUint(const uint64_t val) {
625         ensureEncoding();
626         CborWriteUint(val, &mOut);
627     }
628 
encodeNull()629     void encodeNull() {
630         ensureEncoding();
631         CborWriteNull(&mOut);
632     }
633 
copyBytes(const std::span<const uint8_t> & view)634     bool copyBytes(const std::span<const uint8_t>& view) {
635         return copyBytes(view.data(), view.size());
636     }
637 
copyBytes(const std::vector<uint8_t> & vec)638     bool copyBytes(const std::vector<uint8_t>& vec) {
639         return copyBytes(vec.data(), vec.size());
640     }
641 
copyBytes(const uint8_t * src,const size_t srcsz)642     bool copyBytes(const uint8_t* src, const size_t srcsz) {
643         if (CborOutOverflowed(&mOut) || mState == State::kOverflowed) {
644             goto err_overflow;
645         }
646 
647         if (mState != State::kEncoding) {
648             mState = State::kInvalid;
649             return false;
650         }
651 
652         size_t dest;
653         if (__builtin_add_overflow((size_t)mOut.buffer, mOut.cursor, &dest)) {
654             goto err_overflow;
655         }
656 
657         size_t destsz;
658         if (__builtin_sub_overflow(mBuffer.size(), mOut.cursor, &destsz)) {
659             goto err_overflow;
660         }
661 
662         if (destsz < srcsz) {
663             goto err_overflow;
664         }
665 
666         std::memcpy((void*)dest, src, srcsz);
667         mOut.cursor += srcsz;
668         return true;
669 
670     err_overflow:
671         mState = State::kOverflowed;
672         return false;
673     }
674 
intoVec()675     V intoVec() {
676         assert(mState != State::kEmptied && "buffer was moved out of encoder");
677         assert(mState != State::kInvalid && "encoder is in an invalid state");
678         if (mState != State::kInitial) {
679             assert((!CborOutOverflowed(&mOut) &&
680                     mState != State::kOverflowed) &&
681                    "buffer was too small to hold cbor encoded content");
682             assert(mBuffer.size() == CborOutSize(&mOut) &&
683                    "buffer was larger than required to hold encoded content");
684         }
685         mState = State::kEmptied;
686         return std::move(mBuffer);
687     }
688 
view()689     std::span<const uint8_t> view() const {
690         assert((mState == State::kInitial || mState == State::kEncoding) &&
691                "requested view of buffer from encoder in invalid state");
692         if (mState != State::kInitial) {
693             assert((!CborOutOverflowed(&mOut) &&
694                     mState != State::kOverflowed) &&
695                    "buffer was too small to hold CBOR encoded content");
696             assert(mBuffer.size() == CborOutSize(&mOut) &&
697                    "buffer was larger than required to hold CBOR encoded content");
698         }
699         return mBuffer;
700     }
701 
size()702     size_t size() const {
703         if (mState != State::kInitial) {
704             assert((!CborOutOverflowed(&mOut) &&
705                     mState != State::kOverflowed) &&
706                    "requested encoding size after overflow");
707 
708             return CborOutSize(&mOut);
709         } else {
710             return 0u;
711         }
712     }
713 
state()714     State state() const { return mState; }
715 
716 private:
717     State mState = State::kInitial;
718     /* vector or vector-like buffer */
719     V mBuffer;
720     /* cursor used for CBOR encoding which points into mBuffer */
721     struct CborOut mOut;
722     /*
723      * Used to ensure that map keys are encoded in canonical order. When the
724      * encoder is not encoding a map, the field has no value.
725      */
726     std::optional<std::span<const uint8_t>> mLastKey = std::nullopt;
727     /* determines CBOR ordering between two keys */
728     CBORCompare<uint8_t> mComparer;
729 
ensureEncoding()730     void ensureEncoding() const {
731         assert(mState == State::kEncoding &&
732                "Call encodeArray, encodeTag, or encodeMap before this method");
733     }
734 
ensureCapacity(const size_t capacity)735     bool ensureCapacity(const size_t capacity) {
736         if (mState == State::kInitial) {
737             mBuffer.resize(capacity);
738             CborOutInit(mBuffer.data(), mBuffer.size(), &mOut);
739             mState = mBuffer.size() == capacity ? State::kEncoding
740                                                 : State::kInvalid;
741         }
742         return mState == State::kEncoding;
743     }
744 
745     template <typename Fn>
encodeKeyCanonicalOrder(Fn fn)746     void encodeKeyCanonicalOrder(Fn fn) {
747         if (mState != State::kEncoding || !mLastKey.has_value()) {
748             mState = State::kInvalid;
749             return;
750         }
751 
752         const struct CborOut preCursor = mOut;
753         fn();
754 
755         const size_t newKeySz = mOut.cursor - preCursor.cursor;
756         const uint8_t* newKeyStart = preCursor.buffer + preCursor.cursor;
757         const std::span newKey(newKeyStart, newKeySz);
758 
759         /*
760          * The keys in every map must be sorted lowest value to highest.
761          * Sorting is performed on the bytes of the representation of the key
762          * data items without paying attention to the 3/5 bit splitting for
763          * major types. The sorting rules are:
764          *
765          *  * If two keys have different lengths, the shorter one sorts earlier;
766          *
767          *  * If two keys have the same length, the one with the lower value
768          *    in (byte-wise) lexical order sorts earlier.
769          */
770         if (mComparer.keyLess(mLastKey.value(), newKey)) {
771             mLastKey = newKey;
772             return;
773         }
774 
775         /* CBOR encoding is not canonical */
776         mState = State::kInvalid;
777     }
778 };
779 
780 using ArrayCborEncoder = CborEncoder<ArrayVector>;
781 using VectorCborEncoder = CborEncoder<std::vector<uint8_t>>;
782 
783 }  // namespace cbor
784