1 /*
2  * Copyright (C) 2014 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include <stdlib.h>
18 #include <string.h>
19 #include <stddef.h>
20 
21 #include <assert.h>
22 
23 #include <keymaster/authorization_set.h>
24 #include <keymaster/google_keymaster_utils.h>
25 
26 namespace keymaster {
27 
is_blob_tag(keymaster_tag_t tag)28 static inline bool is_blob_tag(keymaster_tag_t tag) {
29     return (keymaster_tag_get_type(tag) == KM_BYTES || keymaster_tag_get_type(tag) == KM_BIGNUM);
30 }
31 
32 const size_t STARTING_ELEMS_CAPACITY = 8;
33 
AuthorizationSet(const AuthorizationSet & set)34 AuthorizationSet::AuthorizationSet(const AuthorizationSet& set)
35     : Serializable(), elems_(NULL), indirect_data_(NULL) {
36     Reinitialize(set.elems_, set.elems_size_);
37 }
38 
~AuthorizationSet()39 AuthorizationSet::~AuthorizationSet() {
40     FreeData();
41 }
42 
reserve_elems(size_t count)43 bool AuthorizationSet::reserve_elems(size_t count) {
44     if (is_valid() != OK)
45         return false;
46 
47     if (count >= elems_capacity_) {
48         keymaster_key_param_t* new_elems = new keymaster_key_param_t[count];
49         if (new_elems == NULL) {
50             set_invalid(ALLOCATION_FAILURE);
51             return false;
52         }
53         memcpy(new_elems, elems_, sizeof(*elems_) * elems_size_);
54         delete[] elems_;
55         elems_ = new_elems;
56         elems_capacity_ = count;
57     }
58     return true;
59 }
60 
reserve_indirect(size_t length)61 bool AuthorizationSet::reserve_indirect(size_t length) {
62     if (is_valid() != OK)
63         return false;
64 
65     if (length > indirect_data_capacity_) {
66         uint8_t* new_data = new uint8_t[length];
67         if (new_data == NULL) {
68             set_invalid(ALLOCATION_FAILURE);
69             return false;
70         }
71         memcpy(new_data, indirect_data_, indirect_data_size_);
72 
73         // Fix up the data pointers to point into the new region.
74         for (size_t i = 0; i < elems_size_; ++i) {
75             if (is_blob_tag(elems_[i].tag))
76                 elems_[i].blob.data = new_data + (elems_[i].blob.data - indirect_data_);
77         }
78         delete[] indirect_data_;
79         indirect_data_ = new_data;
80         indirect_data_capacity_ = length;
81     }
82     return true;
83 }
84 
Reinitialize(const keymaster_key_param_t * elems,const size_t count)85 bool AuthorizationSet::Reinitialize(const keymaster_key_param_t* elems, const size_t count) {
86     FreeData();
87 
88     if (!reserve_elems(count))
89         return false;
90 
91     if (!reserve_indirect(ComputeIndirectDataSize(elems, count)))
92         return false;
93 
94     memcpy(elems_, elems, sizeof(keymaster_key_param_t) * count);
95     elems_size_ = count;
96     CopyIndirectData();
97     error_ = OK;
98     return true;
99 }
100 
set_invalid(Error error)101 void AuthorizationSet::set_invalid(Error error) {
102     FreeData();
103     error_ = error;
104 }
105 
find(keymaster_tag_t tag,int begin) const106 int AuthorizationSet::find(keymaster_tag_t tag, int begin) const {
107     if (is_valid() != OK)
108         return -1;
109 
110     int i = ++begin;
111     while (i < (int)elems_size_ && elems_[i].tag != tag)
112         ++i;
113     if (i == (int)elems_size_)
114         return -1;
115     else
116         return i;
117 }
118 
119 keymaster_key_param_t empty;
operator [](int at) const120 keymaster_key_param_t AuthorizationSet::operator[](int at) const {
121     if (is_valid() == OK && at < (int)elems_size_) {
122         return elems_[at];
123     }
124     memset(&empty, 0, sizeof(empty));
125     return empty;
126 }
127 
comparator(const T & a,const T & b)128 template <typename T> int comparator(const T& a, const T& b) {
129     if (a < b)
130         return -1;
131     else if (a > b)
132         return 1;
133     else
134         return 0;
135 }
136 
param_comparator(const void * a,const void * b)137 static int param_comparator(const void* a, const void* b) {
138     const keymaster_key_param_t* lhs = static_cast<const keymaster_key_param_t*>(a);
139     const keymaster_key_param_t* rhs = static_cast<const keymaster_key_param_t*>(b);
140 
141     if (lhs->tag < rhs->tag)
142         return -1;
143     else if (lhs->tag > rhs->tag)
144         return 1;
145     else
146         switch (keymaster_tag_get_type(lhs->tag)) {
147         default:
148         case KM_INVALID:
149             return 0;
150         case KM_ENUM:
151         case KM_ENUM_REP:
152             return comparator(lhs->enumerated, rhs->enumerated);
153         case KM_INT:
154         case KM_INT_REP:
155             return comparator(lhs->integer, rhs->integer);
156         case KM_LONG:
157             return comparator(lhs->long_integer, rhs->long_integer);
158         case KM_DATE:
159             return comparator(lhs->date_time, rhs->date_time);
160         case KM_BOOL:
161             return comparator(lhs->boolean, rhs->boolean);
162         case KM_BIGNUM:
163         case KM_BYTES: {
164             size_t min_len = lhs->blob.data_length;
165             if (rhs->blob.data_length < min_len)
166                 min_len = rhs->blob.data_length;
167 
168             if (lhs->blob.data_length == rhs->blob.data_length && min_len > 0)
169                 return memcmp(lhs->blob.data, rhs->blob.data, min_len);
170             int cmp_result = memcmp(lhs->blob.data, rhs->blob.data, min_len);
171             if (cmp_result == 0) {
172                 // The blobs are equal up to the length of the shortest (which may have length 0),
173                 // so the shorter is less, the longer is greater and if they have the same length
174                 // they're identical.
175                 return comparator(lhs->blob.data_length, rhs->blob.data_length);
176             }
177             return cmp_result;
178         } break;
179         }
180 }
181 
push_back(const AuthorizationSet & set)182 bool AuthorizationSet::push_back(const AuthorizationSet& set) {
183     if (is_valid() != OK)
184         return false;
185 
186     if (!reserve_elems(elems_size_ + set.elems_size_))
187         return false;
188 
189     if (!reserve_indirect(indirect_data_size_ + set.indirect_data_size_))
190         return false;
191 
192     for (size_t i = 0; i < set.size(); ++i)
193         if (!push_back(set[i]))
194             return false;
195 
196     return true;
197 }
198 
push_back(keymaster_key_param_t elem)199 bool AuthorizationSet::push_back(keymaster_key_param_t elem) {
200     if (is_valid() != OK)
201         return false;
202 
203     if (elems_size_ >= elems_capacity_)
204         if (!reserve_elems(elems_capacity_ ? elems_capacity_ * 2 : STARTING_ELEMS_CAPACITY))
205             return false;
206 
207     if (is_blob_tag(elem.tag)) {
208         if (indirect_data_capacity_ - indirect_data_size_ < elem.blob.data_length)
209             if (!reserve_indirect(2 * (indirect_data_capacity_ + elem.blob.data_length)))
210                 return false;
211 
212         memcpy(indirect_data_ + indirect_data_size_, elem.blob.data, elem.blob.data_length);
213         elem.blob.data = indirect_data_ + indirect_data_size_;
214         indirect_data_size_ += elem.blob.data_length;
215     }
216 
217     elems_[elems_size_++] = elem;
218     return true;
219 }
220 
serialized_size(const keymaster_key_param_t & param)221 static size_t serialized_size(const keymaster_key_param_t& param) {
222     switch (keymaster_tag_get_type(param.tag)) {
223     case KM_INVALID:
224     default:
225         return sizeof(uint32_t);
226     case KM_ENUM:
227     case KM_ENUM_REP:
228     case KM_INT:
229     case KM_INT_REP:
230         return sizeof(uint32_t) * 2;
231     case KM_LONG:
232     case KM_DATE:
233         return sizeof(uint32_t) + sizeof(uint64_t);
234     case KM_BOOL:
235         return sizeof(uint32_t) + 1;
236         break;
237     case KM_BIGNUM:
238     case KM_BYTES:
239         return sizeof(uint32_t) * 3;
240     }
241 }
242 
serialize(const keymaster_key_param_t & param,uint8_t * buf,const uint8_t * end,const uint8_t * indirect_base)243 static uint8_t* serialize(const keymaster_key_param_t& param, uint8_t* buf, const uint8_t* end,
244                           const uint8_t* indirect_base) {
245     buf = append_uint32_to_buf(buf, end, param.tag);
246     switch (keymaster_tag_get_type(param.tag)) {
247     case KM_INVALID:
248         break;
249     case KM_ENUM:
250     case KM_ENUM_REP:
251         buf = append_uint32_to_buf(buf, end, param.enumerated);
252         break;
253     case KM_INT:
254     case KM_INT_REP:
255         buf = append_uint32_to_buf(buf, end, param.integer);
256         break;
257     case KM_LONG:
258         buf = append_uint64_to_buf(buf, end, param.long_integer);
259         break;
260     case KM_DATE:
261         buf = append_uint64_to_buf(buf, end, param.date_time);
262         break;
263     case KM_BOOL:
264         if (buf < end)
265             *buf = static_cast<uint8_t>(param.boolean);
266         buf++;
267         break;
268     case KM_BIGNUM:
269     case KM_BYTES:
270         buf = append_uint32_to_buf(buf, end, param.blob.data_length);
271         buf = append_uint32_to_buf(buf, end, param.blob.data - indirect_base);
272         break;
273     }
274     return buf;
275 }
276 
deserialize(keymaster_key_param_t * param,const uint8_t ** buf_ptr,const uint8_t * end,const uint8_t * indirect_base,const uint8_t * indirect_end)277 static bool deserialize(keymaster_key_param_t* param, const uint8_t** buf_ptr, const uint8_t* end,
278                         const uint8_t* indirect_base, const uint8_t* indirect_end) {
279     if (!copy_uint32_from_buf(buf_ptr, end, &param->tag))
280         return false;
281 
282     switch (keymaster_tag_get_type(param->tag)) {
283     default:
284     case KM_INVALID:
285         return false;
286     case KM_ENUM:
287     case KM_ENUM_REP:
288         return copy_uint32_from_buf(buf_ptr, end, &param->enumerated);
289     case KM_INT:
290     case KM_INT_REP:
291         return copy_uint32_from_buf(buf_ptr, end, &param->integer);
292     case KM_LONG:
293         return copy_uint64_from_buf(buf_ptr, end, &param->long_integer);
294     case KM_DATE:
295         return copy_uint64_from_buf(buf_ptr, end, &param->date_time);
296         break;
297     case KM_BOOL:
298         if (*buf_ptr < end) {
299             param->boolean = static_cast<bool>(**buf_ptr);
300             (*buf_ptr)++;
301             return true;
302         }
303         return false;
304 
305     case KM_BIGNUM:
306     case KM_BYTES: {
307         uint32_t offset;
308         if (!copy_uint32_from_buf(buf_ptr, end, &param->blob.data_length) ||
309             !copy_uint32_from_buf(buf_ptr, end, &offset))
310             return false;
311         if (static_cast<ptrdiff_t>(offset) > indirect_end - indirect_base ||
312             static_cast<ptrdiff_t>(offset + param->blob.data_length) > indirect_end - indirect_base)
313             return false;
314         param->blob.data = indirect_base + offset;
315         return true;
316     }
317     }
318 }
319 
SerializedSizeOfElements() const320 size_t AuthorizationSet::SerializedSizeOfElements() const {
321     size_t size = 0;
322     for (size_t i = 0; i < elems_size_; ++i) {
323         size += serialized_size(elems_[i]);
324     }
325     return size;
326 }
327 
SerializedSize() const328 size_t AuthorizationSet::SerializedSize() const {
329     return sizeof(uint32_t) +           // Size of indirect_data_
330            indirect_data_size_ +        // indirect_data_
331            sizeof(uint32_t) +           // Number of elems_
332            sizeof(uint32_t) +           // Size of elems_
333            SerializedSizeOfElements();  // elems_
334 }
335 
Serialize(uint8_t * buf,const uint8_t * end) const336 uint8_t* AuthorizationSet::Serialize(uint8_t* buf, const uint8_t* end) const {
337     buf = append_size_and_data_to_buf(buf, end, indirect_data_, indirect_data_size_);
338     buf = append_uint32_to_buf(buf, end, elems_size_);
339     buf = append_uint32_to_buf(buf, end, SerializedSizeOfElements());
340     for (size_t i = 0; i < elems_size_; ++i) {
341         buf = serialize(elems_[i], buf, end, indirect_data_);
342     }
343     return buf;
344 }
345 
DeserializeIndirectData(const uint8_t ** buf_ptr,const uint8_t * end)346 bool AuthorizationSet::DeserializeIndirectData(const uint8_t** buf_ptr, const uint8_t* end) {
347     UniquePtr<uint8_t[]> indirect_buf;
348     if (!copy_size_and_data_from_buf(buf_ptr, end, &indirect_data_size_, &indirect_buf)) {
349         set_invalid(MALFORMED_DATA);
350         return false;
351     }
352     indirect_data_ = indirect_buf.release();
353     return true;
354 }
355 
DeserializeElementsData(const uint8_t ** buf_ptr,const uint8_t * end)356 bool AuthorizationSet::DeserializeElementsData(const uint8_t** buf_ptr, const uint8_t* end) {
357     uint32_t elements_count;
358     uint32_t elements_size;
359     if (!copy_uint32_from_buf(buf_ptr, end, &elements_count) ||
360         !copy_uint32_from_buf(buf_ptr, end, &elements_size)) {
361         set_invalid(MALFORMED_DATA);
362         return false;
363     }
364 
365     // Note that the following validation of elements_count is weak, but it prevents allocation of
366     // elems_ arrays which are clearly too large to be reasonable.
367     if (static_cast<ptrdiff_t>(elements_size) > end - *buf_ptr ||
368         elements_count * sizeof(uint32_t) > elements_size) {
369         set_invalid(MALFORMED_DATA);
370         return false;
371     }
372 
373     if (!reserve_elems(elements_count))
374         return false;
375 
376     uint8_t* indirect_end = indirect_data_ + indirect_data_size_;
377     const uint8_t* elements_end = *buf_ptr + elements_size;
378     for (size_t i = 0; i < elements_count; ++i) {
379         if (!deserialize(elems_ + i, buf_ptr, elements_end, indirect_data_, indirect_end)) {
380             set_invalid(MALFORMED_DATA);
381             return false;
382         }
383     }
384     elems_size_ = elements_count;
385     return true;
386 }
387 
Deserialize(const uint8_t ** buf_ptr,const uint8_t * end)388 bool AuthorizationSet::Deserialize(const uint8_t** buf_ptr, const uint8_t* end) {
389     FreeData();
390 
391     if (!DeserializeIndirectData(buf_ptr, end) || !DeserializeElementsData(buf_ptr, end))
392         return false;
393 
394     if (indirect_data_size_ != ComputeIndirectDataSize(elems_, elems_size_)) {
395         set_invalid(MALFORMED_DATA);
396         return false;
397     }
398     return true;
399 }
400 
FreeData()401 void AuthorizationSet::FreeData() {
402     if (elems_ != NULL)
403         memset_s(elems_, 0, elems_size_ * sizeof(keymaster_key_param_t));
404     if (indirect_data_ != NULL)
405         memset_s(indirect_data_, 0, indirect_data_size_);
406 
407     delete[] elems_;
408     delete[] indirect_data_;
409 
410     elems_ = NULL;
411     indirect_data_ = NULL;
412     elems_size_ = 0;
413     elems_capacity_ = 0;
414     indirect_data_size_ = 0;
415     indirect_data_capacity_ = 0;
416     error_ = OK;
417 }
418 
419 /* static */
ComputeIndirectDataSize(const keymaster_key_param_t * elems,size_t count)420 size_t AuthorizationSet::ComputeIndirectDataSize(const keymaster_key_param_t* elems, size_t count) {
421     size_t size = 0;
422     for (size_t i = 0; i < count; ++i) {
423         if (is_blob_tag(elems[i].tag)) {
424             size += elems[i].blob.data_length;
425         }
426     }
427     return size;
428 }
429 
CopyIndirectData()430 void AuthorizationSet::CopyIndirectData() {
431     memset_s(indirect_data_, 0, indirect_data_capacity_);
432 
433     uint8_t* indirect_data_pos = indirect_data_;
434     for (size_t i = 0; i < elems_size_; ++i) {
435         assert(indirect_data_pos <= indirect_data_ + indirect_data_capacity_);
436         if (is_blob_tag(elems_[i].tag)) {
437             memcpy(indirect_data_pos, elems_[i].blob.data, elems_[i].blob.data_length);
438             elems_[i].blob.data = indirect_data_pos;
439             indirect_data_pos += elems_[i].blob.data_length;
440         }
441     }
442     assert(indirect_data_pos == indirect_data_ + indirect_data_capacity_);
443     indirect_data_size_ = indirect_data_pos - indirect_data_;
444 }
445 
GetTagValueEnum(keymaster_tag_t tag,uint32_t * val) const446 bool AuthorizationSet::GetTagValueEnum(keymaster_tag_t tag, uint32_t* val) const {
447     int pos = find(tag);
448     if (pos == -1) {
449         return false;
450     }
451     *val = elems_[pos].enumerated;
452     return true;
453 }
454 
GetTagValueEnumRep(keymaster_tag_t tag,size_t instance,uint32_t * val) const455 bool AuthorizationSet::GetTagValueEnumRep(keymaster_tag_t tag, size_t instance,
456                                           uint32_t* val) const {
457     size_t count = 0;
458     int pos = -1;
459     while (count <= instance) {
460         pos = find(tag, pos);
461         if (pos == -1) {
462             return false;
463         }
464         ++count;
465     }
466     *val = elems_[pos].enumerated;
467     return true;
468 }
469 
GetTagValueInt(keymaster_tag_t tag,uint32_t * val) const470 bool AuthorizationSet::GetTagValueInt(keymaster_tag_t tag, uint32_t* val) const {
471     int pos = find(tag);
472     if (pos == -1) {
473         return false;
474     }
475     *val = elems_[pos].integer;
476     return true;
477 }
478 
GetTagValueIntRep(keymaster_tag_t tag,size_t instance,uint32_t * val) const479 bool AuthorizationSet::GetTagValueIntRep(keymaster_tag_t tag, size_t instance,
480                                          uint32_t* val) const {
481     size_t count = 0;
482     int pos = -1;
483     while (count <= instance) {
484         pos = find(tag, pos);
485         if (pos == -1) {
486             return false;
487         }
488         ++count;
489     }
490     *val = elems_[pos].integer;
491     return true;
492 }
493 
GetTagValueLong(keymaster_tag_t tag,uint64_t * val) const494 bool AuthorizationSet::GetTagValueLong(keymaster_tag_t tag, uint64_t* val) const {
495     int pos = find(tag);
496     if (pos == -1) {
497         return false;
498     }
499     *val = elems_[pos].long_integer;
500     return true;
501 }
502 
GetTagValueDate(keymaster_tag_t tag,uint64_t * val) const503 bool AuthorizationSet::GetTagValueDate(keymaster_tag_t tag, uint64_t* val) const {
504     int pos = find(tag);
505     if (pos == -1) {
506         return false;
507     }
508     *val = elems_[pos].date_time;
509     return true;
510 }
511 
GetTagValueBlob(keymaster_tag_t tag,keymaster_blob_t * val) const512 bool AuthorizationSet::GetTagValueBlob(keymaster_tag_t tag, keymaster_blob_t* val) const {
513     int pos = find(tag);
514     if (pos == -1) {
515         return false;
516     }
517     *val = elems_[pos].blob;
518     return true;
519 }
520 
521 }  // namespace keymaster
522