1 /*
2  ** Copyright 2011, The Android Open Source Project
3  **
4  ** Licensed under the Apache License, Version 2.0 (the "License");
5  ** you may not use this file except in compliance with the License.
6  ** You may obtain a copy of the License at
7  **
8  **     http://www.apache.org/licenses/LICENSE-2.0
9  **
10  ** Unless required by applicable law or agreed to in writing, software
11  ** distributed under the License is distributed on an "AS IS" BASIS,
12  ** WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  ** See the License for the specific language governing permissions and
14  ** limitations under the License.
15  */
16 
17 #include "nnCache.h"
18 
19 #include <inttypes.h>
20 #include <sys/mman.h>
21 #include <sys/stat.h>
22 #include <unistd.h>
23 
24 #include <thread>
25 
26 #include <log/log.h>
27 
28 // Cache file header
29 static const char* cacheFileMagic = "nn$$";
30 static const size_t cacheFileHeaderSize = 8;
31 
32 // The time in seconds to wait before saving newly inserted cache entries.
33 static const unsigned int deferredSaveDelay = 4;
34 
35 // ----------------------------------------------------------------------------
36 namespace android {
37 // ----------------------------------------------------------------------------
38 
39 //
40 // NNCache definition
41 //
NNCache()42 NNCache::NNCache() :
43     mInitialized(false),
44     mMaxKeySize(0), mMaxValueSize(0), mMaxTotalSize(0),
45     mPolicy(defaultPolicy()),
46     mSavePending(false) {
47 }
48 
~NNCache()49 NNCache::~NNCache() {
50 }
51 
52 NNCache NNCache::sCache;
53 
get()54 NNCache* NNCache::get() {
55     return &sCache;
56 }
57 
initialize(size_t maxKeySize,size_t maxValueSize,size_t maxTotalSize,Policy policy)58 void NNCache::initialize(size_t maxKeySize, size_t maxValueSize, size_t maxTotalSize,
59                          Policy policy) {
60     std::lock_guard<std::mutex> lock(mMutex);
61     mInitialized = true;
62     mMaxKeySize = maxKeySize;
63     mMaxValueSize = maxValueSize;
64     mMaxTotalSize = maxTotalSize;
65     mPolicy = policy;
66 }
67 
terminate()68 void NNCache::terminate() {
69     std::lock_guard<std::mutex> lock(mMutex);
70     saveBlobCacheLocked();
71     mBlobCache = NULL;
72     mInitialized = false;
73 }
74 
setBlob(const void * key,ssize_t keySize,const void * value,ssize_t valueSize)75 void NNCache::setBlob(const void* key, ssize_t keySize,
76         const void* value, ssize_t valueSize) {
77     std::lock_guard<std::mutex> lock(mMutex);
78 
79     if (keySize < 0 || valueSize < 0) {
80         ALOGW("nnCache::setBlob: negative sizes are not allowed");
81         return;
82     }
83 
84     if (mInitialized) {
85         BlobCache* bc = getBlobCacheLocked();
86         bc->set(key, keySize, value, valueSize);
87 
88         if (!mSavePending) {
89             mSavePending = true;
90             std::thread deferredSaveThread([this]() {
91                 sleep(deferredSaveDelay);
92                 std::lock_guard<std::mutex> lock(mMutex);
93                 if (mInitialized) {
94                     saveBlobCacheLocked();
95                 }
96                 mSavePending = false;
97             });
98             deferredSaveThread.detach();
99         }
100     }
101 }
102 
getBlob(const void * key,ssize_t keySize,void * value,ssize_t valueSize)103 ssize_t NNCache::getBlob(const void* key, ssize_t keySize,
104         void* value, ssize_t valueSize) {
105     std::lock_guard<std::mutex> lock(mMutex);
106 
107     if (keySize < 0 || valueSize < 0) {
108         ALOGW("nnCache::getBlob: negative sizes are not allowed");
109         return 0;
110     }
111 
112     if (mInitialized) {
113         BlobCache* bc = getBlobCacheLocked();
114         return bc->get(key, keySize, value, valueSize);
115     }
116     return 0;
117 }
118 
getBlob(const void * key,ssize_t keySize,void ** value,std::function<void * (size_t)> alloc)119 ssize_t NNCache::getBlob(const void* key, ssize_t keySize,
120         void** value, std::function<void*(size_t)> alloc) {
121     std::lock_guard<std::mutex> lock(mMutex);
122 
123     if (keySize < 0) {
124         ALOGW("nnCache::getBlob: negative sizes are not allowed");
125         return 0;
126     }
127 
128     if (mInitialized) {
129         BlobCache* bc = getBlobCacheLocked();
130         return bc->get(key, keySize, value, alloc);
131     }
132     return 0;
133 }
134 
setCacheFilename(const char * filename)135 void NNCache::setCacheFilename(const char* filename) {
136     std::lock_guard<std::mutex> lock(mMutex);
137     mFilename = filename;
138 }
139 
getBlobCacheLocked()140 BlobCache* NNCache::getBlobCacheLocked() {
141     if (mBlobCache == nullptr) {
142         mBlobCache.reset(new BlobCache(mMaxKeySize, mMaxValueSize, mMaxTotalSize, mPolicy));
143         loadBlobCacheLocked();
144     }
145     return mBlobCache.get();
146 }
147 
crc32c(const uint8_t * buf,size_t len)148 static uint32_t crc32c(const uint8_t* buf, size_t len) {
149     const uint32_t polyBits = 0x82F63B78;
150     uint32_t r = 0;
151     for (size_t i = 0; i < len; i++) {
152         r ^= buf[i];
153         for (int j = 0; j < 8; j++) {
154             if (r & 1) {
155                 r = (r >> 1) ^ polyBits;
156             } else {
157                 r >>= 1;
158             }
159         }
160     }
161     return r;
162 }
163 
saveBlobCacheLocked()164 void NNCache::saveBlobCacheLocked() {
165     if (mFilename.length() > 0 && mBlobCache != NULL) {
166         size_t cacheSize = mBlobCache->getFlattenedSize();
167         size_t headerSize = cacheFileHeaderSize;
168         const char* fname = mFilename.c_str();
169 
170         // Try to create the file with no permissions so we can write it
171         // without anyone trying to read it.
172         int fd = open(fname, O_CREAT | O_EXCL | O_RDWR, 0);
173         if (fd == -1) {
174             if (errno == EEXIST) {
175                 // The file exists, delete it and try again.
176                 if (unlink(fname) == -1) {
177                     // No point in retrying if the unlink failed.
178                     ALOGE("error unlinking cache file %s: %s (%d)", fname,
179                             strerror(errno), errno);
180                     return;
181                 }
182                 // Retry now that we've unlinked the file.
183                 fd = open(fname, O_CREAT | O_EXCL | O_RDWR, 0);
184             }
185             if (fd == -1) {
186                 ALOGE("error creating cache file %s: %s (%d)", fname,
187                         strerror(errno), errno);
188                 return;
189             }
190         }
191 
192         size_t fileSize = headerSize + cacheSize;
193 
194         uint8_t* buf = new uint8_t [fileSize];
195         if (!buf) {
196             ALOGE("error allocating buffer for cache contents: %s (%d)",
197                     strerror(errno), errno);
198             close(fd);
199             unlink(fname);
200             return;
201         }
202 
203         int err = mBlobCache->flatten(buf + headerSize, cacheSize);
204         if (err < 0) {
205             ALOGE("error writing cache contents: %s (%d)", strerror(-err),
206                     -err);
207             delete [] buf;
208             close(fd);
209             unlink(fname);
210             return;
211         }
212 
213         // Write the file magic and CRC
214         memcpy(buf, cacheFileMagic, 4);
215         uint32_t* crc = reinterpret_cast<uint32_t*>(buf + 4);
216         *crc = crc32c(buf + headerSize, cacheSize);
217 
218         if (write(fd, buf, fileSize) == -1) {
219             ALOGE("error writing cache file: %s (%d)", strerror(errno),
220                     errno);
221             delete [] buf;
222             close(fd);
223             unlink(fname);
224             return;
225         }
226 
227         delete [] buf;
228         fchmod(fd, S_IRUSR);
229         close(fd);
230     }
231 }
232 
loadBlobCacheLocked()233 void NNCache::loadBlobCacheLocked() {
234     if (mFilename.length() > 0) {
235         size_t headerSize = cacheFileHeaderSize;
236 
237         int fd = open(mFilename.c_str(), O_RDONLY, 0);
238         if (fd == -1) {
239             if (errno != ENOENT) {
240                 ALOGE("error opening cache file %s: %s (%d)", mFilename.c_str(),
241                         strerror(errno), errno);
242             }
243             return;
244         }
245 
246         struct stat statBuf;
247         if (fstat(fd, &statBuf) == -1) {
248             ALOGE("error stat'ing cache file: %s (%d)", strerror(errno), errno);
249             close(fd);
250             return;
251         }
252 
253         // Sanity check the size before trying to mmap it.
254         size_t fileSize = statBuf.st_size;
255         if (fileSize > mMaxTotalSize * 2) {
256             ALOGE("cache file is too large: %#" PRIx64,
257                   static_cast<off64_t>(statBuf.st_size));
258             close(fd);
259             return;
260         }
261 
262         uint8_t* buf = reinterpret_cast<uint8_t*>(mmap(NULL, fileSize,
263                 PROT_READ, MAP_PRIVATE, fd, 0));
264         if (buf == MAP_FAILED) {
265             ALOGE("error mmaping cache file: %s (%d)", strerror(errno),
266                     errno);
267             close(fd);
268             return;
269         }
270 
271         // Check the file magic and CRC
272         size_t cacheSize = fileSize - headerSize;
273         if (memcmp(buf, cacheFileMagic, 4) != 0) {
274             ALOGE("cache file has bad mojo");
275             close(fd);
276             return;
277         }
278         uint32_t* crc = reinterpret_cast<uint32_t*>(buf + 4);
279         if (crc32c(buf + headerSize, cacheSize) != *crc) {
280             ALOGE("cache file failed CRC check");
281             close(fd);
282             return;
283         }
284 
285         int err = mBlobCache->unflatten(buf + headerSize, cacheSize);
286         if (err < 0) {
287             ALOGE("error reading cache contents: %s (%d)", strerror(-err),
288                     -err);
289             munmap(buf, fileSize);
290             close(fd);
291             return;
292         }
293 
294         munmap(buf, fileSize);
295         close(fd);
296     }
297 }
298 
299 // ----------------------------------------------------------------------------
300 }; // namespace android
301 // ----------------------------------------------------------------------------
302