1 /*
2  * Copyright (C) 2014 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include <keymaster/authorization_set.h>
18 
19 #include <assert.h>
20 #include <stdlib.h>
21 #include <string.h>
22 #include <stddef.h>
23 
24 #include <new>
25 
26 #include <keymaster/android_keymaster_utils.h>
27 #include <keymaster/logger.h>
28 
29 namespace keymaster {
30 
is_blob_tag(keymaster_tag_t tag)31 static inline bool is_blob_tag(keymaster_tag_t tag) {
32     return (keymaster_tag_get_type(tag) == KM_BYTES || keymaster_tag_get_type(tag) == KM_BIGNUM);
33 }
34 
35 const size_t STARTING_ELEMS_CAPACITY = 8;
36 
AuthorizationSet(AuthorizationSetBuilder & builder)37 AuthorizationSet::AuthorizationSet(AuthorizationSetBuilder& builder) {
38     elems_ = builder.set.elems_;
39     builder.set.elems_ = NULL;
40 
41     elems_size_ = builder.set.elems_size_;
42     builder.set.elems_size_ = 0;
43 
44     elems_capacity_ = builder.set.elems_capacity_;
45     builder.set.elems_capacity_ = 0;
46 
47     indirect_data_ = builder.set.indirect_data_;
48     builder.set.indirect_data_ = NULL;
49 
50     indirect_data_capacity_ = builder.set.indirect_data_capacity_;
51     builder.set.indirect_data_capacity_ = 0;
52 
53     indirect_data_size_ = builder.set.indirect_data_size_;
54     builder.set.indirect_data_size_ = 0;
55 
56     error_ = builder.set.error_;
57     builder.set.error_ = OK;
58 }
59 
~AuthorizationSet()60 AuthorizationSet::~AuthorizationSet() {
61     FreeData();
62 }
63 
reserve_elems(size_t count)64 bool AuthorizationSet::reserve_elems(size_t count) {
65     if (is_valid() != OK)
66         return false;
67 
68     if (count >= elems_capacity_) {
69         keymaster_key_param_t* new_elems = new (std::nothrow) keymaster_key_param_t[count];
70         if (new_elems == NULL) {
71             set_invalid(ALLOCATION_FAILURE);
72             return false;
73         }
74         memcpy(new_elems, elems_, sizeof(*elems_) * elems_size_);
75         delete[] elems_;
76         elems_ = new_elems;
77         elems_capacity_ = count;
78     }
79     return true;
80 }
81 
reserve_indirect(size_t length)82 bool AuthorizationSet::reserve_indirect(size_t length) {
83     if (is_valid() != OK)
84         return false;
85 
86     if (length > indirect_data_capacity_) {
87         uint8_t* new_data = new (std::nothrow) uint8_t[length];
88         if (new_data == NULL) {
89             set_invalid(ALLOCATION_FAILURE);
90             return false;
91         }
92         memcpy(new_data, indirect_data_, indirect_data_size_);
93 
94         // Fix up the data pointers to point into the new region.
95         for (size_t i = 0; i < elems_size_; ++i) {
96             if (is_blob_tag(elems_[i].tag))
97                 elems_[i].blob.data = new_data + (elems_[i].blob.data - indirect_data_);
98         }
99         delete[] indirect_data_;
100         indirect_data_ = new_data;
101         indirect_data_capacity_ = length;
102     }
103     return true;
104 }
105 
Reinitialize(const keymaster_key_param_t * elems,const size_t count)106 bool AuthorizationSet::Reinitialize(const keymaster_key_param_t* elems, const size_t count) {
107     FreeData();
108 
109     if (elems == NULL || count == 0) {
110         error_ = OK;
111         return true;
112     }
113 
114     if (!reserve_elems(count))
115         return false;
116 
117     if (!reserve_indirect(ComputeIndirectDataSize(elems, count)))
118         return false;
119 
120     memcpy(elems_, elems, sizeof(keymaster_key_param_t) * count);
121     elems_size_ = count;
122     CopyIndirectData();
123     error_ = OK;
124     return true;
125 }
126 
set_invalid(Error error)127 void AuthorizationSet::set_invalid(Error error) {
128     FreeData();
129     error_ = error;
130 }
131 
Deduplicate()132 void AuthorizationSet::Deduplicate() {
133     qsort(elems_, elems_size_, sizeof(*elems_),
134           reinterpret_cast<int (*)(const void*, const void*)>(keymaster_param_compare));
135 
136     size_t invalid_count = 0;
137     for (size_t i = 1; i < size(); ++i) {
138         if (elems_[i - 1].tag == KM_TAG_INVALID)
139             ++invalid_count;
140         else if (keymaster_param_compare(elems_ + i - 1, elems_ + i) == 0) {
141             // Mark dups as invalid.  Note that this "leaks" the data referenced by KM_BYTES and
142             // KM_BIGNUM entries, but those are just pointers into indirect_data_, so it will all
143             // get cleaned up.
144             elems_[i - 1].tag = KM_TAG_INVALID;
145             ++invalid_count;
146         }
147     }
148     if (size() > 0 && elems_[size() - 1].tag == KM_TAG_INVALID)
149         ++invalid_count;
150 
151     if (invalid_count == 0)
152         return;
153 
154     // Since KM_TAG_INVALID == 0, all of the invalid entries are first.
155     elems_size_ -= invalid_count;
156     memmove(elems_, elems_ + invalid_count, size() * sizeof(*elems_));
157 }
158 
CopyToParamSet(keymaster_key_param_set_t * set) const159 void AuthorizationSet::CopyToParamSet(keymaster_key_param_set_t* set) const {
160     assert(set);
161 
162     set->length = size();
163     set->params =
164         reinterpret_cast<keymaster_key_param_t*>(malloc(sizeof(keymaster_key_param_t) * size()));
165 
166     for (size_t i = 0; i < size(); ++i) {
167         const keymaster_key_param_t src = (*this)[i];
168         keymaster_key_param_t& dst(set->params[i]);
169 
170         dst = src;
171         keymaster_tag_type_t type = keymaster_tag_get_type(src.tag);
172         if (type == KM_BIGNUM || type == KM_BYTES) {
173             void* tmp = malloc(src.blob.data_length);
174             memcpy(tmp, src.blob.data, src.blob.data_length);
175             dst.blob.data = reinterpret_cast<uint8_t*>(tmp);
176         }
177     }
178 }
179 
find(keymaster_tag_t tag,int begin) const180 int AuthorizationSet::find(keymaster_tag_t tag, int begin) const {
181     if (is_valid() != OK)
182         return -1;
183 
184     int i = ++begin;
185     while (i < (int)elems_size_ && elems_[i].tag != tag)
186         ++i;
187     if (i == (int)elems_size_)
188         return -1;
189     else
190         return i;
191 }
192 
193 keymaster_key_param_t empty;
operator [](int at) const194 keymaster_key_param_t AuthorizationSet::operator[](int at) const {
195     if (is_valid() == OK && at < (int)elems_size_) {
196         return elems_[at];
197     }
198     memset(&empty, 0, sizeof(empty));
199     return empty;
200 }
201 
push_back(const keymaster_key_param_set_t & set)202 bool AuthorizationSet::push_back(const keymaster_key_param_set_t& set) {
203     if (is_valid() != OK)
204         return false;
205 
206     if (!reserve_elems(elems_size_ + set.length))
207         return false;
208 
209     if (!reserve_indirect(indirect_data_size_ + ComputeIndirectDataSize(set.params, set.length)))
210         return false;
211 
212     for (size_t i = 0; i < set.length; ++i)
213         if (!push_back(set.params[i]))
214             return false;
215 
216     return true;
217 }
218 
push_back(keymaster_key_param_t elem)219 bool AuthorizationSet::push_back(keymaster_key_param_t elem) {
220     if (is_valid() != OK)
221         return false;
222 
223     if (elems_size_ >= elems_capacity_)
224         if (!reserve_elems(elems_capacity_ ? elems_capacity_ * 2 : STARTING_ELEMS_CAPACITY))
225             return false;
226 
227     if (is_blob_tag(elem.tag)) {
228         if (indirect_data_capacity_ - indirect_data_size_ < elem.blob.data_length)
229             if (!reserve_indirect(2 * (indirect_data_capacity_ + elem.blob.data_length)))
230                 return false;
231 
232         memcpy(indirect_data_ + indirect_data_size_, elem.blob.data, elem.blob.data_length);
233         elem.blob.data = indirect_data_ + indirect_data_size_;
234         indirect_data_size_ += elem.blob.data_length;
235     }
236 
237     elems_[elems_size_++] = elem;
238     return true;
239 }
240 
serialized_size(const keymaster_key_param_t & param)241 static size_t serialized_size(const keymaster_key_param_t& param) {
242     switch (keymaster_tag_get_type(param.tag)) {
243     case KM_INVALID:
244         return sizeof(uint32_t);
245     case KM_ENUM:
246     case KM_ENUM_REP:
247     case KM_UINT:
248     case KM_UINT_REP:
249         return sizeof(uint32_t) * 2;
250     case KM_ULONG:
251     case KM_ULONG_REP:
252     case KM_DATE:
253         return sizeof(uint32_t) + sizeof(uint64_t);
254     case KM_BOOL:
255         return sizeof(uint32_t) + 1;
256     case KM_BIGNUM:
257     case KM_BYTES:
258         return sizeof(uint32_t) * 3;
259     }
260 
261     return sizeof(uint32_t);
262 }
263 
serialize(const keymaster_key_param_t & param,uint8_t * buf,const uint8_t * end,const uint8_t * indirect_base)264 static uint8_t* serialize(const keymaster_key_param_t& param, uint8_t* buf, const uint8_t* end,
265                           const uint8_t* indirect_base) {
266     buf = append_uint32_to_buf(buf, end, param.tag);
267     switch (keymaster_tag_get_type(param.tag)) {
268     case KM_INVALID:
269         break;
270     case KM_ENUM:
271     case KM_ENUM_REP:
272         buf = append_uint32_to_buf(buf, end, param.enumerated);
273         break;
274     case KM_UINT:
275     case KM_UINT_REP:
276         buf = append_uint32_to_buf(buf, end, param.integer);
277         break;
278     case KM_ULONG:
279     case KM_ULONG_REP:
280         buf = append_uint64_to_buf(buf, end, param.long_integer);
281         break;
282     case KM_DATE:
283         buf = append_uint64_to_buf(buf, end, param.date_time);
284         break;
285     case KM_BOOL:
286         if (buf < end)
287             *buf = static_cast<uint8_t>(param.boolean);
288         buf++;
289         break;
290     case KM_BIGNUM:
291     case KM_BYTES:
292         buf = append_uint32_to_buf(buf, end, param.blob.data_length);
293         buf = append_uint32_to_buf(buf, end, param.blob.data - indirect_base);
294         break;
295     }
296     return buf;
297 }
298 
deserialize(keymaster_key_param_t * param,const uint8_t ** buf_ptr,const uint8_t * end,const uint8_t * indirect_base,const uint8_t * indirect_end)299 static bool deserialize(keymaster_key_param_t* param, const uint8_t** buf_ptr, const uint8_t* end,
300                         const uint8_t* indirect_base, const uint8_t* indirect_end) {
301     if (!copy_uint32_from_buf(buf_ptr, end, &param->tag))
302         return false;
303 
304     switch (keymaster_tag_get_type(param->tag)) {
305     case KM_INVALID:
306         return false;
307     case KM_ENUM:
308     case KM_ENUM_REP:
309         return copy_uint32_from_buf(buf_ptr, end, &param->enumerated);
310     case KM_UINT:
311     case KM_UINT_REP:
312         return copy_uint32_from_buf(buf_ptr, end, &param->integer);
313     case KM_ULONG:
314     case KM_ULONG_REP:
315         return copy_uint64_from_buf(buf_ptr, end, &param->long_integer);
316     case KM_DATE:
317         return copy_uint64_from_buf(buf_ptr, end, &param->date_time);
318         break;
319     case KM_BOOL:
320         if (*buf_ptr < end) {
321             param->boolean = static_cast<bool>(**buf_ptr);
322             (*buf_ptr)++;
323             return true;
324         }
325         return false;
326 
327     case KM_BIGNUM:
328     case KM_BYTES: {
329         uint32_t offset;
330         if (!copy_uint32_from_buf(buf_ptr, end, &param->blob.data_length) ||
331             !copy_uint32_from_buf(buf_ptr, end, &offset))
332             return false;
333         if (param->blob.data_length + offset < param->blob.data_length ||  // Overflow check
334             static_cast<ptrdiff_t>(offset) > indirect_end - indirect_base ||
335             static_cast<ptrdiff_t>(offset + param->blob.data_length) > indirect_end - indirect_base)
336             return false;
337         param->blob.data = indirect_base + offset;
338         return true;
339     }
340     }
341 
342     return false;
343 }
344 
SerializedSizeOfElements() const345 size_t AuthorizationSet::SerializedSizeOfElements() const {
346     size_t size = 0;
347     for (size_t i = 0; i < elems_size_; ++i) {
348         size += serialized_size(elems_[i]);
349     }
350     return size;
351 }
352 
SerializedSize() const353 size_t AuthorizationSet::SerializedSize() const {
354     return sizeof(uint32_t) +           // Size of indirect_data_
355            indirect_data_size_ +        // indirect_data_
356            sizeof(uint32_t) +           // Number of elems_
357            sizeof(uint32_t) +           // Size of elems_
358            SerializedSizeOfElements();  // elems_
359 }
360 
Serialize(uint8_t * buf,const uint8_t * end) const361 uint8_t* AuthorizationSet::Serialize(uint8_t* buf, const uint8_t* end) const {
362     buf = append_size_and_data_to_buf(buf, end, indirect_data_, indirect_data_size_);
363     buf = append_uint32_to_buf(buf, end, elems_size_);
364     buf = append_uint32_to_buf(buf, end, SerializedSizeOfElements());
365     for (size_t i = 0; i < elems_size_; ++i) {
366         buf = serialize(elems_[i], buf, end, indirect_data_);
367     }
368     return buf;
369 }
370 
DeserializeIndirectData(const uint8_t ** buf_ptr,const uint8_t * end)371 bool AuthorizationSet::DeserializeIndirectData(const uint8_t** buf_ptr, const uint8_t* end) {
372     UniquePtr<uint8_t[]> indirect_buf;
373     if (!copy_size_and_data_from_buf(buf_ptr, end, &indirect_data_size_, &indirect_buf)) {
374         LOG_E("Malformed data found in AuthorizationSet deserialization", 0);
375         set_invalid(MALFORMED_DATA);
376         return false;
377     }
378     indirect_data_ = indirect_buf.release();
379     return true;
380 }
381 
DeserializeElementsData(const uint8_t ** buf_ptr,const uint8_t * end)382 bool AuthorizationSet::DeserializeElementsData(const uint8_t** buf_ptr, const uint8_t* end) {
383     uint32_t elements_count;
384     uint32_t elements_size;
385     if (!copy_uint32_from_buf(buf_ptr, end, &elements_count) ||
386         !copy_uint32_from_buf(buf_ptr, end, &elements_size)) {
387         LOG_E("Malformed data found in AuthorizationSet deserialization", 0);
388         set_invalid(MALFORMED_DATA);
389         return false;
390     }
391 
392     // Note that the following validation of elements_count is weak, but it prevents allocation of
393     // elems_ arrays which are clearly too large to be reasonable.
394     if (static_cast<ptrdiff_t>(elements_size) > end - *buf_ptr ||
395         elements_count * sizeof(uint32_t) > elements_size ||
396         *buf_ptr + (elements_count * sizeof(*elems_)) < *buf_ptr) {
397         LOG_E("Malformed data found in AuthorizationSet deserialization", 0);
398         set_invalid(MALFORMED_DATA);
399         return false;
400     }
401 
402     if (!reserve_elems(elements_count))
403         return false;
404 
405     uint8_t* indirect_end = indirect_data_ + indirect_data_size_;
406     const uint8_t* elements_end = *buf_ptr + elements_size;
407     for (size_t i = 0; i < elements_count; ++i) {
408         if (!deserialize(elems_ + i, buf_ptr, elements_end, indirect_data_, indirect_end)) {
409             LOG_E("Malformed data found in AuthorizationSet deserialization", 0);
410             set_invalid(MALFORMED_DATA);
411             return false;
412         }
413     }
414     elems_size_ = elements_count;
415     return true;
416 }
417 
Deserialize(const uint8_t ** buf_ptr,const uint8_t * end)418 bool AuthorizationSet::Deserialize(const uint8_t** buf_ptr, const uint8_t* end) {
419     FreeData();
420 
421     if (!DeserializeIndirectData(buf_ptr, end) || !DeserializeElementsData(buf_ptr, end))
422         return false;
423 
424     if (indirect_data_size_ != ComputeIndirectDataSize(elems_, elems_size_)) {
425         LOG_E("Malformed data found in AuthorizationSet deserialization", 0);
426         set_invalid(MALFORMED_DATA);
427         return false;
428     }
429     return true;
430 }
431 
Clear()432 void AuthorizationSet::Clear() {
433     memset_s(elems_, 0, elems_size_ * sizeof(keymaster_key_param_t));
434     memset_s(indirect_data_, 0, indirect_data_size_);
435     elems_size_ = 0;
436     indirect_data_size_ = 0;
437 }
438 
FreeData()439 void AuthorizationSet::FreeData() {
440     Clear();
441 
442     delete[] elems_;
443     delete[] indirect_data_;
444 
445     elems_ = NULL;
446     indirect_data_ = NULL;
447     elems_capacity_ = 0;
448     indirect_data_capacity_ = 0;
449     error_ = OK;
450 }
451 
452 /* static */
ComputeIndirectDataSize(const keymaster_key_param_t * elems,size_t count)453 size_t AuthorizationSet::ComputeIndirectDataSize(const keymaster_key_param_t* elems, size_t count) {
454     size_t size = 0;
455     for (size_t i = 0; i < count; ++i) {
456         if (is_blob_tag(elems[i].tag)) {
457             size += elems[i].blob.data_length;
458         }
459     }
460     return size;
461 }
462 
CopyIndirectData()463 void AuthorizationSet::CopyIndirectData() {
464     memset_s(indirect_data_, 0, indirect_data_capacity_);
465 
466     uint8_t* indirect_data_pos = indirect_data_;
467     for (size_t i = 0; i < elems_size_; ++i) {
468         assert(indirect_data_pos <= indirect_data_ + indirect_data_capacity_);
469         if (is_blob_tag(elems_[i].tag)) {
470             memcpy(indirect_data_pos, elems_[i].blob.data, elems_[i].blob.data_length);
471             elems_[i].blob.data = indirect_data_pos;
472             indirect_data_pos += elems_[i].blob.data_length;
473         }
474     }
475     assert(indirect_data_pos == indirect_data_ + indirect_data_capacity_);
476     indirect_data_size_ = indirect_data_pos - indirect_data_;
477 }
478 
GetTagCount(keymaster_tag_t tag) const479 size_t AuthorizationSet::GetTagCount(keymaster_tag_t tag) const {
480     size_t count = 0;
481     for (int pos = -1; (pos = find(tag, pos)) != -1;)
482         ++count;
483     return count;
484 }
485 
GetTagValueEnum(keymaster_tag_t tag,uint32_t * val) const486 bool AuthorizationSet::GetTagValueEnum(keymaster_tag_t tag, uint32_t* val) const {
487     int pos = find(tag);
488     if (pos == -1) {
489         return false;
490     }
491     *val = elems_[pos].enumerated;
492     return true;
493 }
494 
GetTagValueEnumRep(keymaster_tag_t tag,size_t instance,uint32_t * val) const495 bool AuthorizationSet::GetTagValueEnumRep(keymaster_tag_t tag, size_t instance,
496                                           uint32_t* val) const {
497     size_t count = 0;
498     int pos = -1;
499     while (count <= instance) {
500         pos = find(tag, pos);
501         if (pos == -1) {
502             return false;
503         }
504         ++count;
505     }
506     *val = elems_[pos].enumerated;
507     return true;
508 }
509 
GetTagValueInt(keymaster_tag_t tag,uint32_t * val) const510 bool AuthorizationSet::GetTagValueInt(keymaster_tag_t tag, uint32_t* val) const {
511     int pos = find(tag);
512     if (pos == -1) {
513         return false;
514     }
515     *val = elems_[pos].integer;
516     return true;
517 }
518 
GetTagValueIntRep(keymaster_tag_t tag,size_t instance,uint32_t * val) const519 bool AuthorizationSet::GetTagValueIntRep(keymaster_tag_t tag, size_t instance,
520                                          uint32_t* val) const {
521     size_t count = 0;
522     int pos = -1;
523     while (count <= instance) {
524         pos = find(tag, pos);
525         if (pos == -1) {
526             return false;
527         }
528         ++count;
529     }
530     *val = elems_[pos].integer;
531     return true;
532 }
533 
GetTagValueLong(keymaster_tag_t tag,uint64_t * val) const534 bool AuthorizationSet::GetTagValueLong(keymaster_tag_t tag, uint64_t* val) const {
535     int pos = find(tag);
536     if (pos == -1) {
537         return false;
538     }
539     *val = elems_[pos].long_integer;
540     return true;
541 }
542 
GetTagValueLongRep(keymaster_tag_t tag,size_t instance,uint64_t * val) const543 bool AuthorizationSet::GetTagValueLongRep(keymaster_tag_t tag, size_t instance,
544                                           uint64_t* val) const {
545     size_t count = 0;
546     int pos = -1;
547     while (count <= instance) {
548         pos = find(tag, pos);
549         if (pos == -1) {
550             return false;
551         }
552         ++count;
553     }
554     *val = elems_[pos].long_integer;
555     return true;
556 }
557 
GetTagValueDate(keymaster_tag_t tag,uint64_t * val) const558 bool AuthorizationSet::GetTagValueDate(keymaster_tag_t tag, uint64_t* val) const {
559     int pos = find(tag);
560     if (pos == -1) {
561         return false;
562     }
563     *val = elems_[pos].date_time;
564     return true;
565 }
566 
GetTagValueBlob(keymaster_tag_t tag,keymaster_blob_t * val) const567 bool AuthorizationSet::GetTagValueBlob(keymaster_tag_t tag, keymaster_blob_t* val) const {
568     int pos = find(tag);
569     if (pos == -1) {
570         return false;
571     }
572     *val = elems_[pos].blob;
573     return true;
574 }
575 
GetTagValueBool(keymaster_tag_t tag) const576 bool AuthorizationSet::GetTagValueBool(keymaster_tag_t tag) const {
577     int pos = find(tag);
578     if (pos == -1) {
579         return false;
580     }
581     assert(elems_[pos].boolean);
582     return elems_[pos].boolean;
583 }
584 
ContainsEnumValue(keymaster_tag_t tag,uint32_t value) const585 bool AuthorizationSet::ContainsEnumValue(keymaster_tag_t tag, uint32_t value) const {
586     for (auto& entry : *this)
587         if (entry.tag == tag && entry.enumerated == value)
588             return true;
589     return false;
590 }
591 
592 }  // namespace keymaster
593