1 /*
2 ** Copyright 2011, The Android Open Source Project
3 **
4 ** Licensed under the Apache License, Version 2.0 (the "License");
5 ** you may not use this file except in compliance with the License.
6 ** You may obtain a copy of the License at
7 **
8 ** http://www.apache.org/licenses/LICENSE-2.0
9 **
10 ** Unless required by applicable law or agreed to in writing, software
11 ** distributed under the License is distributed on an "AS IS" BASIS,
12 ** WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 ** See the License for the specific language governing permissions and
14 ** limitations under the License.
15 */
16
17 #include "nnCache.h"
18
19 #include <inttypes.h>
20 #include <sys/mman.h>
21 #include <sys/stat.h>
22 #include <unistd.h>
23
24 #include <thread>
25
26 #include <log/log.h>
27
28 // Cache file header
29 static const char* cacheFileMagic = "nn$$";
30 static const size_t cacheFileHeaderSize = 8;
31
32 // The time in seconds to wait before saving newly inserted cache entries.
33 static const unsigned int deferredSaveDelay = 4;
34
35 // ----------------------------------------------------------------------------
36 namespace android {
37 // ----------------------------------------------------------------------------
38
39 //
40 // NNCache definition
41 //
NNCache()42 NNCache::NNCache()
43 : mInitialized(false),
44 mMaxKeySize(0),
45 mMaxValueSize(0),
46 mMaxTotalSize(0),
47 mPolicy(defaultPolicy()),
48 mSavePending(false) {}
49
~NNCache()50 NNCache::~NNCache() {}
51
52 NNCache NNCache::sCache;
53
get()54 NNCache* NNCache::get() {
55 return &sCache;
56 }
57
initialize(size_t maxKeySize,size_t maxValueSize,size_t maxTotalSize,Policy policy)58 void NNCache::initialize(size_t maxKeySize, size_t maxValueSize, size_t maxTotalSize,
59 Policy policy) {
60 std::lock_guard<std::mutex> lock(mMutex);
61 mInitialized = true;
62 mMaxKeySize = maxKeySize;
63 mMaxValueSize = maxValueSize;
64 mMaxTotalSize = maxTotalSize;
65 mPolicy = policy;
66 }
67
terminate()68 void NNCache::terminate() {
69 std::lock_guard<std::mutex> lock(mMutex);
70 saveBlobCacheLocked();
71 mBlobCache = NULL;
72 mInitialized = false;
73 }
74
setBlob(const void * key,ssize_t keySize,const void * value,ssize_t valueSize)75 void NNCache::setBlob(const void* key, ssize_t keySize, const void* value, ssize_t valueSize) {
76 std::lock_guard<std::mutex> lock(mMutex);
77
78 if (keySize < 0 || valueSize < 0) {
79 ALOGW("nnCache::setBlob: negative sizes are not allowed");
80 return;
81 }
82
83 if (mInitialized) {
84 BlobCache* bc = getBlobCacheLocked();
85 bc->set(key, keySize, value, valueSize);
86
87 if (!mSavePending) {
88 mSavePending = true;
89 std::thread deferredSaveThread([this]() {
90 sleep(deferredSaveDelay);
91 std::lock_guard<std::mutex> lock(mMutex);
92 if (mInitialized) {
93 saveBlobCacheLocked();
94 }
95 mSavePending = false;
96 });
97 deferredSaveThread.detach();
98 }
99 }
100 }
101
getBlob(const void * key,ssize_t keySize,void * value,ssize_t valueSize)102 ssize_t NNCache::getBlob(const void* key, ssize_t keySize, void* value, ssize_t valueSize) {
103 std::lock_guard<std::mutex> lock(mMutex);
104
105 if (keySize < 0 || valueSize < 0) {
106 ALOGW("nnCache::getBlob: negative sizes are not allowed");
107 return 0;
108 }
109
110 if (mInitialized) {
111 BlobCache* bc = getBlobCacheLocked();
112 return bc->get(key, keySize, value, valueSize);
113 }
114 return 0;
115 }
116
getBlob(const void * key,ssize_t keySize,void ** value,std::function<void * (size_t)> alloc)117 ssize_t NNCache::getBlob(const void* key, ssize_t keySize, void** value,
118 std::function<void*(size_t)> alloc) {
119 std::lock_guard<std::mutex> lock(mMutex);
120
121 if (keySize < 0) {
122 ALOGW("nnCache::getBlob: negative sizes are not allowed");
123 return 0;
124 }
125
126 if (mInitialized) {
127 BlobCache* bc = getBlobCacheLocked();
128 return bc->get(key, keySize, value, alloc);
129 }
130 return 0;
131 }
132
setCacheFilename(const char * filename)133 void NNCache::setCacheFilename(const char* filename) {
134 std::lock_guard<std::mutex> lock(mMutex);
135 mFilename = filename;
136 }
137
getBlobCacheLocked()138 BlobCache* NNCache::getBlobCacheLocked() {
139 if (mBlobCache == nullptr) {
140 mBlobCache.reset(new BlobCache(mMaxKeySize, mMaxValueSize, mMaxTotalSize, mPolicy));
141 loadBlobCacheLocked();
142 }
143 return mBlobCache.get();
144 }
145
crc32c(const uint8_t * buf,size_t len)146 static uint32_t crc32c(const uint8_t* buf, size_t len) {
147 const uint32_t polyBits = 0x82F63B78;
148 uint32_t r = 0;
149 for (size_t i = 0; i < len; i++) {
150 r ^= buf[i];
151 for (int j = 0; j < 8; j++) {
152 if (r & 1) {
153 r = (r >> 1) ^ polyBits;
154 } else {
155 r >>= 1;
156 }
157 }
158 }
159 return r;
160 }
161
saveBlobCacheLocked()162 void NNCache::saveBlobCacheLocked() {
163 if (mFilename.length() > 0 && mBlobCache != NULL) {
164 size_t cacheSize = mBlobCache->getFlattenedSize();
165 size_t headerSize = cacheFileHeaderSize;
166 const char* fname = mFilename.c_str();
167
168 // Try to create the file with no permissions so we can write it
169 // without anyone trying to read it.
170 int fd = open(fname, O_CREAT | O_EXCL | O_RDWR, 0);
171 if (fd == -1) {
172 if (errno == EEXIST) {
173 // The file exists, delete it and try again.
174 if (unlink(fname) == -1) {
175 // No point in retrying if the unlink failed.
176 ALOGE("error unlinking cache file %s: %s (%d)", fname, strerror(errno), errno);
177 return;
178 }
179 // Retry now that we've unlinked the file.
180 fd = open(fname, O_CREAT | O_EXCL | O_RDWR, 0);
181 }
182 if (fd == -1) {
183 ALOGE("error creating cache file %s: %s (%d)", fname, strerror(errno), errno);
184 return;
185 }
186 }
187
188 size_t fileSize = headerSize + cacheSize;
189
190 uint8_t* buf = new uint8_t[fileSize];
191 if (!buf) {
192 ALOGE("error allocating buffer for cache contents: %s (%d)", strerror(errno), errno);
193 close(fd);
194 unlink(fname);
195 return;
196 }
197
198 int err = mBlobCache->flatten(buf + headerSize, cacheSize);
199 if (err < 0) {
200 ALOGE("error writing cache contents: %s (%d)", strerror(-err), -err);
201 delete[] buf;
202 close(fd);
203 unlink(fname);
204 return;
205 }
206
207 // Write the file magic and CRC
208 memcpy(buf, cacheFileMagic, 4);
209 uint32_t* crc = reinterpret_cast<uint32_t*>(buf + 4);
210 *crc = crc32c(buf + headerSize, cacheSize);
211
212 if (write(fd, buf, fileSize) == -1) {
213 ALOGE("error writing cache file: %s (%d)", strerror(errno), errno);
214 delete[] buf;
215 close(fd);
216 unlink(fname);
217 return;
218 }
219
220 delete[] buf;
221 fchmod(fd, S_IRUSR);
222 close(fd);
223 }
224 }
225
loadBlobCacheLocked()226 void NNCache::loadBlobCacheLocked() {
227 if (mFilename.length() > 0) {
228 size_t headerSize = cacheFileHeaderSize;
229
230 int fd = open(mFilename.c_str(), O_RDONLY, 0);
231 if (fd == -1) {
232 if (errno != ENOENT) {
233 ALOGE("error opening cache file %s: %s (%d)", mFilename.c_str(), strerror(errno),
234 errno);
235 }
236 return;
237 }
238
239 struct stat statBuf;
240 if (fstat(fd, &statBuf) == -1) {
241 ALOGE("error stat'ing cache file: %s (%d)", strerror(errno), errno);
242 close(fd);
243 return;
244 }
245
246 // Sanity check the size before trying to mmap it.
247 size_t fileSize = statBuf.st_size;
248 if (fileSize > mMaxTotalSize * 2) {
249 ALOGE("cache file is too large: %#" PRIx64, static_cast<off64_t>(statBuf.st_size));
250 close(fd);
251 return;
252 }
253
254 uint8_t* buf =
255 reinterpret_cast<uint8_t*>(mmap(NULL, fileSize, PROT_READ, MAP_PRIVATE, fd, 0));
256 if (buf == MAP_FAILED) {
257 ALOGE("error mmaping cache file: %s (%d)", strerror(errno), errno);
258 close(fd);
259 return;
260 }
261
262 // Check the file magic and CRC
263 size_t cacheSize = fileSize - headerSize;
264 if (memcmp(buf, cacheFileMagic, 4) != 0) {
265 ALOGE("cache file has bad mojo");
266 close(fd);
267 return;
268 }
269 uint32_t* crc = reinterpret_cast<uint32_t*>(buf + 4);
270 if (crc32c(buf + headerSize, cacheSize) != *crc) {
271 ALOGE("cache file failed CRC check");
272 close(fd);
273 return;
274 }
275
276 int err = mBlobCache->unflatten(buf + headerSize, cacheSize);
277 if (err < 0) {
278 ALOGE("error reading cache contents: %s (%d)", strerror(-err), -err);
279 munmap(buf, fileSize);
280 close(fd);
281 return;
282 }
283
284 munmap(buf, fileSize);
285 close(fd);
286 }
287 }
288
289 // ----------------------------------------------------------------------------
290 }; // namespace android
291 // ----------------------------------------------------------------------------
292