1 /*
2  * Copyright (C) 2021 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #define LOG_TAG "clearkey-JsonWebKey"
17 
18 #include <utils/Log.h>
19 
20 #include "JsonWebKey.h"
21 
22 #include "Base64.h"
23 
24 namespace {
25 const std::string kBase64Padding("=");
26 const std::string kKeysTag("keys");
27 const std::string kKeyTypeTag("kty");
28 const std::string kKeyTag("k");
29 const std::string kKeyIdTag("kid");
30 const std::string kMediaSessionType("type");
31 const std::string kPersistentLicenseSession("persistent-license");
32 const std::string kSymmetricKeyValue("oct");
33 const std::string kTemporaryLicenseSession("temporary");
34 }  // namespace
35 
36 namespace clearkeydrm {
37 
JsonWebKey()38 JsonWebKey::JsonWebKey() {}
39 
~JsonWebKey()40 JsonWebKey::~JsonWebKey() {}
41 
42 /*
43  * Parses a JSON Web Key Set string, initializes a KeyMap with key id:key
44  * pairs from the JSON Web Key Set. Both key ids and keys are base64url
45  * encoded. The KeyMap contains base64url decoded key id:key pairs.
46  *
47  * @return Returns false for errors, true for success.
48  */
extractKeysFromJsonWebKeySet(const std::string & jsonWebKeySet,KeyMap * keys)49 bool JsonWebKey::extractKeysFromJsonWebKeySet(const std::string& jsonWebKeySet, KeyMap* keys) {
50     keys->clear();
51 
52     if (!parseJsonWebKeySet(jsonWebKeySet, &mJsonObjects)) {
53         return false;
54     }
55 
56     // mJsonObjects[0] contains the entire JSON Web Key Set, including
57     // all the base64 encoded keys. Each key is also stored separately as
58     // a JSON object in mJsonObjects[1..n] where n is the total
59     // number of keys in the set.
60     if (mJsonObjects.size() == 0 || !isJsonWebKeySet(mJsonObjects[0])) {
61         return false;
62     }
63 
64     std::string encodedKey, encodedKeyId;
65     std::vector<uint8_t> decodedKey, decodedKeyId;
66 
67     // mJsonObjects[1] contains the first JSON Web Key in the set
68     for (size_t i = 1; i < mJsonObjects.size(); ++i) {
69         encodedKeyId.clear();
70         encodedKey.clear();
71 
72         if (!parseJsonObject(mJsonObjects[i], &mTokens)) return false;
73 
74         if (findKey(mJsonObjects[i], &encodedKeyId, &encodedKey)) {
75             if (encodedKeyId.empty() || encodedKey.empty()) {
76                 ALOGE("Must have both key id and key in the JsonWebKey set.");
77                 continue;
78             }
79 
80             if (!decodeBase64String(encodedKeyId, &decodedKeyId)) {
81                 ALOGE("Failed to decode key id(%s)", encodedKeyId.c_str());
82                 continue;
83             }
84 
85             if (!decodeBase64String(encodedKey, &decodedKey)) {
86                 ALOGE("Failed to decode key(%s)", encodedKey.c_str());
87                 continue;
88             }
89 
90             keys->insert(std::pair<std::vector<uint8_t>, std::vector<uint8_t>>(decodedKeyId,
91                                                                                decodedKey));
92         }
93     }
94     return true;
95 }
96 
decodeBase64String(const std::string & encodedText,std::vector<uint8_t> * decodedText)97 bool JsonWebKey::decodeBase64String(const std::string& encodedText,
98                                     std::vector<uint8_t>* decodedText) {
99     decodedText->clear();
100 
101     // encodedText should not contain padding characters as per EME spec.
102     if (encodedText.find(kBase64Padding) != std::string::npos) {
103         return false;
104     }
105 
106     // Since decodeBase64() requires padding characters,
107     // add them so length of encodedText is exactly a multiple of 4.
108     int remainder = encodedText.length() % 4;
109     std::string paddedText(encodedText);
110     if (remainder > 0) {
111         for (int i = 0; i < 4 - remainder; ++i) {
112             paddedText.append(kBase64Padding);
113         }
114     }
115 
116     ::android::sp<Buffer> buffer = decodeBase64(paddedText);
117     if (buffer == nullptr) {
118         ALOGE("Malformed base64 encoded content found.");
119         return false;
120     }
121 
122     decodedText->insert(decodedText->end(), buffer->base(), buffer->base() + buffer->size());
123     return true;
124 }
125 
findKey(const std::string & jsonObject,std::string * keyId,std::string * encodedKey)126 bool JsonWebKey::findKey(const std::string& jsonObject, std::string* keyId,
127                          std::string* encodedKey) {
128     std::string key, value;
129 
130     // Only allow symmetric key, i.e. "kty":"oct" pair.
131     if (jsonObject.find(kKeyTypeTag) != std::string::npos) {
132         findValue(kKeyTypeTag, &value);
133         if (0 != value.compare(kSymmetricKeyValue)) return false;
134     }
135 
136     if (jsonObject.find(kKeyIdTag) != std::string::npos) {
137         findValue(kKeyIdTag, keyId);
138     }
139 
140     if (jsonObject.find(kKeyTag) != std::string::npos) {
141         findValue(kKeyTag, encodedKey);
142     }
143     return true;
144 }
145 
findValue(const std::string & key,std::string * value)146 void JsonWebKey::findValue(const std::string& key, std::string* value) {
147     value->clear();
148     const char* valueToken;
149     for (std::vector<std::string>::const_iterator nextToken = mTokens.begin();
150          nextToken != mTokens.end(); ++nextToken) {
151         if (0 == (*nextToken).compare(key)) {
152             if (nextToken + 1 == mTokens.end()) break;
153             valueToken = (*(nextToken + 1)).c_str();
154             value->assign(valueToken);
155             nextToken++;
156             break;
157         }
158     }
159 }
160 
isJsonWebKeySet(const std::string & jsonObject) const161 bool JsonWebKey::isJsonWebKeySet(const std::string& jsonObject) const {
162     if (jsonObject.find(kKeysTag) == std::string::npos) {
163         ALOGE("JSON Web Key does not contain keys.");
164         return false;
165     }
166     return true;
167 }
168 
169 /*
170  * Parses a JSON objects string and initializes a vector of tokens.
171  *
172  * @return Returns false for errors, true for success.
173  */
parseJsonObject(const std::string & jsonObject,std::vector<std::string> * tokens)174 bool JsonWebKey::parseJsonObject(const std::string& jsonObject, std::vector<std::string>* tokens) {
175     jsmn_parser parser;
176 
177     jsmn_init(&parser);
178     int numTokens = jsmn_parse(&parser, jsonObject.c_str(), jsonObject.size(), nullptr, 0);
179     if (numTokens < 0) {
180         ALOGE("Parser returns error code=%d", numTokens);
181         return false;
182     }
183 
184     unsigned int jsmnTokensSize = numTokens * sizeof(jsmntok_t);
185     mJsmnTokens.clear();
186     mJsmnTokens.resize(jsmnTokensSize);
187 
188     jsmn_init(&parser);
189     int status = jsmn_parse(&parser, jsonObject.c_str(), jsonObject.size(), mJsmnTokens.data(),
190                             numTokens);
191     if (status < 0) {
192         ALOGE("Parser returns error code=%d", status);
193         return false;
194     }
195 
196     tokens->clear();
197     std::string token;
198     const char* pjs;
199     for (int j = 0; j < numTokens; ++j) {
200         pjs = jsonObject.c_str() + mJsmnTokens[j].start;
201         if (mJsmnTokens[j].type == JSMN_STRING || mJsmnTokens[j].type == JSMN_PRIMITIVE) {
202             token.assign(pjs, mJsmnTokens[j].end - mJsmnTokens[j].start);
203             tokens->push_back(token);
204         }
205     }
206     return true;
207 }
208 
209 /*
210  * Parses JSON Web Key Set string and initializes a vector of JSON objects.
211  *
212  * @return Returns false for errors, true for success.
213  */
parseJsonWebKeySet(const std::string & jsonWebKeySet,std::vector<std::string> * jsonObjects)214 bool JsonWebKey::parseJsonWebKeySet(const std::string& jsonWebKeySet,
215                                     std::vector<std::string>* jsonObjects) {
216     if (jsonWebKeySet.empty()) {
217         ALOGE("Empty JSON Web Key");
218         return false;
219     }
220 
221     // The jsmn parser only supports unicode encoding.
222     jsmn_parser parser;
223 
224     // Computes number of tokens. A token marks the type, offset in
225     // the original string.
226     jsmn_init(&parser);
227     int numTokens = jsmn_parse(&parser, jsonWebKeySet.c_str(), jsonWebKeySet.size(), nullptr, 0);
228     if (numTokens < 0) {
229         ALOGE("Parser returns error code=%d", numTokens);
230         return false;
231     }
232 
233     unsigned int jsmnTokensSize = numTokens * sizeof(jsmntok_t);
234     mJsmnTokens.resize(jsmnTokensSize);
235 
236     jsmn_init(&parser);
237     int status = jsmn_parse(&parser, jsonWebKeySet.c_str(), jsonWebKeySet.size(),
238                             mJsmnTokens.data(), numTokens);
239     if (status < 0) {
240         ALOGE("Parser returns error code=%d", status);
241         return false;
242     }
243 
244     std::string token;
245     const char* pjs;
246     for (int i = 0; i < numTokens; ++i) {
247         pjs = jsonWebKeySet.c_str() + mJsmnTokens[i].start;
248         if (mJsmnTokens[i].type == JSMN_OBJECT) {
249             token.assign(pjs, mJsmnTokens[i].end - mJsmnTokens[i].start);
250             jsonObjects->push_back(token);
251         }
252     }
253     return true;
254 }
255 
256 }  // namespace clearkeydrm
257