1 /*
2  * Copyright (C) 2018 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #define LOG_TAG "hidl_JsonWebKey"
17 
18 #include <utils/Log.h>
19 
20 #include "JsonWebKey.h"
21 
22 #include "Base64.h"
23 
24 namespace {
25 const std::string kBase64Padding("=");
26 const std::string kKeysTag("keys");
27 const std::string kKeyTypeTag("kty");
28 const std::string kKeyTag("k");
29 const std::string kKeyIdTag("kid");
30 const std::string kMediaSessionType("type");
31 const std::string kPersistentLicenseSession("persistent-license");
32 const std::string kSymmetricKeyValue("oct");
33 const std::string kTemporaryLicenseSession("temporary");
34 }
35 
36 namespace android {
37 namespace hardware {
38 namespace drm {
39 namespace V1_2 {
40 namespace clearkey {
41 
JsonWebKey()42 JsonWebKey::JsonWebKey() {
43 }
44 
~JsonWebKey()45 JsonWebKey::~JsonWebKey() {
46 }
47 
48 /*
49  * Parses a JSON Web Key Set string, initializes a KeyMap with key id:key
50  * pairs from the JSON Web Key Set. Both key ids and keys are base64url
51  * encoded. The KeyMap contains base64url decoded key id:key pairs.
52  *
53  * @return Returns false for errors, true for success.
54  */
extractKeysFromJsonWebKeySet(const std::string & jsonWebKeySet,KeyMap * keys)55 bool JsonWebKey::extractKeysFromJsonWebKeySet(const std::string& jsonWebKeySet,
56         KeyMap* keys) {
57 
58     keys->clear();
59 
60     if (!parseJsonWebKeySet(jsonWebKeySet, &mJsonObjects)) {
61         return false;
62     }
63 
64     // mJsonObjects[0] contains the entire JSON Web Key Set, including
65     // all the base64 encoded keys. Each key is also stored separately as
66     // a JSON object in mJsonObjects[1..n] where n is the total
67     // number of keys in the set.
68     if (!isJsonWebKeySet(mJsonObjects[0])) {
69         return false;
70     }
71 
72     std::string encodedKey, encodedKeyId;
73     std::vector<uint8_t> decodedKey, decodedKeyId;
74 
75     // mJsonObjects[1] contains the first JSON Web Key in the set
76     for (size_t i = 1; i < mJsonObjects.size(); ++i) {
77         encodedKeyId.clear();
78         encodedKey.clear();
79 
80         if (!parseJsonObject(mJsonObjects[i], &mTokens))
81             return false;
82 
83         if (findKey(mJsonObjects[i], &encodedKeyId, &encodedKey)) {
84             if (encodedKeyId.empty() || encodedKey.empty()) {
85                 ALOGE("Must have both key id and key in the JsonWebKey set.");
86                 continue;
87             }
88 
89             if (!decodeBase64String(encodedKeyId, &decodedKeyId)) {
90                 ALOGE("Failed to decode key id(%s)", encodedKeyId.c_str());
91                 continue;
92             }
93 
94             if (!decodeBase64String(encodedKey, &decodedKey)) {
95                 ALOGE("Failed to decode key(%s)", encodedKey.c_str());
96                 continue;
97             }
98 
99             keys->insert(std::pair<std::vector<uint8_t>,
100                     std::vector<uint8_t> >(decodedKeyId, decodedKey));
101         }
102     }
103     return true;
104 }
105 
decodeBase64String(const std::string & encodedText,std::vector<uint8_t> * decodedText)106 bool JsonWebKey::decodeBase64String(const std::string& encodedText,
107         std::vector<uint8_t>* decodedText) {
108 
109     decodedText->clear();
110 
111     // encodedText should not contain padding characters as per EME spec.
112     if (encodedText.find(kBase64Padding) != std::string::npos) {
113         return false;
114     }
115 
116     // Since decodeBase64() requires padding characters,
117     // add them so length of encodedText is exactly a multiple of 4.
118     int remainder = encodedText.length() % 4;
119     std::string paddedText(encodedText);
120     if (remainder > 0) {
121         for (int i = 0; i < 4 - remainder; ++i) {
122             paddedText.append(kBase64Padding);
123         }
124     }
125 
126     sp<Buffer> buffer = decodeBase64(paddedText);
127     if (buffer == nullptr) {
128         ALOGE("Malformed base64 encoded content found.");
129         return false;
130     }
131 
132     decodedText->insert(decodedText->end(), buffer->base(), buffer->base() + buffer->size());
133     return true;
134 }
135 
findKey(const std::string & jsonObject,std::string * keyId,std::string * encodedKey)136 bool JsonWebKey::findKey(const std::string& jsonObject, std::string* keyId,
137         std::string* encodedKey) {
138 
139     std::string key, value;
140 
141     // Only allow symmetric key, i.e. "kty":"oct" pair.
142     if (jsonObject.find(kKeyTypeTag) != std::string::npos) {
143         findValue(kKeyTypeTag, &value);
144         if (0 != value.compare(kSymmetricKeyValue))
145             return false;
146     }
147 
148     if (jsonObject.find(kKeyIdTag) != std::string::npos) {
149         findValue(kKeyIdTag, keyId);
150     }
151 
152     if (jsonObject.find(kKeyTag) != std::string::npos) {
153         findValue(kKeyTag, encodedKey);
154     }
155     return true;
156 }
157 
findValue(const std::string & key,std::string * value)158 void JsonWebKey::findValue(const std::string &key, std::string* value) {
159     value->clear();
160     const char* valueToken;
161     for (std::vector<std::string>::const_iterator nextToken = mTokens.begin();
162         nextToken != mTokens.end(); ++nextToken) {
163         if (0 == (*nextToken).compare(key)) {
164             if (nextToken + 1 == mTokens.end())
165                 break;
166             valueToken = (*(nextToken + 1)).c_str();
167             value->assign(valueToken);
168             nextToken++;
169             break;
170         }
171     }
172 }
173 
isJsonWebKeySet(const std::string & jsonObject) const174 bool JsonWebKey::isJsonWebKeySet(const std::string& jsonObject) const {
175     if (jsonObject.find(kKeysTag) == std::string::npos) {
176         ALOGE("JSON Web Key does not contain keys.");
177         return false;
178     }
179     return true;
180 }
181 
182 /*
183  * Parses a JSON objects string and initializes a vector of tokens.
184  *
185  * @return Returns false for errors, true for success.
186  */
parseJsonObject(const std::string & jsonObject,std::vector<std::string> * tokens)187 bool JsonWebKey::parseJsonObject(const std::string& jsonObject,
188         std::vector<std::string>* tokens) {
189     jsmn_parser parser;
190 
191     jsmn_init(&parser);
192     int numTokens = jsmn_parse(&parser,
193         jsonObject.c_str(), jsonObject.size(), nullptr, 0);
194     if (numTokens < 0) {
195         ALOGE("Parser returns error code=%d", numTokens);
196         return false;
197     }
198 
199     unsigned int jsmnTokensSize = numTokens * sizeof(jsmntok_t);
200     mJsmnTokens.clear();
201     mJsmnTokens.resize(jsmnTokensSize);
202 
203     jsmn_init(&parser);
204     int status = jsmn_parse(&parser, jsonObject.c_str(),
205         jsonObject.size(), mJsmnTokens.data(), numTokens);
206     if (status < 0) {
207         ALOGE("Parser returns error code=%d", status);
208         return false;
209     }
210 
211     tokens->clear();
212     std::string token;
213     const char *pjs;
214     for (int j = 0; j < numTokens; ++j) {
215         pjs = jsonObject.c_str() + mJsmnTokens[j].start;
216         if (mJsmnTokens[j].type == JSMN_STRING ||
217                 mJsmnTokens[j].type == JSMN_PRIMITIVE) {
218             token.assign(pjs, mJsmnTokens[j].end - mJsmnTokens[j].start);
219             tokens->push_back(token);
220         }
221     }
222     return true;
223 }
224 
225 /*
226  * Parses JSON Web Key Set string and initializes a vector of JSON objects.
227  *
228  * @return Returns false for errors, true for success.
229  */
parseJsonWebKeySet(const std::string & jsonWebKeySet,std::vector<std::string> * jsonObjects)230 bool JsonWebKey::parseJsonWebKeySet(const std::string& jsonWebKeySet,
231         std::vector<std::string>* jsonObjects) {
232     if (jsonWebKeySet.empty()) {
233         ALOGE("Empty JSON Web Key");
234         return false;
235     }
236 
237     // The jsmn parser only supports unicode encoding.
238     jsmn_parser parser;
239 
240     // Computes number of tokens. A token marks the type, offset in
241     // the original string.
242     jsmn_init(&parser);
243     int numTokens = jsmn_parse(&parser,
244             jsonWebKeySet.c_str(), jsonWebKeySet.size(), nullptr, 0);
245     if (numTokens < 0) {
246         ALOGE("Parser returns error code=%d", numTokens);
247         return false;
248     }
249 
250     unsigned int jsmnTokensSize = numTokens * sizeof(jsmntok_t);
251     mJsmnTokens.resize(jsmnTokensSize);
252 
253     jsmn_init(&parser);
254     int status = jsmn_parse(&parser, jsonWebKeySet.c_str(),
255             jsonWebKeySet.size(), mJsmnTokens.data(), numTokens);
256     if (status < 0) {
257         ALOGE("Parser returns error code=%d", status);
258         return false;
259     }
260 
261     std::string token;
262     const char *pjs;
263     for (int i = 0; i < numTokens; ++i) {
264         pjs = jsonWebKeySet.c_str() + mJsmnTokens[i].start;
265         if (mJsmnTokens[i].type == JSMN_OBJECT) {
266             token.assign(pjs, mJsmnTokens[i].end - mJsmnTokens[i].start);
267             jsonObjects->push_back(token);
268         }
269     }
270     return true;
271 }
272 
273 } // namespace clearkey
274 } // namespace V1_2
275 } // namespace drm
276 } // namespace hardware
277 } // namespace android
278 
279