1 /*
2 ** Copyright 2011, The Android Open Source Project
3 **
4 ** Licensed under the Apache License, Version 2.0 (the "License");
5 ** you may not use this file except in compliance with the License.
6 ** You may obtain a copy of the License at
7 **
8 ** http://www.apache.org/licenses/LICENSE-2.0
9 **
10 ** Unless required by applicable law or agreed to in writing, software
11 ** distributed under the License is distributed on an "AS IS" BASIS,
12 ** WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 ** See the License for the specific language governing permissions and
14 ** limitations under the License.
15 */
16
17 #include "nnCache.h"
18
19 #include <inttypes.h>
20 #include <sys/mman.h>
21 #include <sys/stat.h>
22 #include <unistd.h>
23
24 #include <thread>
25
26 #include <log/log.h>
27
28 // Cache file header
29 static const char* cacheFileMagic = "nn$$";
30 static const size_t cacheFileHeaderSize = 8;
31
32 // The time in seconds to wait before saving newly inserted cache entries.
33 static const unsigned int deferredSaveDelay = 4;
34
35 // ----------------------------------------------------------------------------
36 namespace android {
37 // ----------------------------------------------------------------------------
38
39 //
40 // NNCache definition
41 //
NNCache()42 NNCache::NNCache() :
43 mInitialized(false),
44 mMaxKeySize(0), mMaxValueSize(0), mMaxTotalSize(0),
45 mPolicy(defaultPolicy()),
46 mSavePending(false) {
47 }
48
~NNCache()49 NNCache::~NNCache() {
50 }
51
52 NNCache NNCache::sCache;
53
get()54 NNCache* NNCache::get() {
55 return &sCache;
56 }
57
initialize(size_t maxKeySize,size_t maxValueSize,size_t maxTotalSize,Policy policy)58 void NNCache::initialize(size_t maxKeySize, size_t maxValueSize, size_t maxTotalSize,
59 Policy policy) {
60 std::lock_guard<std::mutex> lock(mMutex);
61 mInitialized = true;
62 mMaxKeySize = maxKeySize;
63 mMaxValueSize = maxValueSize;
64 mMaxTotalSize = maxTotalSize;
65 mPolicy = policy;
66 }
67
terminate()68 void NNCache::terminate() {
69 std::lock_guard<std::mutex> lock(mMutex);
70 saveBlobCacheLocked();
71 mBlobCache = NULL;
72 mInitialized = false;
73 }
74
setBlob(const void * key,ssize_t keySize,const void * value,ssize_t valueSize)75 void NNCache::setBlob(const void* key, ssize_t keySize,
76 const void* value, ssize_t valueSize) {
77 std::lock_guard<std::mutex> lock(mMutex);
78
79 if (keySize < 0 || valueSize < 0) {
80 ALOGW("nnCache::setBlob: negative sizes are not allowed");
81 return;
82 }
83
84 if (mInitialized) {
85 BlobCache* bc = getBlobCacheLocked();
86 bc->set(key, keySize, value, valueSize);
87
88 if (!mSavePending) {
89 mSavePending = true;
90 std::thread deferredSaveThread([this]() {
91 sleep(deferredSaveDelay);
92 std::lock_guard<std::mutex> lock(mMutex);
93 if (mInitialized) {
94 saveBlobCacheLocked();
95 }
96 mSavePending = false;
97 });
98 deferredSaveThread.detach();
99 }
100 }
101 }
102
getBlob(const void * key,ssize_t keySize,void * value,ssize_t valueSize)103 ssize_t NNCache::getBlob(const void* key, ssize_t keySize,
104 void* value, ssize_t valueSize) {
105 std::lock_guard<std::mutex> lock(mMutex);
106
107 if (keySize < 0 || valueSize < 0) {
108 ALOGW("nnCache::getBlob: negative sizes are not allowed");
109 return 0;
110 }
111
112 if (mInitialized) {
113 BlobCache* bc = getBlobCacheLocked();
114 return bc->get(key, keySize, value, valueSize);
115 }
116 return 0;
117 }
118
getBlob(const void * key,ssize_t keySize,void ** value,std::function<void * (size_t)> alloc)119 ssize_t NNCache::getBlob(const void* key, ssize_t keySize,
120 void** value, std::function<void*(size_t)> alloc) {
121 std::lock_guard<std::mutex> lock(mMutex);
122
123 if (keySize < 0) {
124 ALOGW("nnCache::getBlob: negative sizes are not allowed");
125 return 0;
126 }
127
128 if (mInitialized) {
129 BlobCache* bc = getBlobCacheLocked();
130 return bc->get(key, keySize, value, alloc);
131 }
132 return 0;
133 }
134
setCacheFilename(const char * filename)135 void NNCache::setCacheFilename(const char* filename) {
136 std::lock_guard<std::mutex> lock(mMutex);
137 mFilename = filename;
138 }
139
getBlobCacheLocked()140 BlobCache* NNCache::getBlobCacheLocked() {
141 if (mBlobCache == nullptr) {
142 mBlobCache.reset(new BlobCache(mMaxKeySize, mMaxValueSize, mMaxTotalSize, mPolicy));
143 loadBlobCacheLocked();
144 }
145 return mBlobCache.get();
146 }
147
crc32c(const uint8_t * buf,size_t len)148 static uint32_t crc32c(const uint8_t* buf, size_t len) {
149 const uint32_t polyBits = 0x82F63B78;
150 uint32_t r = 0;
151 for (size_t i = 0; i < len; i++) {
152 r ^= buf[i];
153 for (int j = 0; j < 8; j++) {
154 if (r & 1) {
155 r = (r >> 1) ^ polyBits;
156 } else {
157 r >>= 1;
158 }
159 }
160 }
161 return r;
162 }
163
saveBlobCacheLocked()164 void NNCache::saveBlobCacheLocked() {
165 if (mFilename.length() > 0 && mBlobCache != NULL) {
166 size_t cacheSize = mBlobCache->getFlattenedSize();
167 size_t headerSize = cacheFileHeaderSize;
168 const char* fname = mFilename.c_str();
169
170 // Try to create the file with no permissions so we can write it
171 // without anyone trying to read it.
172 int fd = open(fname, O_CREAT | O_EXCL | O_RDWR, 0);
173 if (fd == -1) {
174 if (errno == EEXIST) {
175 // The file exists, delete it and try again.
176 if (unlink(fname) == -1) {
177 // No point in retrying if the unlink failed.
178 ALOGE("error unlinking cache file %s: %s (%d)", fname,
179 strerror(errno), errno);
180 return;
181 }
182 // Retry now that we've unlinked the file.
183 fd = open(fname, O_CREAT | O_EXCL | O_RDWR, 0);
184 }
185 if (fd == -1) {
186 ALOGE("error creating cache file %s: %s (%d)", fname,
187 strerror(errno), errno);
188 return;
189 }
190 }
191
192 size_t fileSize = headerSize + cacheSize;
193
194 uint8_t* buf = new uint8_t [fileSize];
195 if (!buf) {
196 ALOGE("error allocating buffer for cache contents: %s (%d)",
197 strerror(errno), errno);
198 close(fd);
199 unlink(fname);
200 return;
201 }
202
203 int err = mBlobCache->flatten(buf + headerSize, cacheSize);
204 if (err < 0) {
205 ALOGE("error writing cache contents: %s (%d)", strerror(-err),
206 -err);
207 delete [] buf;
208 close(fd);
209 unlink(fname);
210 return;
211 }
212
213 // Write the file magic and CRC
214 memcpy(buf, cacheFileMagic, 4);
215 uint32_t* crc = reinterpret_cast<uint32_t*>(buf + 4);
216 *crc = crc32c(buf + headerSize, cacheSize);
217
218 if (write(fd, buf, fileSize) == -1) {
219 ALOGE("error writing cache file: %s (%d)", strerror(errno),
220 errno);
221 delete [] buf;
222 close(fd);
223 unlink(fname);
224 return;
225 }
226
227 delete [] buf;
228 fchmod(fd, S_IRUSR);
229 close(fd);
230 }
231 }
232
loadBlobCacheLocked()233 void NNCache::loadBlobCacheLocked() {
234 if (mFilename.length() > 0) {
235 size_t headerSize = cacheFileHeaderSize;
236
237 int fd = open(mFilename.c_str(), O_RDONLY, 0);
238 if (fd == -1) {
239 if (errno != ENOENT) {
240 ALOGE("error opening cache file %s: %s (%d)", mFilename.c_str(),
241 strerror(errno), errno);
242 }
243 return;
244 }
245
246 struct stat statBuf;
247 if (fstat(fd, &statBuf) == -1) {
248 ALOGE("error stat'ing cache file: %s (%d)", strerror(errno), errno);
249 close(fd);
250 return;
251 }
252
253 // Sanity check the size before trying to mmap it.
254 size_t fileSize = statBuf.st_size;
255 if (fileSize > mMaxTotalSize * 2) {
256 ALOGE("cache file is too large: %#" PRIx64,
257 static_cast<off64_t>(statBuf.st_size));
258 close(fd);
259 return;
260 }
261
262 uint8_t* buf = reinterpret_cast<uint8_t*>(mmap(NULL, fileSize,
263 PROT_READ, MAP_PRIVATE, fd, 0));
264 if (buf == MAP_FAILED) {
265 ALOGE("error mmaping cache file: %s (%d)", strerror(errno),
266 errno);
267 close(fd);
268 return;
269 }
270
271 // Check the file magic and CRC
272 size_t cacheSize = fileSize - headerSize;
273 if (memcmp(buf, cacheFileMagic, 4) != 0) {
274 ALOGE("cache file has bad mojo");
275 close(fd);
276 return;
277 }
278 uint32_t* crc = reinterpret_cast<uint32_t*>(buf + 4);
279 if (crc32c(buf + headerSize, cacheSize) != *crc) {
280 ALOGE("cache file failed CRC check");
281 close(fd);
282 return;
283 }
284
285 int err = mBlobCache->unflatten(buf + headerSize, cacheSize);
286 if (err < 0) {
287 ALOGE("error reading cache contents: %s (%d)", strerror(-err),
288 -err);
289 munmap(buf, fileSize);
290 close(fd);
291 return;
292 }
293
294 munmap(buf, fileSize);
295 close(fd);
296 }
297 }
298
299 // ----------------------------------------------------------------------------
300 }; // namespace android
301 // ----------------------------------------------------------------------------
302