1 /*
2  ** Copyright 2011, The Android Open Source Project
3  **
4  ** Licensed under the Apache License, Version 2.0 (the "License");
5  ** you may not use this file except in compliance with the License.
6  ** You may obtain a copy of the License at
7  **
8  **     http://www.apache.org/licenses/LICENSE-2.0
9  **
10  ** Unless required by applicable law or agreed to in writing, software
11  ** distributed under the License is distributed on an "AS IS" BASIS,
12  ** WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  ** See the License for the specific language governing permissions and
14  ** limitations under the License.
15  */
16 
17 #include "nnCache.h"
18 
19 #include <fcntl.h>
20 #include <inttypes.h>
21 #include <log/log.h>
22 #include <sys/mman.h>
23 #include <sys/stat.h>
24 #include <unistd.h>
25 
26 #include <functional>
27 #include <thread>
28 
29 // Cache file header
30 static const char* cacheFileMagic = "nn$$";
31 static const size_t cacheFileHeaderSize = 8;
32 
33 // The time in seconds to wait before saving newly inserted cache entries.
34 static const unsigned int deferredSaveDelay = 4;
35 
36 // ----------------------------------------------------------------------------
37 namespace android {
38 // ----------------------------------------------------------------------------
39 
40 //
41 // NNCache definition
42 //
NNCache()43 NNCache::NNCache()
44     : mInitialized(false),
45       mMaxKeySize(0),
46       mMaxValueSize(0),
47       mMaxTotalSize(0),
48       mPolicy(defaultPolicy()),
49       mSavePending(false) {}
50 
~NNCache()51 NNCache::~NNCache() {}
52 
53 NNCache NNCache::sCache;
54 
get()55 NNCache* NNCache::get() {
56     return &sCache;
57 }
58 
initialize(size_t maxKeySize,size_t maxValueSize,size_t maxTotalSize,Policy policy)59 void NNCache::initialize(size_t maxKeySize, size_t maxValueSize, size_t maxTotalSize,
60                          Policy policy) {
61     std::lock_guard<std::mutex> lock(mMutex);
62     mInitialized = true;
63     mMaxKeySize = maxKeySize;
64     mMaxValueSize = maxValueSize;
65     mMaxTotalSize = maxTotalSize;
66     mPolicy = policy;
67 }
68 
terminate()69 void NNCache::terminate() {
70     std::lock_guard<std::mutex> lock(mMutex);
71     saveBlobCacheLocked();
72     mBlobCache = NULL;
73     mInitialized = false;
74 }
75 
setBlob(const void * key,ssize_t keySize,const void * value,ssize_t valueSize)76 void NNCache::setBlob(const void* key, ssize_t keySize, const void* value, ssize_t valueSize) {
77     std::lock_guard<std::mutex> lock(mMutex);
78 
79     if (keySize < 0 || valueSize < 0) {
80         ALOGW("nnCache::setBlob: negative sizes are not allowed");
81         return;
82     }
83 
84     if (mInitialized) {
85         BlobCache* bc = getBlobCacheLocked();
86         bc->set(key, keySize, value, valueSize);
87 
88         if (!mSavePending) {
89             mSavePending = true;
90             std::thread deferredSaveThread([this]() {
91                 sleep(deferredSaveDelay);
92                 std::lock_guard<std::mutex> lock(mMutex);
93                 if (mInitialized) {
94                     saveBlobCacheLocked();
95                 }
96                 mSavePending = false;
97             });
98             deferredSaveThread.detach();
99         }
100     }
101 }
102 
getBlob(const void * key,ssize_t keySize,void * value,ssize_t valueSize)103 ssize_t NNCache::getBlob(const void* key, ssize_t keySize, void* value, ssize_t valueSize) {
104     std::lock_guard<std::mutex> lock(mMutex);
105 
106     if (keySize < 0 || valueSize < 0) {
107         ALOGW("nnCache::getBlob: negative sizes are not allowed");
108         return 0;
109     }
110 
111     if (mInitialized) {
112         BlobCache* bc = getBlobCacheLocked();
113         return bc->get(key, keySize, value, valueSize);
114     }
115     return 0;
116 }
117 
getBlob(const void * key,ssize_t keySize,void ** value,std::function<void * (size_t)> alloc)118 ssize_t NNCache::getBlob(const void* key, ssize_t keySize, void** value,
119                          std::function<void*(size_t)> alloc) {
120     std::lock_guard<std::mutex> lock(mMutex);
121 
122     if (keySize < 0) {
123         ALOGW("nnCache::getBlob: negative sizes are not allowed");
124         return 0;
125     }
126 
127     if (mInitialized) {
128         BlobCache* bc = getBlobCacheLocked();
129         return bc->get(key, keySize, value, alloc);
130     }
131     return 0;
132 }
133 
setCacheFilename(const char * filename)134 void NNCache::setCacheFilename(const char* filename) {
135     std::lock_guard<std::mutex> lock(mMutex);
136     mFilename = filename;
137 }
138 
getBlobCacheLocked()139 BlobCache* NNCache::getBlobCacheLocked() {
140     if (mBlobCache == nullptr) {
141         mBlobCache.reset(new BlobCache(mMaxKeySize, mMaxValueSize, mMaxTotalSize, mPolicy));
142         loadBlobCacheLocked();
143     }
144     return mBlobCache.get();
145 }
146 
crc32c(const uint8_t * buf,size_t len)147 static uint32_t crc32c(const uint8_t* buf, size_t len) {
148     const uint32_t polyBits = 0x82F63B78;
149     uint32_t r = 0;
150     for (size_t i = 0; i < len; i++) {
151         r ^= buf[i];
152         for (int j = 0; j < 8; j++) {
153             if (r & 1) {
154                 r = (r >> 1) ^ polyBits;
155             } else {
156                 r >>= 1;
157             }
158         }
159     }
160     return r;
161 }
162 
saveBlobCacheLocked()163 void NNCache::saveBlobCacheLocked() {
164     if (mFilename.length() > 0 && mBlobCache != NULL) {
165         size_t cacheSize = mBlobCache->getFlattenedSize();
166         size_t headerSize = cacheFileHeaderSize;
167         const char* fname = mFilename.c_str();
168 
169         // Try to create the file with no permissions so we can write it
170         // without anyone trying to read it.
171         int fd = open(fname, O_CREAT | O_EXCL | O_RDWR, 0);
172         if (fd == -1) {
173             if (errno == EEXIST) {
174                 // The file exists, delete it and try again.
175                 if (unlink(fname) == -1) {
176                     // No point in retrying if the unlink failed.
177                     ALOGE("error unlinking cache file %s: %s (%d)", fname, strerror(errno), errno);
178                     return;
179                 }
180                 // Retry now that we've unlinked the file.
181                 fd = open(fname, O_CREAT | O_EXCL | O_RDWR, 0);
182             }
183             if (fd == -1) {
184                 ALOGE("error creating cache file %s: %s (%d)", fname, strerror(errno), errno);
185                 return;
186             }
187         }
188 
189         size_t fileSize = headerSize + cacheSize;
190 
191         uint8_t* buf = new uint8_t[fileSize];
192         if (!buf) {
193             ALOGE("error allocating buffer for cache contents: %s (%d)", strerror(errno), errno);
194             close(fd);
195             unlink(fname);
196             return;
197         }
198 
199         int err = mBlobCache->flatten(buf + headerSize, cacheSize);
200         if (err < 0) {
201             ALOGE("error writing cache contents: %s (%d)", strerror(-err), -err);
202             delete[] buf;
203             close(fd);
204             unlink(fname);
205             return;
206         }
207 
208         // Write the file magic and CRC
209         memcpy(buf, cacheFileMagic, 4);
210         uint32_t* crc = reinterpret_cast<uint32_t*>(buf + 4);
211         *crc = crc32c(buf + headerSize, cacheSize);
212 
213         if (write(fd, buf, fileSize) == -1) {
214             ALOGE("error writing cache file: %s (%d)", strerror(errno), errno);
215             delete[] buf;
216             close(fd);
217             unlink(fname);
218             return;
219         }
220 
221         delete[] buf;
222         fchmod(fd, S_IRUSR);
223         close(fd);
224     }
225 }
226 
loadBlobCacheLocked()227 void NNCache::loadBlobCacheLocked() {
228     if (mFilename.length() > 0) {
229         size_t headerSize = cacheFileHeaderSize;
230 
231         int fd = open(mFilename.c_str(), O_RDONLY, 0);
232         if (fd == -1) {
233             if (errno != ENOENT) {
234                 ALOGE("error opening cache file %s: %s (%d)", mFilename.c_str(), strerror(errno),
235                       errno);
236             }
237             return;
238         }
239 
240         struct stat statBuf;
241         if (fstat(fd, &statBuf) == -1) {
242             ALOGE("error stat'ing cache file: %s (%d)", strerror(errno), errno);
243             close(fd);
244             return;
245         }
246 
247         // Validity check the size before trying to mmap it.
248         size_t fileSize = statBuf.st_size;
249         if (fileSize > mMaxTotalSize * 2) {
250             ALOGE("cache file is too large: %#" PRIx64, static_cast<off64_t>(statBuf.st_size));
251             close(fd);
252             return;
253         }
254 
255         uint8_t* buf =
256                 reinterpret_cast<uint8_t*>(mmap(NULL, fileSize, PROT_READ, MAP_PRIVATE, fd, 0));
257         if (buf == MAP_FAILED) {
258             ALOGE("error mmaping cache file: %s (%d)", strerror(errno), errno);
259             close(fd);
260             return;
261         }
262 
263         // Check the file magic and CRC
264         size_t cacheSize = fileSize - headerSize;
265         if (memcmp(buf, cacheFileMagic, 4) != 0) {
266             ALOGE("cache file has bad mojo");
267             close(fd);
268             return;
269         }
270         uint32_t* crc = reinterpret_cast<uint32_t*>(buf + 4);
271         if (crc32c(buf + headerSize, cacheSize) != *crc) {
272             ALOGE("cache file failed CRC check");
273             close(fd);
274             return;
275         }
276 
277         int err = mBlobCache->unflatten(buf + headerSize, cacheSize);
278         if (err < 0) {
279             ALOGE("error reading cache contents: %s (%d)", strerror(-err), -err);
280             munmap(buf, fileSize);
281             close(fd);
282             return;
283         }
284 
285         munmap(buf, fileSize);
286         close(fd);
287     }
288 }
289 
290 // ----------------------------------------------------------------------------
291 };  // namespace android
292 // ----------------------------------------------------------------------------
293