1 /*
2  ** Copyright 2011, The Android Open Source Project
3  **
4  ** Licensed under the Apache License, Version 2.0 (the "License");
5  ** you may not use this file except in compliance with the License.
6  ** You may obtain a copy of the License at
7  **
8  **     http://www.apache.org/licenses/LICENSE-2.0
9  **
10  ** Unless required by applicable law or agreed to in writing, software
11  ** distributed under the License is distributed on an "AS IS" BASIS,
12  ** WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  ** See the License for the specific language governing permissions and
14  ** limitations under the License.
15  */
16 
17 #include "nnCache.h"
18 
19 #include <inttypes.h>
20 #include <sys/mman.h>
21 #include <sys/stat.h>
22 #include <unistd.h>
23 
24 #include <thread>
25 
26 #include <log/log.h>
27 
28 // Cache file header
29 static const char* cacheFileMagic = "nn$$";
30 static const size_t cacheFileHeaderSize = 8;
31 
32 // The time in seconds to wait before saving newly inserted cache entries.
33 static const unsigned int deferredSaveDelay = 4;
34 
35 // ----------------------------------------------------------------------------
36 namespace android {
37 // ----------------------------------------------------------------------------
38 
39 //
40 // NNCache definition
41 //
NNCache()42 NNCache::NNCache()
43     : mInitialized(false),
44       mMaxKeySize(0),
45       mMaxValueSize(0),
46       mMaxTotalSize(0),
47       mPolicy(defaultPolicy()),
48       mSavePending(false) {}
49 
~NNCache()50 NNCache::~NNCache() {}
51 
52 NNCache NNCache::sCache;
53 
get()54 NNCache* NNCache::get() {
55     return &sCache;
56 }
57 
initialize(size_t maxKeySize,size_t maxValueSize,size_t maxTotalSize,Policy policy)58 void NNCache::initialize(size_t maxKeySize, size_t maxValueSize, size_t maxTotalSize,
59                          Policy policy) {
60     std::lock_guard<std::mutex> lock(mMutex);
61     mInitialized = true;
62     mMaxKeySize = maxKeySize;
63     mMaxValueSize = maxValueSize;
64     mMaxTotalSize = maxTotalSize;
65     mPolicy = policy;
66 }
67 
terminate()68 void NNCache::terminate() {
69     std::lock_guard<std::mutex> lock(mMutex);
70     saveBlobCacheLocked();
71     mBlobCache = NULL;
72     mInitialized = false;
73 }
74 
setBlob(const void * key,ssize_t keySize,const void * value,ssize_t valueSize)75 void NNCache::setBlob(const void* key, ssize_t keySize, const void* value, ssize_t valueSize) {
76     std::lock_guard<std::mutex> lock(mMutex);
77 
78     if (keySize < 0 || valueSize < 0) {
79         ALOGW("nnCache::setBlob: negative sizes are not allowed");
80         return;
81     }
82 
83     if (mInitialized) {
84         BlobCache* bc = getBlobCacheLocked();
85         bc->set(key, keySize, value, valueSize);
86 
87         if (!mSavePending) {
88             mSavePending = true;
89             std::thread deferredSaveThread([this]() {
90                 sleep(deferredSaveDelay);
91                 std::lock_guard<std::mutex> lock(mMutex);
92                 if (mInitialized) {
93                     saveBlobCacheLocked();
94                 }
95                 mSavePending = false;
96             });
97             deferredSaveThread.detach();
98         }
99     }
100 }
101 
getBlob(const void * key,ssize_t keySize,void * value,ssize_t valueSize)102 ssize_t NNCache::getBlob(const void* key, ssize_t keySize, void* value, ssize_t valueSize) {
103     std::lock_guard<std::mutex> lock(mMutex);
104 
105     if (keySize < 0 || valueSize < 0) {
106         ALOGW("nnCache::getBlob: negative sizes are not allowed");
107         return 0;
108     }
109 
110     if (mInitialized) {
111         BlobCache* bc = getBlobCacheLocked();
112         return bc->get(key, keySize, value, valueSize);
113     }
114     return 0;
115 }
116 
getBlob(const void * key,ssize_t keySize,void ** value,std::function<void * (size_t)> alloc)117 ssize_t NNCache::getBlob(const void* key, ssize_t keySize, void** value,
118                          std::function<void*(size_t)> alloc) {
119     std::lock_guard<std::mutex> lock(mMutex);
120 
121     if (keySize < 0) {
122         ALOGW("nnCache::getBlob: negative sizes are not allowed");
123         return 0;
124     }
125 
126     if (mInitialized) {
127         BlobCache* bc = getBlobCacheLocked();
128         return bc->get(key, keySize, value, alloc);
129     }
130     return 0;
131 }
132 
setCacheFilename(const char * filename)133 void NNCache::setCacheFilename(const char* filename) {
134     std::lock_guard<std::mutex> lock(mMutex);
135     mFilename = filename;
136 }
137 
getBlobCacheLocked()138 BlobCache* NNCache::getBlobCacheLocked() {
139     if (mBlobCache == nullptr) {
140         mBlobCache.reset(new BlobCache(mMaxKeySize, mMaxValueSize, mMaxTotalSize, mPolicy));
141         loadBlobCacheLocked();
142     }
143     return mBlobCache.get();
144 }
145 
crc32c(const uint8_t * buf,size_t len)146 static uint32_t crc32c(const uint8_t* buf, size_t len) {
147     const uint32_t polyBits = 0x82F63B78;
148     uint32_t r = 0;
149     for (size_t i = 0; i < len; i++) {
150         r ^= buf[i];
151         for (int j = 0; j < 8; j++) {
152             if (r & 1) {
153                 r = (r >> 1) ^ polyBits;
154             } else {
155                 r >>= 1;
156             }
157         }
158     }
159     return r;
160 }
161 
saveBlobCacheLocked()162 void NNCache::saveBlobCacheLocked() {
163     if (mFilename.length() > 0 && mBlobCache != NULL) {
164         size_t cacheSize = mBlobCache->getFlattenedSize();
165         size_t headerSize = cacheFileHeaderSize;
166         const char* fname = mFilename.c_str();
167 
168         // Try to create the file with no permissions so we can write it
169         // without anyone trying to read it.
170         int fd = open(fname, O_CREAT | O_EXCL | O_RDWR, 0);
171         if (fd == -1) {
172             if (errno == EEXIST) {
173                 // The file exists, delete it and try again.
174                 if (unlink(fname) == -1) {
175                     // No point in retrying if the unlink failed.
176                     ALOGE("error unlinking cache file %s: %s (%d)", fname, strerror(errno), errno);
177                     return;
178                 }
179                 // Retry now that we've unlinked the file.
180                 fd = open(fname, O_CREAT | O_EXCL | O_RDWR, 0);
181             }
182             if (fd == -1) {
183                 ALOGE("error creating cache file %s: %s (%d)", fname, strerror(errno), errno);
184                 return;
185             }
186         }
187 
188         size_t fileSize = headerSize + cacheSize;
189 
190         uint8_t* buf = new uint8_t[fileSize];
191         if (!buf) {
192             ALOGE("error allocating buffer for cache contents: %s (%d)", strerror(errno), errno);
193             close(fd);
194             unlink(fname);
195             return;
196         }
197 
198         int err = mBlobCache->flatten(buf + headerSize, cacheSize);
199         if (err < 0) {
200             ALOGE("error writing cache contents: %s (%d)", strerror(-err), -err);
201             delete[] buf;
202             close(fd);
203             unlink(fname);
204             return;
205         }
206 
207         // Write the file magic and CRC
208         memcpy(buf, cacheFileMagic, 4);
209         uint32_t* crc = reinterpret_cast<uint32_t*>(buf + 4);
210         *crc = crc32c(buf + headerSize, cacheSize);
211 
212         if (write(fd, buf, fileSize) == -1) {
213             ALOGE("error writing cache file: %s (%d)", strerror(errno), errno);
214             delete[] buf;
215             close(fd);
216             unlink(fname);
217             return;
218         }
219 
220         delete[] buf;
221         fchmod(fd, S_IRUSR);
222         close(fd);
223     }
224 }
225 
loadBlobCacheLocked()226 void NNCache::loadBlobCacheLocked() {
227     if (mFilename.length() > 0) {
228         size_t headerSize = cacheFileHeaderSize;
229 
230         int fd = open(mFilename.c_str(), O_RDONLY, 0);
231         if (fd == -1) {
232             if (errno != ENOENT) {
233                 ALOGE("error opening cache file %s: %s (%d)", mFilename.c_str(), strerror(errno),
234                       errno);
235             }
236             return;
237         }
238 
239         struct stat statBuf;
240         if (fstat(fd, &statBuf) == -1) {
241             ALOGE("error stat'ing cache file: %s (%d)", strerror(errno), errno);
242             close(fd);
243             return;
244         }
245 
246         // Sanity check the size before trying to mmap it.
247         size_t fileSize = statBuf.st_size;
248         if (fileSize > mMaxTotalSize * 2) {
249             ALOGE("cache file is too large: %#" PRIx64, static_cast<off64_t>(statBuf.st_size));
250             close(fd);
251             return;
252         }
253 
254         uint8_t* buf =
255                 reinterpret_cast<uint8_t*>(mmap(NULL, fileSize, PROT_READ, MAP_PRIVATE, fd, 0));
256         if (buf == MAP_FAILED) {
257             ALOGE("error mmaping cache file: %s (%d)", strerror(errno), errno);
258             close(fd);
259             return;
260         }
261 
262         // Check the file magic and CRC
263         size_t cacheSize = fileSize - headerSize;
264         if (memcmp(buf, cacheFileMagic, 4) != 0) {
265             ALOGE("cache file has bad mojo");
266             close(fd);
267             return;
268         }
269         uint32_t* crc = reinterpret_cast<uint32_t*>(buf + 4);
270         if (crc32c(buf + headerSize, cacheSize) != *crc) {
271             ALOGE("cache file failed CRC check");
272             close(fd);
273             return;
274         }
275 
276         int err = mBlobCache->unflatten(buf + headerSize, cacheSize);
277         if (err < 0) {
278             ALOGE("error reading cache contents: %s (%d)", strerror(-err), -err);
279             munmap(buf, fileSize);
280             close(fd);
281             return;
282         }
283 
284         munmap(buf, fileSize);
285         close(fd);
286     }
287 }
288 
289 // ----------------------------------------------------------------------------
290 };  // namespace android
291 // ----------------------------------------------------------------------------
292