1 /*
2 * Copyright (C) 2022 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #pragma once
18
19 #include <cassert>
20 #include <cstring>
21 #include <map>
22 #include <optional>
23 #include <span>
24 #include <string>
25 #include <string_view>
26 #include <vector>
27
28 #include <lk/compiler.h>
29
30 #include <dice/cbor_reader.h>
31 #include <dice/cbor_writer.h>
32
33 namespace cbor {
34
35 /**
36 * readCborBoolean() - Read boolean value from CBOR input object.
37 * @in: Initialized CBOR input object to read from
38 *
39 * Return: boolean if read succeeds, %nullopt otherwise
40 */
readCborBoolean(struct CborIn & in)41 static inline std::optional<bool> readCborBoolean(struct CborIn& in) {
42 if (CborReadTrue(&in) == CBOR_READ_RESULT_OK) {
43 return true;
44 } else if (CborReadFalse(&in) == CBOR_READ_RESULT_OK) {
45 return false;
46 } else {
47 return std::nullopt;
48 }
49 }
50
51 /**
52 * encodedSizeOf() - Get number of bytes required to encode CBOR item.
53 * @val: Value (for types with no content) or length of CBOR item.
54 *
55 * Figure how many bytes we need to encode a CBOR item with a particular value
56 * or item count. This function is not limited to determining the size of
57 * unsigned integers since the encoding of arrays, maps, and scalars use a
58 * similar encoding. For CBOR types that have content, the result does not
59 * include the bytes required to store the content itself.
60 *
61 * Return: Encoded size of CBOR value less the size of its content if any.
62 */
encodedSizeOf(uint64_t val)63 static inline size_t encodedSizeOf(uint64_t val) {
64 uint8_t buffer[16];
65 struct CborOut out;
66 CborOutInit(buffer, sizeof(buffer), &out);
67 CborWriteUint(val, &out);
68 assert(!CborOutOverflowed(&out));
69 return CborOutSize(&out);
70 }
71
72 /**
73 * encodedSizeOfInt() - Get number of bytes required to encode signed integer.
74 * @val: Integer to encode.
75 *
76 * Return: Encoded size of signed integer.
77 */
encodedSizeOfInt(int64_t val)78 static inline size_t encodedSizeOfInt(int64_t val) {
79 uint8_t buffer[16];
80 struct CborOut out;
81 CborOutInit(buffer, sizeof(buffer), &out);
82 CborWriteInt(val, &out);
83 assert(!CborOutOverflowed(&out));
84 return CborOutSize(&out);
85 }
86
87 /**
88 * Sorts the map in canonical order, as defined in RFC 7049.
89 * https://datatracker.ietf.org/doc/html/rfc7049#section-3.9
90 */
91 template <typename T>
92 struct CBORCompare {
operatorCBORCompare93 constexpr bool operator()(const std::vector<T>& a,
94 const std::vector<T>& b) const {
95 return keyLess(a, b);
96 }
97
98 /* Returns true iff key a sorts before key b in CBOR order */
keyLessCBORCompare99 constexpr bool keyLess(const std::span<const T>& a,
100 const std::span<const T>& b) const {
101 /* If two keys have different lengths, the shorter one sorts earlier */
102 if (a.size() < b.size())
103 return true;
104 if (a.size() > b.size())
105 return false;
106
107 /* If keys have the same length, do a byte-wise comparison */
108 return std::lexicographical_compare(a.begin(), a.end(), b.begin(),
109 b.end());
110 }
111 };
112
113 using CborMap = std::map<std::vector<uint8_t>,
114 std::span<const uint8_t>,
115 CBORCompare<uint8_t>>;
116
populateMap(CborMap & map,const std::span<const uint8_t> & encMap)117 static inline bool populateMap(CborMap& map,
118 const std::span<const uint8_t>& encMap) {
119 if (!encMap.size()) {
120 /* No elements to add to map */
121 return true;
122 }
123
124 struct CborIn in;
125 CborInInit(encMap.data(), encMap.size(), &in);
126
127 size_t numPairs;
128 if (CborReadMap(&in, &numPairs) != CBOR_READ_RESULT_OK) {
129 return false;
130 }
131
132 int64_t key;
133 struct CborOut out;
134 struct CborIn savedIn;
135 for (size_t i = 0; i < numPairs; i++) {
136 /* Read key */
137 if (CborReadInt(&in, &key) != CBOR_READ_RESULT_OK) {
138 return false;
139 }
140
141 /* skip value */
142 savedIn = in;
143 if (CborReadSkip(&in) != CBOR_READ_RESULT_OK) {
144 return false;
145 }
146
147 std::vector<uint8_t> encKey(encodedSizeOfInt(key));
148 CborOutInit(encKey.data(), encKey.size(), &out);
149 CborWriteInt(key, &out);
150 assert(!CborOutOverflowed(&out));
151
152 std::span value(savedIn.buffer + savedIn.cursor,
153 in.cursor - savedIn.cursor);
154
155 map[std::move(encKey)] = value;
156 }
157
158 return true;
159 }
160
161 /**
162 * mergeMaps() - Merge the items in two CBOR maps, return canonical map.
163 *
164 * @lhs: CBOR-encoded map using signed integers as keys
165 * @rhs: CBOR-encoded map using signed integers as keys
166 *
167 * Return:
168 * Canonical CBOR encoding of the combined map or %nullopt if an error
169 * occurred.
170 */
mergeMaps(const std::span<const uint8_t> & lhs,const std::span<const uint8_t> & rhs)171 static inline std::optional<std::vector<uint8_t>> mergeMaps(
172 const std::span<const uint8_t>& lhs,
173 const std::span<const uint8_t>& rhs) {
174 /*
175 * map is sorted on the encoded key which ensures that the CBOR encoding is
176 * canonical.
177 */
178 CborMap map;
179
180 if (!populateMap(map, lhs)) {
181 return std::nullopt;
182 }
183 if (!populateMap(map, rhs)) {
184 return std::nullopt;
185 }
186
187 size_t outputSize = encodedSizeOf(map.size());
188 for (const auto& [key, value] : map) {
189 outputSize += key.size() + value.size();
190 }
191
192 auto output = std::vector<uint8_t>(outputSize);
193 struct CborOut out;
194 CborOutInit(output.data(), output.size(), &out);
195 CborWriteMap(map.size(), &out);
196 for (const auto& [key, value] : map) {
197 /* insert key */
198 std::memcpy(output.data() + out.cursor, key.data(), key.size());
199 out.cursor += key.size();
200
201 /* insert value */
202 std::memcpy(output.data() + out.cursor, value.data(), value.size());
203 out.cursor += value.size();
204 }
205
206 assert(out.cursor == output.size());
207 assert(!CborOutOverflowed(&out));
208 return output;
209 }
210
211 /**
212 * encodeBstrHeader() - write CBOR header for a binary string of a given size.
213 * @payloadSize: Size of binary string to encode header for.
214 * @outBufSize: Size of output buffer.
215 * @outBuf: Output buffer to write CBOR header to.
216 *
217 * Return: A pointer to one past the last byte written.
218 */
encodeBstrHeader(uint64_t bstrSize,size_t outBufSize,uint8_t * outBuf)219 static inline uint8_t* encodeBstrHeader(uint64_t bstrSize,
220 size_t outBufSize,
221 uint8_t* outBuf) {
222 struct CborOut fakeOut;
223 const size_t bstrHeaderSize = cbor::encodedSizeOf(bstrSize);
224 assert(0 < bstrHeaderSize <= outBufSize);
225 size_t fakeBufferSize;
226 if (__builtin_add_overflow(bstrHeaderSize, bstrSize, &fakeBufferSize)) {
227 return nullptr;
228 }
229 // NOTE: CborAllocBstr will fail if we don't provide a buffer object that
230 // appears large enough. CborAllocBstr will *only* write header information
231 // about the binary string so it will only touch allocated memory.
232 CborOutInit(outBuf, fakeBufferSize, &fakeOut);
233 // CborAllocBstr will only write the type and length of the binary string
234 // into outBuf and manipulate the fakeOut object itself. Further
235 // writes to fakeOut will trigger memory corruption.
236 uint8_t* bstrHeaderEnd = CborAllocBstr(bstrSize, &fakeOut);
237 assert(!CborOutOverflowed(&fakeOut));
238 assert(bstrHeaderEnd != nullptr);
239 assert((size_t)(bstrHeaderEnd - outBuf) == bstrHeaderSize);
240
241 return bstrHeaderEnd;
242 }
243
244 class ArrayVector {
245 public:
data()246 uint8_t* data() const { return mArr.get(); }
247
size()248 size_t size() const { return mSize; }
249
250 /**
251 * resize() - change the reported size of underlying array.
252 * @count: New size of the array.
253 *
254 * This function is needed for compatibility with std::vector. We only
255 * support two cases 1) growing a zero-element array and 2) reducing the
256 * size of a non-zero element array without shrinking the underlying
257 * allocation.
258 */
resize(size_t count)259 void resize(size_t count) {
260 if (mSize == 0 && !mArr) {
261 mArr = std::unique_ptr<uint8_t[]>(new (std::nothrow)
262 uint8_t[count]);
263 mSize = mArr ? /* success */ count : /* fail */ 0;
264 } else if (count <= mSize) {
265 mSize = count;
266 } else {
267 /*
268 * Shouldn't hit this case since the CountingEncoder computes how
269 * many bytes we need for encoding.
270 */
271 assert(false && "resizing existing array allocation not supported");
272 }
273
274 assert(count <= mSize);
275 mSize = count;
276 }
277
arr()278 std::unique_ptr<uint8_t[]> arr() { return std::move(mArr); }
279
280 private:
281 std::unique_ptr<uint8_t[]> mArr;
282 size_t mSize = 0;
283 };
284
285 /**
286 * This class wraps the open-dice API defined in `cbor_writer.h`. Users of this
287 * class need not determine the correct size of the output buffer manually. By
288 * accepting a set of callbacks that are pure (can be called multiple times),
289 * this class first performs a dry run to calculate the number of bytes needed
290 * to represent a structure as CBOR, then it allocates the necessary memory and
291 * performs the actual encoding.
292 *
293 * Limitations:
294 * * The encoder will enter an error state unless map keys are ordered
295 * canonically as defined in section 3.9 of the CBOR RFC [0].
296 * * The outermost CBOR element must be a tag, map, or an array. Trying to
297 * encode any other item with a newly created encoder object is not supported.
298 *
299 * [0]: https://datatracker.ietf.org/doc/html/rfc7049#section-3.9
300 */
301 template <typename V>
302 class CborEncoder {
303 private:
304 /**
305 * Helper class which provides the same interface as the CborEncoder but
306 * instead of encoding its arguments, it calculates the encoding length.
307 * This lets us precisely size the CBOR output buffer ahead of time instead
308 * of having to resize it on the fly.
309 *
310 * The counts provided by this class (bytes, array elements, map pairs) are
311 * only valid if the count didn't overflow. Users of this class must check
312 * whether an overflow happened before accepting any other property.
313 */
314 class CountingEncoder {
315 public:
316 template <typename Fn>
encodeTag(int64_t tag,Fn fn)317 void encodeTag(int64_t tag, Fn fn) {
318 CountingEncoder enc;
319 fn(enc);
320
321 countBytes(encodedSizeOf(tag));
322 countBytes(enc);
323 }
324
325 template <typename Fn>
encodeArray(Fn fn)326 void encodeArray(Fn fn) {
327 CountingEncoder enc;
328 fn(enc);
329
330 countBytes(encodedSizeOf(enc.arrayElements()));
331 countBytes(enc);
332 }
333
334 template <typename Fn>
encodeMap(Fn fn)335 void encodeMap(Fn fn) {
336 CountingEncoder enc;
337 fn(enc);
338
339 countBytes(encodedSizeOf(enc.mapPairs()));
340 countBytes(enc);
341 mArrayElements++;
342 }
343
344 template <typename Fn>
encodeKeyValue(int64_t key,Fn fn)345 void encodeKeyValue(int64_t key, Fn fn) {
346 CountingEncoder enc;
347 fn(enc);
348
349 countBytesToEncode(key);
350 countBytes(enc);
351 mMapPairs++;
352 }
353
encodeKeyValue(int64_t key,int64_t val)354 void encodeKeyValue(int64_t key, int64_t val) {
355 countBytesToEncode(key);
356 countBytesToEncode(val);
357 mMapPairs++;
358 }
359
encodeKeyValue(int64_t key,int val)360 void encodeKeyValue(int64_t key, int val) {
361 encodeKeyValue(key, (int64_t)val);
362 }
363
encodeKeyValue(int64_t key,__UNUSED bool val)364 void encodeKeyValue(int64_t key, __UNUSED bool val) {
365 countBytesToEncode(key);
366 /* Value 20 encodes false; 21 encodes true. Each requires a byte */
367 countBytes(1);
368 mMapPairs++;
369 }
370
encodeKeyValue(int64_t key,const char * val)371 void encodeKeyValue(int64_t key, const char* val) {
372 size_t len = strlen(val);
373 countBytesToEncode(key);
374 countBytes(encodedSizeOf(len));
375 countBytes(len);
376 mMapPairs++;
377 }
378
encodeTstr(const std::basic_string_view<char> str)379 void encodeTstr(const std::basic_string_view<char> str) {
380 size_t len = str.size();
381 countBytes(encodedSizeOf(len));
382 countBytes(len);
383 mArrayElements++;
384 }
385
encodeTstr(const char * str)386 void encodeTstr(const char* str) {
387 size_t len = strlen(str);
388 countBytes(encodedSizeOf(len));
389 countBytes(len);
390 mArrayElements++;
391 }
392
encodeBstr(const std::string & str)393 void encodeBstr(const std::string& str) {
394 encodeBstr(reinterpret_cast<const uint8_t*>(str.data()),
395 str.size());
396 }
397
encodeBstr(const std::vector<uint8_t> & vec)398 void encodeBstr(const std::vector<uint8_t>& vec) {
399 encodeBstr(vec.data(), vec.size());
400 }
401
encodeBstr(const std::span<const uint8_t> & view)402 void encodeBstr(const std::span<const uint8_t>& view) {
403 encodeBstr(view.data(), view.size());
404 }
405
encodeBstr(__UNUSED const uint8_t * src,const size_t srcsz)406 void encodeBstr(__UNUSED const uint8_t* src, const size_t srcsz) {
407 countBytes(encodedSizeOf(srcsz));
408 countBytes(srcsz);
409 mArrayElements++;
410 }
411
encodeEmptyBstr()412 void encodeEmptyBstr() {
413 countBytes(1); /* null is encoded as value 22 and takes up a byte */
414 mArrayElements++;
415 }
416
encodeInt(const int64_t val)417 void encodeInt(const int64_t val) {
418 countBytesToEncode(val);
419 mArrayElements++;
420 }
421
encodeUint(const uint64_t val)422 void encodeUint(const uint64_t val) {
423 countBytes(encodedSizeOf(val));
424 mArrayElements++;
425 }
426
encodeNull()427 void encodeNull() {
428 countBytes(1);
429 mArrayElements++;
430 }
431
copyBytes(const std::span<const uint8_t> & view)432 bool copyBytes(const std::span<const uint8_t>& view) {
433 return copyBytes(view.data(), view.size());
434 }
435
copyBytes(const std::vector<uint8_t> & vec)436 bool copyBytes(const std::vector<uint8_t>& vec) {
437 return copyBytes(vec.data(), vec.size());
438 }
439
copyBytes(const uint8_t * src,const size_t srcsz)440 bool copyBytes(const uint8_t* src, const size_t srcsz) {
441 countBytes(srcsz);
442 mArrayElements++;
443 return !mOverflowed;
444 }
445
overflowed()446 bool overflowed() const { return mOverflowed; }
447
bytes()448 size_t bytes() const { return mBytes; }
449
arrayElements()450 size_t arrayElements() const { return mArrayElements; }
451
mapPairs()452 size_t mapPairs() const { return mMapPairs; }
453
454 private:
455 /*
456 * if true, the count failed and other properties should not be relied
457 * upon for CBOR encoding.
458 */
459 bool mOverflowed = false;
460 /* bytes needed for CBOR encoding unless an overflow occurred */
461 size_t mBytes = 0;
462 /* array elements, not including sub-elements, to write */
463 size_t mArrayElements = 0;
464 /* map pairs, not including map pairs in sub-elements, to write */
465 size_t mMapPairs = 0;
466
countBytes(size_t count)467 void countBytes(size_t count) {
468 mOverflowed |= __builtin_add_overflow(count, mBytes, &mBytes);
469 }
470
countBytes(CountingEncoder enc)471 void countBytes(CountingEncoder enc) {
472 if (enc.overflowed()) {
473 mOverflowed = true;
474 return;
475 }
476 countBytes(enc.bytes());
477 }
478
countBytesToEncode(int64_t val)479 void countBytesToEncode(int64_t val) {
480 countBytes(encodedSizeOfInt(val));
481 }
482
483 /* CborEncoded calls countBytes */
484 friend class CborEncoder;
485 };
486
487 public:
488 enum class State {
489 /* buffer not allocated */
490 kInitial,
491 /* initialization or resizing of buffer failed */
492 kInvalid,
493 /* encoding or ready to encode */
494 kEncoding,
495 /* encoding would have overflowed buffer */
496 kOverflowed,
497 /* encoder no longer owns buffer */
498 kEmptied,
499 };
500
501 template <typename Fn>
encodeTag(int64_t tag,Fn fn)502 void encodeTag(int64_t tag, Fn fn) {
503 CountingEncoder enc;
504 fn(enc);
505 enc.countBytes(encodedSizeOf(tag));
506
507 if (enc.overflowed()) {
508 mState = State::kOverflowed;
509 return;
510 }
511
512 if (ensureCapacity(enc.bytes())) {
513 CborWriteTag(tag, &mOut);
514 fn(*this);
515 }
516 }
517
518 template <typename Fn>
encodeArray(Fn fn)519 void encodeArray(Fn fn) {
520 CountingEncoder enc;
521 fn(enc);
522 enc.countBytes(encodedSizeOf(enc.arrayElements()));
523
524 if (enc.overflowed()) {
525 mState = State::kOverflowed;
526 return;
527 }
528
529 if (ensureCapacity(enc.bytes())) {
530 CborWriteArray(enc.arrayElements(), &mOut);
531 fn(*this);
532 }
533 }
534
535 template <typename Fn>
encodeMap(Fn fn)536 void encodeMap(Fn fn) {
537 CountingEncoder enc;
538 fn(enc);
539 enc.countBytes(encodedSizeOf(enc.mapPairs()));
540
541 if (enc.overflowed()) {
542 mState = State::kOverflowed;
543 return;
544 }
545
546 if (ensureCapacity(enc.bytes())) {
547 CborWriteMap(enc.mapPairs(), &mOut);
548
549 auto savedKey = mLastKey;
550 mLastKey = std::span<uint8_t>{};
551
552 fn(*this);
553
554 mLastKey = savedKey;
555 }
556 }
557
558 template <typename Fn>
encodeKeyValue(int64_t key,Fn fn)559 void encodeKeyValue(int64_t key, Fn fn) {
560 encodeKeyCanonicalOrder([key, this] { CborWriteInt(key, &mOut); });
561
562 fn(*this);
563 }
564
encodeKeyValue(int64_t key,int val)565 void encodeKeyValue(int64_t key, int val) {
566 encodeKeyValue(key, (int64_t)val);
567 }
568
encodeKeyValue(int64_t key,int64_t val)569 void encodeKeyValue(int64_t key, int64_t val) {
570 encodeKeyCanonicalOrder([key, this] { CborWriteInt(key, &mOut); });
571 CborWriteInt(val, &mOut);
572 }
573
encodeKeyValue(int64_t key,bool val)574 void encodeKeyValue(int64_t key, bool val) {
575 encodeKeyCanonicalOrder([key, this] { CborWriteInt(key, &mOut); });
576 if (val)
577 CborWriteTrue(&mOut);
578 else
579 CborWriteFalse(&mOut);
580 }
581
encodeKeyValue(int64_t key,const char * val)582 void encodeKeyValue(int64_t key, const char* val) {
583 encodeKeyCanonicalOrder([key, this] { CborWriteInt(key, &mOut); });
584 encodeTstr(val);
585 }
586
encodeTstr(const char * str)587 void encodeTstr(const char* str) {
588 const std::string_view view(str);
589 encodeTstr(view);
590 }
591
encodeTstr(const std::string_view str)592 void encodeTstr(const std::string_view str) {
593 ensureEncoding();
594 CborWriteTstr(str.data(), &mOut);
595 }
596
encodeBstr(const std::string & str)597 void encodeBstr(const std::string& str) {
598 encodeBstr(reinterpret_cast<const uint8_t*>(str.data()), str.size());
599 }
600
encodeBstr(const std::span<const uint8_t> & byteView)601 void encodeBstr(const std::span<const uint8_t>& byteView) {
602 encodeBstr(byteView.data(), byteView.size());
603 }
604
encodeBstr(const std::vector<uint8_t> & vec)605 void encodeBstr(const std::vector<uint8_t>& vec) {
606 encodeBstr(vec.data(), vec.size());
607 }
608
encodeBstr(const uint8_t * data,const size_t size)609 void encodeBstr(const uint8_t* data, const size_t size) {
610 ensureEncoding();
611 CborWriteBstr(size, data, &mOut);
612 }
613
encodeEmptyBstr()614 void encodeEmptyBstr() {
615 ensureEncoding();
616 encodeBstr(nullptr, 0);
617 }
618
encodeInt(const int64_t val)619 void encodeInt(const int64_t val) {
620 ensureEncoding();
621 CborWriteInt(val, &mOut);
622 }
623
encodeUint(const uint64_t val)624 void encodeUint(const uint64_t val) {
625 ensureEncoding();
626 CborWriteUint(val, &mOut);
627 }
628
encodeNull()629 void encodeNull() {
630 ensureEncoding();
631 CborWriteNull(&mOut);
632 }
633
copyBytes(const std::span<const uint8_t> & view)634 bool copyBytes(const std::span<const uint8_t>& view) {
635 return copyBytes(view.data(), view.size());
636 }
637
copyBytes(const std::vector<uint8_t> & vec)638 bool copyBytes(const std::vector<uint8_t>& vec) {
639 return copyBytes(vec.data(), vec.size());
640 }
641
copyBytes(const uint8_t * src,const size_t srcsz)642 bool copyBytes(const uint8_t* src, const size_t srcsz) {
643 if (CborOutOverflowed(&mOut) || mState == State::kOverflowed) {
644 goto err_overflow;
645 }
646
647 if (mState != State::kEncoding) {
648 mState = State::kInvalid;
649 return false;
650 }
651
652 size_t dest;
653 if (__builtin_add_overflow((size_t)mOut.buffer, mOut.cursor, &dest)) {
654 goto err_overflow;
655 }
656
657 size_t destsz;
658 if (__builtin_sub_overflow(mBuffer.size(), mOut.cursor, &destsz)) {
659 goto err_overflow;
660 }
661
662 if (destsz < srcsz) {
663 goto err_overflow;
664 }
665
666 std::memcpy((void*)dest, src, srcsz);
667 mOut.cursor += srcsz;
668 return true;
669
670 err_overflow:
671 mState = State::kOverflowed;
672 return false;
673 }
674
intoVec()675 V intoVec() {
676 assert(mState != State::kEmptied && "buffer was moved out of encoder");
677 assert(mState != State::kInvalid && "encoder is in an invalid state");
678 if (mState != State::kInitial) {
679 assert((!CborOutOverflowed(&mOut) &&
680 mState != State::kOverflowed) &&
681 "buffer was too small to hold cbor encoded content");
682 assert(mBuffer.size() == CborOutSize(&mOut) &&
683 "buffer was larger than required to hold encoded content");
684 }
685 mState = State::kEmptied;
686 return std::move(mBuffer);
687 }
688
view()689 std::span<const uint8_t> view() const {
690 assert((mState == State::kInitial || mState == State::kEncoding) &&
691 "requested view of buffer from encoder in invalid state");
692 if (mState != State::kInitial) {
693 assert((!CborOutOverflowed(&mOut) &&
694 mState != State::kOverflowed) &&
695 "buffer was too small to hold CBOR encoded content");
696 assert(mBuffer.size() == CborOutSize(&mOut) &&
697 "buffer was larger than required to hold CBOR encoded content");
698 }
699 return mBuffer;
700 }
701
size()702 size_t size() const {
703 if (mState != State::kInitial) {
704 assert((!CborOutOverflowed(&mOut) &&
705 mState != State::kOverflowed) &&
706 "requested encoding size after overflow");
707
708 return CborOutSize(&mOut);
709 } else {
710 return 0u;
711 }
712 }
713
state()714 State state() const { return mState; }
715
716 private:
717 State mState = State::kInitial;
718 /* vector or vector-like buffer */
719 V mBuffer;
720 /* cursor used for CBOR encoding which points into mBuffer */
721 struct CborOut mOut;
722 /*
723 * Used to ensure that map keys are encoded in canonical order. When the
724 * encoder is not encoding a map, the field has no value.
725 */
726 std::optional<std::span<const uint8_t>> mLastKey = std::nullopt;
727 /* determines CBOR ordering between two keys */
728 CBORCompare<uint8_t> mComparer;
729
ensureEncoding()730 void ensureEncoding() const {
731 assert(mState == State::kEncoding &&
732 "Call encodeArray, encodeTag, or encodeMap before this method");
733 }
734
ensureCapacity(const size_t capacity)735 bool ensureCapacity(const size_t capacity) {
736 if (mState == State::kInitial) {
737 mBuffer.resize(capacity);
738 CborOutInit(mBuffer.data(), mBuffer.size(), &mOut);
739 mState = mBuffer.size() == capacity ? State::kEncoding
740 : State::kInvalid;
741 }
742 return mState == State::kEncoding;
743 }
744
745 template <typename Fn>
encodeKeyCanonicalOrder(Fn fn)746 void encodeKeyCanonicalOrder(Fn fn) {
747 if (mState != State::kEncoding || !mLastKey.has_value()) {
748 mState = State::kInvalid;
749 return;
750 }
751
752 const struct CborOut preCursor = mOut;
753 fn();
754
755 const size_t newKeySz = mOut.cursor - preCursor.cursor;
756 const uint8_t* newKeyStart = preCursor.buffer + preCursor.cursor;
757 const std::span newKey(newKeyStart, newKeySz);
758
759 /*
760 * The keys in every map must be sorted lowest value to highest.
761 * Sorting is performed on the bytes of the representation of the key
762 * data items without paying attention to the 3/5 bit splitting for
763 * major types. The sorting rules are:
764 *
765 * * If two keys have different lengths, the shorter one sorts earlier;
766 *
767 * * If two keys have the same length, the one with the lower value
768 * in (byte-wise) lexical order sorts earlier.
769 */
770 if (mComparer.keyLess(mLastKey.value(), newKey)) {
771 mLastKey = newKey;
772 return;
773 }
774
775 /* CBOR encoding is not canonical */
776 mState = State::kInvalid;
777 }
778 };
779
780 using ArrayCborEncoder = CborEncoder<ArrayVector>;
781 using VectorCborEncoder = CborEncoder<std::vector<uint8_t>>;
782
783 } // namespace cbor
784