1 /*
2  * Copyright (C) 2021 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #include <utils/Log.h>
17 
18 #include <sys/stat.h>
19 #include <string>
20 
21 #include "DeviceFiles.h"
22 #include "protos/DeviceFiles.pb.h"
23 
24 #include <openssl/sha.h>
25 
26 // Protobuf generated classes.
27 using clearkeydrm::HashedFile;
28 using clearkeydrm::License;
29 using clearkeydrm::License_LicenseState_ACTIVE;
30 using clearkeydrm::License_LicenseState_RELEASING;
31 using clearkeydrm::OfflineFile;
32 
33 namespace {
34 const char kLicenseFileNameExt[] = ".lic";
35 
Hash(const std::string & data,std::string * hash)36 bool Hash(const std::string& data, std::string* hash) {
37     if (!hash) return false;
38 
39     hash->resize(SHA256_DIGEST_LENGTH);
40 
41     const unsigned char* input = reinterpret_cast<const unsigned char*>(data.data());
42     unsigned char* output = reinterpret_cast<unsigned char*>(&(*hash)[0]);
43     SHA256(input, data.size(), output);
44     return true;
45 }
46 
47 }  // namespace
48 
49 namespace clearkeydrm {
50 
StoreLicense(const std::string & keySetId,LicenseState state,const std::string & licenseResponse)51 bool DeviceFiles::StoreLicense(const std::string& keySetId, LicenseState state,
52                                const std::string& licenseResponse) {
53     OfflineFile file;
54     file.set_type(OfflineFile::LICENSE);
55     file.set_version(OfflineFile::VERSION_1);
56 
57     License* license = file.mutable_license();
58     switch (state) {
59         case kLicenseStateActive:
60             license->set_state(License_LicenseState_ACTIVE);
61             license->set_license(licenseResponse);
62             break;
63         case kLicenseStateReleasing:
64             license->set_state(License_LicenseState_RELEASING);
65             license->set_license(licenseResponse);
66             break;
67         default:
68             ALOGW("StoreLicense: Unknown license state: %u", state);
69             return false;
70     }
71 
72     std::string serializedFile;
73     file.SerializeToString(&serializedFile);
74 
75     return StoreFileWithHash(keySetId + kLicenseFileNameExt, serializedFile);
76 }
77 
StoreFileWithHash(const std::string & fileName,const std::string & serializedFile)78 bool DeviceFiles::StoreFileWithHash(const std::string& fileName,
79                                     const std::string& serializedFile) {
80     std::string hash;
81     if (!Hash(serializedFile, &hash)) {
82         ALOGE("StoreFileWithHash: Failed to compute hash");
83         return false;
84     }
85 
86     HashedFile hashFile;
87     hashFile.set_file(serializedFile);
88     hashFile.set_hash(hash);
89 
90     std::string serializedHashFile;
91     hashFile.SerializeToString(&serializedHashFile);
92 
93     return StoreFileRaw(fileName, serializedHashFile);
94 }
95 
StoreFileRaw(const std::string & fileName,const std::string & serializedHashFile)96 bool DeviceFiles::StoreFileRaw(const std::string& fileName, const std::string& serializedHashFile) {
97     MemoryFileSystem::MemoryFile memFile;
98     memFile.setFileName(fileName);
99     memFile.setContent(serializedHashFile);
100     memFile.setFileSize(serializedHashFile.size());
101     size_t len = mFileHandle.Write(fileName, memFile);
102 
103     if (len != static_cast<size_t>(serializedHashFile.size())) {
104         ALOGE("StoreFileRaw: Failed to write %s", fileName.c_str());
105         ALOGD("StoreFileRaw: expected=%zd, actual=%zu", serializedHashFile.size(), len);
106         return false;
107     }
108 
109     ALOGD("StoreFileRaw: wrote %zu bytes to %s", serializedHashFile.size(), fileName.c_str());
110     return true;
111 }
112 
RetrieveLicense(const std::string & keySetId,LicenseState * state,std::string * offlineLicense)113 bool DeviceFiles::RetrieveLicense(const std::string& keySetId, LicenseState* state,
114                                   std::string* offlineLicense) {
115     OfflineFile file;
116     if (!RetrieveHashedFile(keySetId + kLicenseFileNameExt, &file)) {
117         return false;
118     }
119 
120     if (file.type() != OfflineFile::LICENSE) {
121         ALOGE("RetrieveLicense: Invalid file type");
122         return false;
123     }
124 
125     if (file.version() != OfflineFile::VERSION_1) {
126         ALOGE("RetrieveLicense: Invalid file version");
127         return false;
128     }
129 
130     if (!file.has_license()) {
131         ALOGE("RetrieveLicense: License not present");
132         return false;
133     }
134 
135     License license = file.license();
136     switch (license.state()) {
137         case License_LicenseState_ACTIVE:
138             *state = kLicenseStateActive;
139             break;
140         case License_LicenseState_RELEASING:
141             *state = kLicenseStateReleasing;
142             break;
143         default:
144             ALOGW("RetrieveLicense: Unrecognized license state: %u", kLicenseStateUnknown);
145             *state = kLicenseStateUnknown;
146             break;
147     }
148     *offlineLicense = license.license();
149     return true;
150 }
151 
DeleteLicense(const std::string & keySetId)152 bool DeviceFiles::DeleteLicense(const std::string& keySetId) {
153     return mFileHandle.RemoveFile(keySetId + kLicenseFileNameExt);
154 }
155 
DeleteAllLicenses()156 bool DeviceFiles::DeleteAllLicenses() {
157     return mFileHandle.RemoveAllFiles();
158 }
159 
LicenseExists(const std::string & keySetId)160 bool DeviceFiles::LicenseExists(const std::string& keySetId) {
161     return mFileHandle.FileExists(keySetId + kLicenseFileNameExt);
162 }
163 
ListLicenses() const164 std::vector<std::string> DeviceFiles::ListLicenses() const {
165     std::vector<std::string> licenses = mFileHandle.ListFiles();
166     for (size_t i = 0; i < licenses.size(); i++) {
167         std::string& license = licenses[i];
168         license = license.substr(0, license.size() - strlen(kLicenseFileNameExt));
169     }
170     return licenses;
171 }
172 
RetrieveHashedFile(const std::string & fileName,OfflineFile * deSerializedFile)173 bool DeviceFiles::RetrieveHashedFile(const std::string& fileName, OfflineFile* deSerializedFile) {
174     if (!deSerializedFile) {
175         ALOGE("RetrieveHashedFile: invalid file parameter");
176         return false;
177     }
178 
179     if (!FileExists(fileName)) {
180         ALOGE("RetrieveHashedFile: %s does not exist", fileName.c_str());
181         return false;
182     }
183 
184     ssize_t bytes = GetFileSize(fileName);
185     if (bytes <= 0) {
186         ALOGE("RetrieveHashedFile: invalid file size: %s", fileName.c_str());
187         // Remove the corrupted file so the caller will not get the same error
188         // when trying to access the file repeatedly, causing the system to stall.
189         RemoveFile(fileName);
190         return false;
191     }
192 
193     std::string serializedHashFile;
194     serializedHashFile.resize(bytes);
195     bytes = mFileHandle.Read(fileName, &serializedHashFile);
196 
197     if (bytes != static_cast<ssize_t>(serializedHashFile.size())) {
198         ALOGE("RetrieveHashedFile: Failed to read from %s", fileName.c_str());
199         ALOGV("RetrieveHashedFile: expected: %zd, actual: %zd", serializedHashFile.size(), bytes);
200         // Remove the corrupted file so the caller will not get the same error
201         // when trying to access the file repeatedly, causing the system to stall.
202         RemoveFile(fileName);
203         return false;
204     }
205 
206     ALOGV("RetrieveHashedFile: read %zd from %s", bytes, fileName.c_str());
207 
208     HashedFile hashFile;
209     if (!hashFile.ParseFromString(serializedHashFile)) {
210         ALOGE("RetrieveHashedFile: Unable to parse hash file");
211         // Remove corrupt file.
212         RemoveFile(fileName);
213         return false;
214     }
215 
216     std::string hash;
217     if (!Hash(hashFile.file(), &hash)) {
218         ALOGE("RetrieveHashedFile: Hash computation failed");
219         return false;
220     }
221 
222     if (hash != hashFile.hash()) {
223         ALOGE("RetrieveHashedFile: Hash mismatch");
224         // Remove corrupt file.
225         RemoveFile(fileName);
226         return false;
227     }
228 
229     if (!deSerializedFile->ParseFromString(hashFile.file())) {
230         ALOGE("RetrieveHashedFile: Unable to parse file");
231         // Remove corrupt file.
232         RemoveFile(fileName);
233         return false;
234     }
235 
236     return true;
237 }
238 
FileExists(const std::string & fileName) const239 bool DeviceFiles::FileExists(const std::string& fileName) const {
240     return mFileHandle.FileExists(fileName);
241 }
242 
RemoveFile(const std::string & fileName)243 bool DeviceFiles::RemoveFile(const std::string& fileName) {
244     return mFileHandle.RemoveFile(fileName);
245 }
246 
GetFileSize(const std::string & fileName) const247 ssize_t DeviceFiles::GetFileSize(const std::string& fileName) const {
248     return mFileHandle.GetFileSize(fileName);
249 }
250 
251 }  // namespace clearkeydrm
252