1 /*
2 ** Copyright 2011, The Android Open Source Project
3 **
4 ** Licensed under the Apache License, Version 2.0 (the "License");
5 ** you may not use this file except in compliance with the License.
6 ** You may obtain a copy of the License at
7 **
8 ** http://www.apache.org/licenses/LICENSE-2.0
9 **
10 ** Unless required by applicable law or agreed to in writing, software
11 ** distributed under the License is distributed on an "AS IS" BASIS,
12 ** WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 ** See the License for the specific language governing permissions and
14 ** limitations under the License.
15 */
16
17 #include "nnCache.h"
18
19 #include <fcntl.h>
20 #include <inttypes.h>
21 #include <log/log.h>
22 #include <sys/mman.h>
23 #include <sys/stat.h>
24 #include <unistd.h>
25
26 #include <functional>
27 #include <thread>
28
29 // Cache file header
30 static const char* cacheFileMagic = "nn$$";
31 static const size_t cacheFileHeaderSize = 8;
32
33 // The time in seconds to wait before saving newly inserted cache entries.
34 static const unsigned int deferredSaveDelay = 4;
35
36 // ----------------------------------------------------------------------------
37 namespace android {
38 // ----------------------------------------------------------------------------
39
40 //
41 // NNCache definition
42 //
NNCache()43 NNCache::NNCache()
44 : mInitialized(false),
45 mMaxKeySize(0),
46 mMaxValueSize(0),
47 mMaxTotalSize(0),
48 mPolicy(defaultPolicy()),
49 mSavePending(false) {}
50
~NNCache()51 NNCache::~NNCache() {}
52
53 NNCache NNCache::sCache;
54
get()55 NNCache* NNCache::get() {
56 return &sCache;
57 }
58
initialize(size_t maxKeySize,size_t maxValueSize,size_t maxTotalSize,Policy policy)59 void NNCache::initialize(size_t maxKeySize, size_t maxValueSize, size_t maxTotalSize,
60 Policy policy) {
61 std::lock_guard<std::mutex> lock(mMutex);
62 mInitialized = true;
63 mMaxKeySize = maxKeySize;
64 mMaxValueSize = maxValueSize;
65 mMaxTotalSize = maxTotalSize;
66 mPolicy = policy;
67 }
68
terminate()69 void NNCache::terminate() {
70 std::lock_guard<std::mutex> lock(mMutex);
71 saveBlobCacheLocked();
72 mBlobCache = NULL;
73 mInitialized = false;
74 }
75
setBlob(const void * key,ssize_t keySize,const void * value,ssize_t valueSize)76 void NNCache::setBlob(const void* key, ssize_t keySize, const void* value, ssize_t valueSize) {
77 std::lock_guard<std::mutex> lock(mMutex);
78
79 if (keySize < 0 || valueSize < 0) {
80 ALOGW("nnCache::setBlob: negative sizes are not allowed");
81 return;
82 }
83
84 if (mInitialized) {
85 BlobCache* bc = getBlobCacheLocked();
86 bc->set(key, keySize, value, valueSize);
87
88 if (!mSavePending) {
89 mSavePending = true;
90 std::thread deferredSaveThread([this]() {
91 sleep(deferredSaveDelay);
92 std::lock_guard<std::mutex> lock(mMutex);
93 if (mInitialized) {
94 saveBlobCacheLocked();
95 }
96 mSavePending = false;
97 });
98 deferredSaveThread.detach();
99 }
100 }
101 }
102
getBlob(const void * key,ssize_t keySize,void * value,ssize_t valueSize)103 ssize_t NNCache::getBlob(const void* key, ssize_t keySize, void* value, ssize_t valueSize) {
104 std::lock_guard<std::mutex> lock(mMutex);
105
106 if (keySize < 0 || valueSize < 0) {
107 ALOGW("nnCache::getBlob: negative sizes are not allowed");
108 return 0;
109 }
110
111 if (mInitialized) {
112 BlobCache* bc = getBlobCacheLocked();
113 return bc->get(key, keySize, value, valueSize);
114 }
115 return 0;
116 }
117
getBlob(const void * key,ssize_t keySize,void ** value,std::function<void * (size_t)> alloc)118 ssize_t NNCache::getBlob(const void* key, ssize_t keySize, void** value,
119 std::function<void*(size_t)> alloc) {
120 std::lock_guard<std::mutex> lock(mMutex);
121
122 if (keySize < 0) {
123 ALOGW("nnCache::getBlob: negative sizes are not allowed");
124 return 0;
125 }
126
127 if (mInitialized) {
128 BlobCache* bc = getBlobCacheLocked();
129 return bc->get(key, keySize, value, alloc);
130 }
131 return 0;
132 }
133
setCacheFilename(const char * filename)134 void NNCache::setCacheFilename(const char* filename) {
135 std::lock_guard<std::mutex> lock(mMutex);
136 mFilename = filename;
137 }
138
getBlobCacheLocked()139 BlobCache* NNCache::getBlobCacheLocked() {
140 if (mBlobCache == nullptr) {
141 mBlobCache.reset(new BlobCache(mMaxKeySize, mMaxValueSize, mMaxTotalSize, mPolicy));
142 loadBlobCacheLocked();
143 }
144 return mBlobCache.get();
145 }
146
crc32c(const uint8_t * buf,size_t len)147 static uint32_t crc32c(const uint8_t* buf, size_t len) {
148 const uint32_t polyBits = 0x82F63B78;
149 uint32_t r = 0;
150 for (size_t i = 0; i < len; i++) {
151 r ^= buf[i];
152 for (int j = 0; j < 8; j++) {
153 if (r & 1) {
154 r = (r >> 1) ^ polyBits;
155 } else {
156 r >>= 1;
157 }
158 }
159 }
160 return r;
161 }
162
saveBlobCacheLocked()163 void NNCache::saveBlobCacheLocked() {
164 if (mFilename.length() > 0 && mBlobCache != NULL) {
165 size_t cacheSize = mBlobCache->getFlattenedSize();
166 size_t headerSize = cacheFileHeaderSize;
167 const char* fname = mFilename.c_str();
168
169 // Try to create the file with no permissions so we can write it
170 // without anyone trying to read it.
171 int fd = open(fname, O_CREAT | O_EXCL | O_RDWR, 0);
172 if (fd == -1) {
173 if (errno == EEXIST) {
174 // The file exists, delete it and try again.
175 if (unlink(fname) == -1) {
176 // No point in retrying if the unlink failed.
177 ALOGE("error unlinking cache file %s: %s (%d)", fname, strerror(errno), errno);
178 return;
179 }
180 // Retry now that we've unlinked the file.
181 fd = open(fname, O_CREAT | O_EXCL | O_RDWR, 0);
182 }
183 if (fd == -1) {
184 ALOGE("error creating cache file %s: %s (%d)", fname, strerror(errno), errno);
185 return;
186 }
187 }
188
189 size_t fileSize = headerSize + cacheSize;
190
191 uint8_t* buf = new uint8_t[fileSize];
192 if (!buf) {
193 ALOGE("error allocating buffer for cache contents: %s (%d)", strerror(errno), errno);
194 close(fd);
195 unlink(fname);
196 return;
197 }
198
199 int err = mBlobCache->flatten(buf + headerSize, cacheSize);
200 if (err < 0) {
201 ALOGE("error writing cache contents: %s (%d)", strerror(-err), -err);
202 delete[] buf;
203 close(fd);
204 unlink(fname);
205 return;
206 }
207
208 // Write the file magic and CRC
209 memcpy(buf, cacheFileMagic, 4);
210 uint32_t* crc = reinterpret_cast<uint32_t*>(buf + 4);
211 *crc = crc32c(buf + headerSize, cacheSize);
212
213 if (write(fd, buf, fileSize) == -1) {
214 ALOGE("error writing cache file: %s (%d)", strerror(errno), errno);
215 delete[] buf;
216 close(fd);
217 unlink(fname);
218 return;
219 }
220
221 delete[] buf;
222 fchmod(fd, S_IRUSR);
223 close(fd);
224 }
225 }
226
loadBlobCacheLocked()227 void NNCache::loadBlobCacheLocked() {
228 if (mFilename.length() > 0) {
229 size_t headerSize = cacheFileHeaderSize;
230
231 int fd = open(mFilename.c_str(), O_RDONLY, 0);
232 if (fd == -1) {
233 if (errno != ENOENT) {
234 ALOGE("error opening cache file %s: %s (%d)", mFilename.c_str(), strerror(errno),
235 errno);
236 }
237 return;
238 }
239
240 struct stat statBuf;
241 if (fstat(fd, &statBuf) == -1) {
242 ALOGE("error stat'ing cache file: %s (%d)", strerror(errno), errno);
243 close(fd);
244 return;
245 }
246
247 // Validity check the size before trying to mmap it.
248 size_t fileSize = statBuf.st_size;
249 if (fileSize > mMaxTotalSize * 2) {
250 ALOGE("cache file is too large: %#" PRIx64, static_cast<off64_t>(statBuf.st_size));
251 close(fd);
252 return;
253 }
254
255 uint8_t* buf =
256 reinterpret_cast<uint8_t*>(mmap(NULL, fileSize, PROT_READ, MAP_PRIVATE, fd, 0));
257 if (buf == MAP_FAILED) {
258 ALOGE("error mmaping cache file: %s (%d)", strerror(errno), errno);
259 close(fd);
260 return;
261 }
262
263 // Check the file magic and CRC
264 size_t cacheSize = fileSize - headerSize;
265 if (memcmp(buf, cacheFileMagic, 4) != 0) {
266 ALOGE("cache file has bad mojo");
267 close(fd);
268 return;
269 }
270 uint32_t* crc = reinterpret_cast<uint32_t*>(buf + 4);
271 if (crc32c(buf + headerSize, cacheSize) != *crc) {
272 ALOGE("cache file failed CRC check");
273 close(fd);
274 return;
275 }
276
277 int err = mBlobCache->unflatten(buf + headerSize, cacheSize);
278 if (err < 0) {
279 ALOGE("error reading cache contents: %s (%d)", strerror(-err), -err);
280 munmap(buf, fileSize);
281 close(fd);
282 return;
283 }
284
285 munmap(buf, fileSize);
286 close(fd);
287 }
288 }
289
290 // ----------------------------------------------------------------------------
291 }; // namespace android
292 // ----------------------------------------------------------------------------
293