1 /*
2  * Copyright (C) 2019 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #define LOG_TAG "neuralnetworks_hidl_hal_test"
18 
19 #include <android-base/logging.h>
20 #include <fcntl.h>
21 #include <ftw.h>
22 #include <gtest/gtest.h>
23 #include <hidlmemory/mapping.h>
24 #include <unistd.h>
25 
26 #include <cstdio>
27 #include <cstdlib>
28 #include <random>
29 #include <thread>
30 
31 #include "1.2/Callbacks.h"
32 #include "GeneratedTestHarness.h"
33 #include "MemoryUtils.h"
34 #include "TestHarness.h"
35 #include "VtsHalNeuralnetworks.h"
36 
37 // Forward declaration of the mobilenet generated test models in
38 // frameworks/ml/nn/runtime/test/generated/.
39 namespace generated_tests::mobilenet_224_gender_basic_fixed {
40 const test_helper::TestModel& get_test_model();
41 }  // namespace generated_tests::mobilenet_224_gender_basic_fixed
42 
43 namespace generated_tests::mobilenet_quantized {
44 const test_helper::TestModel& get_test_model();
45 }  // namespace generated_tests::mobilenet_quantized
46 
47 namespace android::hardware::neuralnetworks::V1_2::vts::functional {
48 
49 using namespace test_helper;
50 using implementation::PreparedModelCallback;
51 using V1_0::ErrorStatus;
52 using V1_1::ExecutionPreference;
53 
54 namespace float32_model {
55 
56 constexpr auto get_test_model = generated_tests::mobilenet_224_gender_basic_fixed::get_test_model;
57 
58 }  // namespace float32_model
59 
60 namespace quant8_model {
61 
62 constexpr auto get_test_model = generated_tests::mobilenet_quantized::get_test_model;
63 
64 }  // namespace quant8_model
65 
66 namespace {
67 
68 enum class AccessMode { READ_WRITE, READ_ONLY, WRITE_ONLY };
69 
70 // Creates cache handles based on provided file groups.
71 // The outer vector corresponds to handles and the inner vector is for fds held by each handle.
createCacheHandles(const std::vector<std::vector<std::string>> & fileGroups,const std::vector<AccessMode> & mode,hidl_vec<hidl_handle> * handles)72 void createCacheHandles(const std::vector<std::vector<std::string>>& fileGroups,
73                         const std::vector<AccessMode>& mode, hidl_vec<hidl_handle>* handles) {
74     handles->resize(fileGroups.size());
75     for (uint32_t i = 0; i < fileGroups.size(); i++) {
76         std::vector<int> fds;
77         for (const auto& file : fileGroups[i]) {
78             int fd;
79             if (mode[i] == AccessMode::READ_ONLY) {
80                 fd = open(file.c_str(), O_RDONLY);
81             } else if (mode[i] == AccessMode::WRITE_ONLY) {
82                 fd = open(file.c_str(), O_WRONLY | O_CREAT, S_IRUSR | S_IWUSR);
83             } else if (mode[i] == AccessMode::READ_WRITE) {
84                 fd = open(file.c_str(), O_RDWR | O_CREAT, S_IRUSR | S_IWUSR);
85             } else {
86                 FAIL();
87             }
88             ASSERT_GE(fd, 0);
89             fds.push_back(fd);
90         }
91         native_handle_t* cacheNativeHandle = native_handle_create(fds.size(), 0);
92         ASSERT_NE(cacheNativeHandle, nullptr);
93         std::copy(fds.begin(), fds.end(), &cacheNativeHandle->data[0]);
94         (*handles)[i].setTo(cacheNativeHandle, /*shouldOwn=*/true);
95     }
96 }
97 
createCacheHandles(const std::vector<std::vector<std::string>> & fileGroups,AccessMode mode,hidl_vec<hidl_handle> * handles)98 void createCacheHandles(const std::vector<std::vector<std::string>>& fileGroups, AccessMode mode,
99                         hidl_vec<hidl_handle>* handles) {
100     createCacheHandles(fileGroups, std::vector<AccessMode>(fileGroups.size(), mode), handles);
101 }
102 
103 // Create a chain of broadcast operations. The second operand is always constant tensor [1].
104 // For simplicity, activation scalar is shared. The second operand is not shared
105 // in the model to let driver maintain a non-trivial size of constant data and the corresponding
106 // data locations in cache.
107 //
108 //                --------- activation --------
109 //                ↓      ↓      ↓             ↓
110 // E.g. input -> ADD -> ADD -> ADD -> ... -> ADD -> output
111 //                ↑      ↑      ↑             ↑
112 //               [1]    [1]    [1]           [1]
113 //
114 // This function assumes the operation is either ADD or MUL.
115 template <typename CppType, TestOperandType operandType>
createLargeTestModelImpl(TestOperationType op,uint32_t len)116 TestModel createLargeTestModelImpl(TestOperationType op, uint32_t len) {
117     EXPECT_TRUE(op == TestOperationType::ADD || op == TestOperationType::MUL);
118 
119     // Model operations and operands.
120     std::vector<TestOperation> operations(len);
121     std::vector<TestOperand> operands(len * 2 + 2);
122 
123     // The activation scalar, value = 0.
124     operands[0] = {
125             .type = TestOperandType::INT32,
126             .dimensions = {},
127             .numberOfConsumers = len,
128             .scale = 0.0f,
129             .zeroPoint = 0,
130             .lifetime = TestOperandLifeTime::CONSTANT_COPY,
131             .data = TestBuffer::createFromVector<int32_t>({0}),
132     };
133 
134     // The buffer value of the constant second operand. The logical value is always 1.0f.
135     CppType bufferValue;
136     // The scale of the first and second operand.
137     float scale1, scale2;
138     if (operandType == TestOperandType::TENSOR_FLOAT32) {
139         bufferValue = 1.0f;
140         scale1 = 0.0f;
141         scale2 = 0.0f;
142     } else if (op == TestOperationType::ADD) {
143         bufferValue = 1;
144         scale1 = 1.0f;
145         scale2 = 1.0f;
146     } else {
147         // To satisfy the constraint on quant8 MUL: input0.scale * input1.scale < output.scale,
148         // set input1 to have scale = 0.5f and bufferValue = 2, i.e. 1.0f in floating point.
149         bufferValue = 2;
150         scale1 = 1.0f;
151         scale2 = 0.5f;
152     }
153 
154     for (uint32_t i = 0; i < len; i++) {
155         const uint32_t firstInputIndex = i * 2 + 1;
156         const uint32_t secondInputIndex = firstInputIndex + 1;
157         const uint32_t outputIndex = secondInputIndex + 1;
158 
159         // The first operation input.
160         operands[firstInputIndex] = {
161                 .type = operandType,
162                 .dimensions = {1},
163                 .numberOfConsumers = 1,
164                 .scale = scale1,
165                 .zeroPoint = 0,
166                 .lifetime = (i == 0 ? TestOperandLifeTime::MODEL_INPUT
167                                     : TestOperandLifeTime::TEMPORARY_VARIABLE),
168                 .data = (i == 0 ? TestBuffer::createFromVector<CppType>({1}) : TestBuffer()),
169         };
170 
171         // The second operation input, value = 1.
172         operands[secondInputIndex] = {
173                 .type = operandType,
174                 .dimensions = {1},
175                 .numberOfConsumers = 1,
176                 .scale = scale2,
177                 .zeroPoint = 0,
178                 .lifetime = TestOperandLifeTime::CONSTANT_COPY,
179                 .data = TestBuffer::createFromVector<CppType>({bufferValue}),
180         };
181 
182         // The operation. All operations share the same activation scalar.
183         // The output operand is created as an input in the next iteration of the loop, in the case
184         // of all but the last member of the chain; and after the loop as a model output, in the
185         // case of the last member of the chain.
186         operations[i] = {
187                 .type = op,
188                 .inputs = {firstInputIndex, secondInputIndex, /*activation scalar*/ 0},
189                 .outputs = {outputIndex},
190         };
191     }
192 
193     // For TestOperationType::ADD, output = 1 + 1 * len = len + 1
194     // For TestOperationType::MUL, output = 1 * 1 ^ len = 1
195     CppType outputResult = static_cast<CppType>(op == TestOperationType::ADD ? len + 1u : 1u);
196 
197     // The model output.
198     operands.back() = {
199             .type = operandType,
200             .dimensions = {1},
201             .numberOfConsumers = 0,
202             .scale = scale1,
203             .zeroPoint = 0,
204             .lifetime = TestOperandLifeTime::MODEL_OUTPUT,
205             .data = TestBuffer::createFromVector<CppType>({outputResult}),
206     };
207 
208     return {
209             .main = {.operands = std::move(operands),
210                      .operations = std::move(operations),
211                      .inputIndexes = {1},
212                      .outputIndexes = {len * 2 + 1}},
213             .isRelaxed = false,
214     };
215 }
216 
217 }  // namespace
218 
219 // Tag for the compilation caching tests.
220 class CompilationCachingTestBase : public testing::Test {
221   protected:
CompilationCachingTestBase(sp<IDevice> device,OperandType type)222     CompilationCachingTestBase(sp<IDevice> device, OperandType type)
223         : kDevice(std::move(device)), kOperandType(type) {}
224 
SetUp()225     void SetUp() override {
226         testing::Test::SetUp();
227         ASSERT_NE(kDevice.get(), nullptr);
228         const bool deviceIsResponsive = kDevice->ping().isOk();
229         ASSERT_TRUE(deviceIsResponsive);
230 
231         // Create cache directory. The cache directory and a temporary cache file is always created
232         // to test the behavior of prepareModelFromCache, even when caching is not supported.
233         char cacheDirTemp[] = "/data/local/tmp/TestCompilationCachingXXXXXX";
234         char* cacheDir = mkdtemp(cacheDirTemp);
235         ASSERT_NE(cacheDir, nullptr);
236         mCacheDir = cacheDir;
237         mCacheDir.push_back('/');
238 
239         Return<void> ret = kDevice->getNumberOfCacheFilesNeeded(
240                 [this](ErrorStatus status, uint32_t numModelCache, uint32_t numDataCache) {
241                     EXPECT_EQ(ErrorStatus::NONE, status);
242                     mNumModelCache = numModelCache;
243                     mNumDataCache = numDataCache;
244                 });
245         EXPECT_TRUE(ret.isOk());
246         mIsCachingSupported = mNumModelCache > 0 || mNumDataCache > 0;
247 
248         // Create empty cache files.
249         mTmpCache = mCacheDir + "tmp";
250         for (uint32_t i = 0; i < mNumModelCache; i++) {
251             mModelCache.push_back({mCacheDir + "model" + std::to_string(i)});
252         }
253         for (uint32_t i = 0; i < mNumDataCache; i++) {
254             mDataCache.push_back({mCacheDir + "data" + std::to_string(i)});
255         }
256         // Sample handles, use AccessMode::WRITE_ONLY for createCacheHandles to create files.
257         hidl_vec<hidl_handle> modelHandle, dataHandle, tmpHandle;
258         createCacheHandles(mModelCache, AccessMode::WRITE_ONLY, &modelHandle);
259         createCacheHandles(mDataCache, AccessMode::WRITE_ONLY, &dataHandle);
260         createCacheHandles({{mTmpCache}}, AccessMode::WRITE_ONLY, &tmpHandle);
261 
262         if (!mIsCachingSupported) {
263             LOG(INFO) << "NN VTS: Early termination of test because vendor service does not "
264                          "support compilation caching.";
265             std::cout << "[          ]   Early termination of test because vendor service does not "
266                          "support compilation caching."
267                       << std::endl;
268         }
269     }
270 
TearDown()271     void TearDown() override {
272         // If the test passes, remove the tmp directory.  Otherwise, keep it for debugging purposes.
273         if (!testing::Test::HasFailure()) {
274             // Recursively remove the cache directory specified by mCacheDir.
275             auto callback = [](const char* entry, const struct stat*, int, struct FTW*) {
276                 return remove(entry);
277             };
278             nftw(mCacheDir.c_str(), callback, 128, FTW_DEPTH | FTW_MOUNT | FTW_PHYS);
279         }
280         testing::Test::TearDown();
281     }
282 
283     // Model and examples creators. According to kOperandType, the following methods will return
284     // either float32 model/examples or the quant8 variant.
createTestModel()285     TestModel createTestModel() {
286         if (kOperandType == OperandType::TENSOR_FLOAT32) {
287             return float32_model::get_test_model();
288         } else {
289             return quant8_model::get_test_model();
290         }
291     }
292 
createLargeTestModel(OperationType op,uint32_t len)293     TestModel createLargeTestModel(OperationType op, uint32_t len) {
294         if (kOperandType == OperandType::TENSOR_FLOAT32) {
295             return createLargeTestModelImpl<float, TestOperandType::TENSOR_FLOAT32>(
296                     static_cast<TestOperationType>(op), len);
297         } else {
298             return createLargeTestModelImpl<uint8_t, TestOperandType::TENSOR_QUANT8_ASYMM>(
299                     static_cast<TestOperationType>(op), len);
300         }
301     }
302 
303     // See if the service can handle the model.
isModelFullySupported(const Model & model)304     bool isModelFullySupported(const Model& model) {
305         bool fullySupportsModel = false;
306         Return<void> supportedCall = kDevice->getSupportedOperations_1_2(
307                 model,
308                 [&fullySupportsModel, &model](ErrorStatus status, const hidl_vec<bool>& supported) {
309                     ASSERT_EQ(ErrorStatus::NONE, status);
310                     ASSERT_EQ(supported.size(), model.operations.size());
311                     fullySupportsModel = std::all_of(supported.begin(), supported.end(),
312                                                      [](bool valid) { return valid; });
313                 });
314         EXPECT_TRUE(supportedCall.isOk());
315         return fullySupportsModel;
316     }
317 
saveModelToCache(const Model & model,const hidl_vec<hidl_handle> & modelCache,const hidl_vec<hidl_handle> & dataCache,sp<IPreparedModel> * preparedModel=nullptr)318     void saveModelToCache(const Model& model, const hidl_vec<hidl_handle>& modelCache,
319                           const hidl_vec<hidl_handle>& dataCache,
320                           sp<IPreparedModel>* preparedModel = nullptr) {
321         if (preparedModel != nullptr) *preparedModel = nullptr;
322 
323         // Launch prepare model.
324         sp<PreparedModelCallback> preparedModelCallback = new PreparedModelCallback();
325         hidl_array<uint8_t, sizeof(mToken)> cacheToken(mToken);
326         Return<ErrorStatus> prepareLaunchStatus =
327                 kDevice->prepareModel_1_2(model, ExecutionPreference::FAST_SINGLE_ANSWER,
328                                           modelCache, dataCache, cacheToken, preparedModelCallback);
329         ASSERT_TRUE(prepareLaunchStatus.isOk());
330         ASSERT_EQ(static_cast<ErrorStatus>(prepareLaunchStatus), ErrorStatus::NONE);
331 
332         // Retrieve prepared model.
333         preparedModelCallback->wait();
334         ASSERT_EQ(preparedModelCallback->getStatus(), ErrorStatus::NONE);
335         if (preparedModel != nullptr) {
336             *preparedModel = IPreparedModel::castFrom(preparedModelCallback->getPreparedModel())
337                                      .withDefault(nullptr);
338         }
339     }
340 
checkEarlyTermination(ErrorStatus status)341     bool checkEarlyTermination(ErrorStatus status) {
342         if (status == ErrorStatus::GENERAL_FAILURE) {
343             LOG(INFO) << "NN VTS: Early termination of test because vendor service cannot "
344                          "save the prepared model that it does not support.";
345             std::cout << "[          ]   Early termination of test because vendor service cannot "
346                          "save the prepared model that it does not support."
347                       << std::endl;
348             return true;
349         }
350         return false;
351     }
352 
checkEarlyTermination(const Model & model)353     bool checkEarlyTermination(const Model& model) {
354         if (!isModelFullySupported(model)) {
355             LOG(INFO) << "NN VTS: Early termination of test because vendor service cannot "
356                          "prepare model that it does not support.";
357             std::cout << "[          ]   Early termination of test because vendor service cannot "
358                          "prepare model that it does not support."
359                       << std::endl;
360             return true;
361         }
362         return false;
363     }
364 
prepareModelFromCache(const hidl_vec<hidl_handle> & modelCache,const hidl_vec<hidl_handle> & dataCache,sp<IPreparedModel> * preparedModel,ErrorStatus * status)365     void prepareModelFromCache(const hidl_vec<hidl_handle>& modelCache,
366                                const hidl_vec<hidl_handle>& dataCache,
367                                sp<IPreparedModel>* preparedModel, ErrorStatus* status) {
368         // Launch prepare model from cache.
369         sp<PreparedModelCallback> preparedModelCallback = new PreparedModelCallback();
370         hidl_array<uint8_t, sizeof(mToken)> cacheToken(mToken);
371         Return<ErrorStatus> prepareLaunchStatus = kDevice->prepareModelFromCache(
372                 modelCache, dataCache, cacheToken, preparedModelCallback);
373         ASSERT_TRUE(prepareLaunchStatus.isOk());
374         if (static_cast<ErrorStatus>(prepareLaunchStatus) != ErrorStatus::NONE) {
375             *preparedModel = nullptr;
376             *status = static_cast<ErrorStatus>(prepareLaunchStatus);
377             return;
378         }
379 
380         // Retrieve prepared model.
381         preparedModelCallback->wait();
382         *status = preparedModelCallback->getStatus();
383         *preparedModel = IPreparedModel::castFrom(preparedModelCallback->getPreparedModel())
384                                  .withDefault(nullptr);
385     }
386 
387     // Absolute path to the temporary cache directory.
388     std::string mCacheDir;
389 
390     // Groups of file paths for model and data cache in the tmp cache directory, initialized with
391     // outer_size = mNum{Model|Data}Cache, inner_size = 1. The outer vector corresponds to handles
392     // and the inner vector is for fds held by each handle.
393     std::vector<std::vector<std::string>> mModelCache;
394     std::vector<std::vector<std::string>> mDataCache;
395 
396     // A separate temporary file path in the tmp cache directory.
397     std::string mTmpCache;
398 
399     uint8_t mToken[static_cast<uint32_t>(Constant::BYTE_SIZE_OF_CACHE_TOKEN)] = {};
400     uint32_t mNumModelCache;
401     uint32_t mNumDataCache;
402     uint32_t mIsCachingSupported;
403 
404     const sp<IDevice> kDevice;
405     // The primary data type of the testModel.
406     const OperandType kOperandType;
407 };
408 
409 using CompilationCachingTestParam = std::tuple<NamedDevice, OperandType>;
410 
411 // A parameterized fixture of CompilationCachingTestBase. Every test will run twice, with the first
412 // pass running with float32 models and the second pass running with quant8 models.
413 class CompilationCachingTest : public CompilationCachingTestBase,
414                                public testing::WithParamInterface<CompilationCachingTestParam> {
415   protected:
CompilationCachingTest()416     CompilationCachingTest()
417         : CompilationCachingTestBase(getData(std::get<NamedDevice>(GetParam())),
418                                      std::get<OperandType>(GetParam())) {}
419 };
420 
TEST_P(CompilationCachingTest,CacheSavingAndRetrieval)421 TEST_P(CompilationCachingTest, CacheSavingAndRetrieval) {
422     // Create test HIDL model and compile.
423     const TestModel& testModel = createTestModel();
424     const Model model = createModel(testModel);
425     if (checkEarlyTermination(model)) return;
426     sp<IPreparedModel> preparedModel = nullptr;
427 
428     // Save the compilation to cache.
429     {
430         hidl_vec<hidl_handle> modelCache, dataCache;
431         createCacheHandles(mModelCache, AccessMode::READ_WRITE, &modelCache);
432         createCacheHandles(mDataCache, AccessMode::READ_WRITE, &dataCache);
433         saveModelToCache(model, modelCache, dataCache);
434     }
435 
436     // Retrieve preparedModel from cache.
437     {
438         preparedModel = nullptr;
439         ErrorStatus status;
440         hidl_vec<hidl_handle> modelCache, dataCache;
441         createCacheHandles(mModelCache, AccessMode::READ_WRITE, &modelCache);
442         createCacheHandles(mDataCache, AccessMode::READ_WRITE, &dataCache);
443         prepareModelFromCache(modelCache, dataCache, &preparedModel, &status);
444         if (!mIsCachingSupported) {
445             ASSERT_EQ(status, ErrorStatus::GENERAL_FAILURE);
446             ASSERT_EQ(preparedModel, nullptr);
447             return;
448         } else if (checkEarlyTermination(status)) {
449             ASSERT_EQ(preparedModel, nullptr);
450             return;
451         } else {
452             ASSERT_EQ(status, ErrorStatus::NONE);
453             ASSERT_NE(preparedModel, nullptr);
454         }
455     }
456 
457     // Execute and verify results.
458     EvaluatePreparedModel(preparedModel, testModel,
459                           /*testDynamicOutputShape=*/false);
460 }
461 
TEST_P(CompilationCachingTest,CacheSavingAndRetrievalNonZeroOffset)462 TEST_P(CompilationCachingTest, CacheSavingAndRetrievalNonZeroOffset) {
463     // Create test HIDL model and compile.
464     const TestModel& testModel = createTestModel();
465     const Model model = createModel(testModel);
466     if (checkEarlyTermination(model)) return;
467     sp<IPreparedModel> preparedModel = nullptr;
468 
469     // Save the compilation to cache.
470     {
471         hidl_vec<hidl_handle> modelCache, dataCache;
472         createCacheHandles(mModelCache, AccessMode::READ_WRITE, &modelCache);
473         createCacheHandles(mDataCache, AccessMode::READ_WRITE, &dataCache);
474         uint8_t sampleBytes[] = {0, 0};
475         // Write a sample integer to the cache.
476         // The driver should be able to handle non-empty cache and non-zero fd offset.
477         for (uint32_t i = 0; i < modelCache.size(); i++) {
478             ASSERT_EQ(write(modelCache[i].getNativeHandle()->data[0], &sampleBytes,
479                             sizeof(sampleBytes)),
480                       sizeof(sampleBytes));
481         }
482         for (uint32_t i = 0; i < dataCache.size(); i++) {
483             ASSERT_EQ(
484                     write(dataCache[i].getNativeHandle()->data[0], &sampleBytes, sizeof(sampleBytes)),
485                     sizeof(sampleBytes));
486         }
487         saveModelToCache(model, modelCache, dataCache);
488     }
489 
490     // Retrieve preparedModel from cache.
491     {
492         preparedModel = nullptr;
493         ErrorStatus status;
494         hidl_vec<hidl_handle> modelCache, dataCache;
495         createCacheHandles(mModelCache, AccessMode::READ_WRITE, &modelCache);
496         createCacheHandles(mDataCache, AccessMode::READ_WRITE, &dataCache);
497         uint8_t sampleByte = 0;
498         // Advance the offset of each handle by one byte.
499         // The driver should be able to handle non-zero fd offset.
500         for (uint32_t i = 0; i < modelCache.size(); i++) {
501             ASSERT_GE(read(modelCache[i].getNativeHandle()->data[0], &sampleByte, 1), 0);
502         }
503         for (uint32_t i = 0; i < dataCache.size(); i++) {
504             ASSERT_GE(read(dataCache[i].getNativeHandle()->data[0], &sampleByte, 1), 0);
505         }
506         prepareModelFromCache(modelCache, dataCache, &preparedModel, &status);
507         if (!mIsCachingSupported) {
508             ASSERT_EQ(status, ErrorStatus::GENERAL_FAILURE);
509             ASSERT_EQ(preparedModel, nullptr);
510             return;
511         } else if (checkEarlyTermination(status)) {
512             ASSERT_EQ(preparedModel, nullptr);
513             return;
514         } else {
515             ASSERT_EQ(status, ErrorStatus::NONE);
516             ASSERT_NE(preparedModel, nullptr);
517         }
518     }
519 
520     // Execute and verify results.
521     EvaluatePreparedModel(preparedModel, testModel,
522                           /*testDynamicOutputShape=*/false);
523 }
524 
TEST_P(CompilationCachingTest,SaveToCacheInvalidNumCache)525 TEST_P(CompilationCachingTest, SaveToCacheInvalidNumCache) {
526     // Create test HIDL model and compile.
527     const TestModel& testModel = createTestModel();
528     const Model model = createModel(testModel);
529     if (checkEarlyTermination(model)) return;
530 
531     // Test with number of model cache files greater than mNumModelCache.
532     {
533         hidl_vec<hidl_handle> modelCache, dataCache;
534         // Pass an additional cache file for model cache.
535         mModelCache.push_back({mTmpCache});
536         createCacheHandles(mModelCache, AccessMode::READ_WRITE, &modelCache);
537         createCacheHandles(mDataCache, AccessMode::READ_WRITE, &dataCache);
538         mModelCache.pop_back();
539         sp<IPreparedModel> preparedModel = nullptr;
540         saveModelToCache(model, modelCache, dataCache, &preparedModel);
541         ASSERT_NE(preparedModel, nullptr);
542         // Execute and verify results.
543         EvaluatePreparedModel(preparedModel, testModel,
544                               /*testDynamicOutputShape=*/false);
545         // Check if prepareModelFromCache fails.
546         preparedModel = nullptr;
547         ErrorStatus status;
548         prepareModelFromCache(modelCache, dataCache, &preparedModel, &status);
549         if (status != ErrorStatus::INVALID_ARGUMENT) {
550             ASSERT_EQ(status, ErrorStatus::GENERAL_FAILURE);
551         }
552         ASSERT_EQ(preparedModel, nullptr);
553     }
554 
555     // Test with number of model cache files smaller than mNumModelCache.
556     if (mModelCache.size() > 0) {
557         hidl_vec<hidl_handle> modelCache, dataCache;
558         // Pop out the last cache file.
559         auto tmp = mModelCache.back();
560         mModelCache.pop_back();
561         createCacheHandles(mModelCache, AccessMode::READ_WRITE, &modelCache);
562         createCacheHandles(mDataCache, AccessMode::READ_WRITE, &dataCache);
563         mModelCache.push_back(tmp);
564         sp<IPreparedModel> preparedModel = nullptr;
565         saveModelToCache(model, modelCache, dataCache, &preparedModel);
566         ASSERT_NE(preparedModel, nullptr);
567         // Execute and verify results.
568         EvaluatePreparedModel(preparedModel, testModel,
569                               /*testDynamicOutputShape=*/false);
570         // Check if prepareModelFromCache fails.
571         preparedModel = nullptr;
572         ErrorStatus status;
573         prepareModelFromCache(modelCache, dataCache, &preparedModel, &status);
574         if (status != ErrorStatus::INVALID_ARGUMENT) {
575             ASSERT_EQ(status, ErrorStatus::GENERAL_FAILURE);
576         }
577         ASSERT_EQ(preparedModel, nullptr);
578     }
579 
580     // Test with number of data cache files greater than mNumDataCache.
581     {
582         hidl_vec<hidl_handle> modelCache, dataCache;
583         // Pass an additional cache file for data cache.
584         mDataCache.push_back({mTmpCache});
585         createCacheHandles(mModelCache, AccessMode::READ_WRITE, &modelCache);
586         createCacheHandles(mDataCache, AccessMode::READ_WRITE, &dataCache);
587         mDataCache.pop_back();
588         sp<IPreparedModel> preparedModel = nullptr;
589         saveModelToCache(model, modelCache, dataCache, &preparedModel);
590         ASSERT_NE(preparedModel, nullptr);
591         // Execute and verify results.
592         EvaluatePreparedModel(preparedModel, testModel,
593                               /*testDynamicOutputShape=*/false);
594         // Check if prepareModelFromCache fails.
595         preparedModel = nullptr;
596         ErrorStatus status;
597         prepareModelFromCache(modelCache, dataCache, &preparedModel, &status);
598         if (status != ErrorStatus::INVALID_ARGUMENT) {
599             ASSERT_EQ(status, ErrorStatus::GENERAL_FAILURE);
600         }
601         ASSERT_EQ(preparedModel, nullptr);
602     }
603 
604     // Test with number of data cache files smaller than mNumDataCache.
605     if (mDataCache.size() > 0) {
606         hidl_vec<hidl_handle> modelCache, dataCache;
607         // Pop out the last cache file.
608         auto tmp = mDataCache.back();
609         mDataCache.pop_back();
610         createCacheHandles(mModelCache, AccessMode::READ_WRITE, &modelCache);
611         createCacheHandles(mDataCache, AccessMode::READ_WRITE, &dataCache);
612         mDataCache.push_back(tmp);
613         sp<IPreparedModel> preparedModel = nullptr;
614         saveModelToCache(model, modelCache, dataCache, &preparedModel);
615         ASSERT_NE(preparedModel, nullptr);
616         // Execute and verify results.
617         EvaluatePreparedModel(preparedModel, testModel,
618                               /*testDynamicOutputShape=*/false);
619         // Check if prepareModelFromCache fails.
620         preparedModel = nullptr;
621         ErrorStatus status;
622         prepareModelFromCache(modelCache, dataCache, &preparedModel, &status);
623         if (status != ErrorStatus::INVALID_ARGUMENT) {
624             ASSERT_EQ(status, ErrorStatus::GENERAL_FAILURE);
625         }
626         ASSERT_EQ(preparedModel, nullptr);
627     }
628 }
629 
TEST_P(CompilationCachingTest,PrepareModelFromCacheInvalidNumCache)630 TEST_P(CompilationCachingTest, PrepareModelFromCacheInvalidNumCache) {
631     // Create test HIDL model and compile.
632     const TestModel& testModel = createTestModel();
633     const Model model = createModel(testModel);
634     if (checkEarlyTermination(model)) return;
635 
636     // Save the compilation to cache.
637     {
638         hidl_vec<hidl_handle> modelCache, dataCache;
639         createCacheHandles(mModelCache, AccessMode::READ_WRITE, &modelCache);
640         createCacheHandles(mDataCache, AccessMode::READ_WRITE, &dataCache);
641         saveModelToCache(model, modelCache, dataCache);
642     }
643 
644     // Test with number of model cache files greater than mNumModelCache.
645     {
646         sp<IPreparedModel> preparedModel = nullptr;
647         ErrorStatus status;
648         hidl_vec<hidl_handle> modelCache, dataCache;
649         mModelCache.push_back({mTmpCache});
650         createCacheHandles(mModelCache, AccessMode::READ_WRITE, &modelCache);
651         createCacheHandles(mDataCache, AccessMode::READ_WRITE, &dataCache);
652         mModelCache.pop_back();
653         prepareModelFromCache(modelCache, dataCache, &preparedModel, &status);
654         if (status != ErrorStatus::GENERAL_FAILURE) {
655             ASSERT_EQ(status, ErrorStatus::INVALID_ARGUMENT);
656         }
657         ASSERT_EQ(preparedModel, nullptr);
658     }
659 
660     // Test with number of model cache files smaller than mNumModelCache.
661     if (mModelCache.size() > 0) {
662         sp<IPreparedModel> preparedModel = nullptr;
663         ErrorStatus status;
664         hidl_vec<hidl_handle> modelCache, dataCache;
665         auto tmp = mModelCache.back();
666         mModelCache.pop_back();
667         createCacheHandles(mModelCache, AccessMode::READ_WRITE, &modelCache);
668         createCacheHandles(mDataCache, AccessMode::READ_WRITE, &dataCache);
669         mModelCache.push_back(tmp);
670         prepareModelFromCache(modelCache, dataCache, &preparedModel, &status);
671         if (status != ErrorStatus::GENERAL_FAILURE) {
672             ASSERT_EQ(status, ErrorStatus::INVALID_ARGUMENT);
673         }
674         ASSERT_EQ(preparedModel, nullptr);
675     }
676 
677     // Test with number of data cache files greater than mNumDataCache.
678     {
679         sp<IPreparedModel> preparedModel = nullptr;
680         ErrorStatus status;
681         hidl_vec<hidl_handle> modelCache, dataCache;
682         mDataCache.push_back({mTmpCache});
683         createCacheHandles(mModelCache, AccessMode::READ_WRITE, &modelCache);
684         createCacheHandles(mDataCache, AccessMode::READ_WRITE, &dataCache);
685         mDataCache.pop_back();
686         prepareModelFromCache(modelCache, dataCache, &preparedModel, &status);
687         if (status != ErrorStatus::GENERAL_FAILURE) {
688             ASSERT_EQ(status, ErrorStatus::INVALID_ARGUMENT);
689         }
690         ASSERT_EQ(preparedModel, nullptr);
691     }
692 
693     // Test with number of data cache files smaller than mNumDataCache.
694     if (mDataCache.size() > 0) {
695         sp<IPreparedModel> preparedModel = nullptr;
696         ErrorStatus status;
697         hidl_vec<hidl_handle> modelCache, dataCache;
698         auto tmp = mDataCache.back();
699         mDataCache.pop_back();
700         createCacheHandles(mModelCache, AccessMode::READ_WRITE, &modelCache);
701         createCacheHandles(mDataCache, AccessMode::READ_WRITE, &dataCache);
702         mDataCache.push_back(tmp);
703         prepareModelFromCache(modelCache, dataCache, &preparedModel, &status);
704         if (status != ErrorStatus::GENERAL_FAILURE) {
705             ASSERT_EQ(status, ErrorStatus::INVALID_ARGUMENT);
706         }
707         ASSERT_EQ(preparedModel, nullptr);
708     }
709 }
710 
TEST_P(CompilationCachingTest,SaveToCacheInvalidNumFd)711 TEST_P(CompilationCachingTest, SaveToCacheInvalidNumFd) {
712     // Create test HIDL model and compile.
713     const TestModel& testModel = createTestModel();
714     const Model model = createModel(testModel);
715     if (checkEarlyTermination(model)) return;
716 
717     // Go through each handle in model cache, test with NumFd greater than 1.
718     for (uint32_t i = 0; i < mNumModelCache; i++) {
719         hidl_vec<hidl_handle> modelCache, dataCache;
720         // Pass an invalid number of fds for handle i.
721         mModelCache[i].push_back(mTmpCache);
722         createCacheHandles(mModelCache, AccessMode::READ_WRITE, &modelCache);
723         createCacheHandles(mDataCache, AccessMode::READ_WRITE, &dataCache);
724         mModelCache[i].pop_back();
725         sp<IPreparedModel> preparedModel = nullptr;
726         saveModelToCache(model, modelCache, dataCache, &preparedModel);
727         ASSERT_NE(preparedModel, nullptr);
728         // Execute and verify results.
729         EvaluatePreparedModel(preparedModel, testModel,
730                               /*testDynamicOutputShape=*/false);
731         // Check if prepareModelFromCache fails.
732         preparedModel = nullptr;
733         ErrorStatus status;
734         prepareModelFromCache(modelCache, dataCache, &preparedModel, &status);
735         if (status != ErrorStatus::INVALID_ARGUMENT) {
736             ASSERT_EQ(status, ErrorStatus::GENERAL_FAILURE);
737         }
738         ASSERT_EQ(preparedModel, nullptr);
739     }
740 
741     // Go through each handle in model cache, test with NumFd equal to 0.
742     for (uint32_t i = 0; i < mNumModelCache; i++) {
743         hidl_vec<hidl_handle> modelCache, dataCache;
744         // Pass an invalid number of fds for handle i.
745         auto tmp = mModelCache[i].back();
746         mModelCache[i].pop_back();
747         createCacheHandles(mModelCache, AccessMode::READ_WRITE, &modelCache);
748         createCacheHandles(mDataCache, AccessMode::READ_WRITE, &dataCache);
749         mModelCache[i].push_back(tmp);
750         sp<IPreparedModel> preparedModel = nullptr;
751         saveModelToCache(model, modelCache, dataCache, &preparedModel);
752         ASSERT_NE(preparedModel, nullptr);
753         // Execute and verify results.
754         EvaluatePreparedModel(preparedModel, testModel,
755                               /*testDynamicOutputShape=*/false);
756         // Check if prepareModelFromCache fails.
757         preparedModel = nullptr;
758         ErrorStatus status;
759         prepareModelFromCache(modelCache, dataCache, &preparedModel, &status);
760         if (status != ErrorStatus::INVALID_ARGUMENT) {
761             ASSERT_EQ(status, ErrorStatus::GENERAL_FAILURE);
762         }
763         ASSERT_EQ(preparedModel, nullptr);
764     }
765 
766     // Go through each handle in data cache, test with NumFd greater than 1.
767     for (uint32_t i = 0; i < mNumDataCache; i++) {
768         hidl_vec<hidl_handle> modelCache, dataCache;
769         // Pass an invalid number of fds for handle i.
770         mDataCache[i].push_back(mTmpCache);
771         createCacheHandles(mModelCache, AccessMode::READ_WRITE, &modelCache);
772         createCacheHandles(mDataCache, AccessMode::READ_WRITE, &dataCache);
773         mDataCache[i].pop_back();
774         sp<IPreparedModel> preparedModel = nullptr;
775         saveModelToCache(model, modelCache, dataCache, &preparedModel);
776         ASSERT_NE(preparedModel, nullptr);
777         // Execute and verify results.
778         EvaluatePreparedModel(preparedModel, testModel,
779                               /*testDynamicOutputShape=*/false);
780         // Check if prepareModelFromCache fails.
781         preparedModel = nullptr;
782         ErrorStatus status;
783         prepareModelFromCache(modelCache, dataCache, &preparedModel, &status);
784         if (status != ErrorStatus::INVALID_ARGUMENT) {
785             ASSERT_EQ(status, ErrorStatus::GENERAL_FAILURE);
786         }
787         ASSERT_EQ(preparedModel, nullptr);
788     }
789 
790     // Go through each handle in data cache, test with NumFd equal to 0.
791     for (uint32_t i = 0; i < mNumDataCache; i++) {
792         hidl_vec<hidl_handle> modelCache, dataCache;
793         // Pass an invalid number of fds for handle i.
794         auto tmp = mDataCache[i].back();
795         mDataCache[i].pop_back();
796         createCacheHandles(mModelCache, AccessMode::READ_WRITE, &modelCache);
797         createCacheHandles(mDataCache, AccessMode::READ_WRITE, &dataCache);
798         mDataCache[i].push_back(tmp);
799         sp<IPreparedModel> preparedModel = nullptr;
800         saveModelToCache(model, modelCache, dataCache, &preparedModel);
801         ASSERT_NE(preparedModel, nullptr);
802         // Execute and verify results.
803         EvaluatePreparedModel(preparedModel, testModel,
804                               /*testDynamicOutputShape=*/false);
805         // Check if prepareModelFromCache fails.
806         preparedModel = nullptr;
807         ErrorStatus status;
808         prepareModelFromCache(modelCache, dataCache, &preparedModel, &status);
809         if (status != ErrorStatus::INVALID_ARGUMENT) {
810             ASSERT_EQ(status, ErrorStatus::GENERAL_FAILURE);
811         }
812         ASSERT_EQ(preparedModel, nullptr);
813     }
814 }
815 
TEST_P(CompilationCachingTest,PrepareModelFromCacheInvalidNumFd)816 TEST_P(CompilationCachingTest, PrepareModelFromCacheInvalidNumFd) {
817     // Create test HIDL model and compile.
818     const TestModel& testModel = createTestModel();
819     const Model model = createModel(testModel);
820     if (checkEarlyTermination(model)) return;
821 
822     // Save the compilation to cache.
823     {
824         hidl_vec<hidl_handle> modelCache, dataCache;
825         createCacheHandles(mModelCache, AccessMode::READ_WRITE, &modelCache);
826         createCacheHandles(mDataCache, AccessMode::READ_WRITE, &dataCache);
827         saveModelToCache(model, modelCache, dataCache);
828     }
829 
830     // Go through each handle in model cache, test with NumFd greater than 1.
831     for (uint32_t i = 0; i < mNumModelCache; i++) {
832         sp<IPreparedModel> preparedModel = nullptr;
833         ErrorStatus status;
834         hidl_vec<hidl_handle> modelCache, dataCache;
835         mModelCache[i].push_back(mTmpCache);
836         createCacheHandles(mModelCache, AccessMode::READ_WRITE, &modelCache);
837         createCacheHandles(mDataCache, AccessMode::READ_WRITE, &dataCache);
838         mModelCache[i].pop_back();
839         prepareModelFromCache(modelCache, dataCache, &preparedModel, &status);
840         if (status != ErrorStatus::GENERAL_FAILURE) {
841             ASSERT_EQ(status, ErrorStatus::INVALID_ARGUMENT);
842         }
843         ASSERT_EQ(preparedModel, nullptr);
844     }
845 
846     // Go through each handle in model cache, test with NumFd equal to 0.
847     for (uint32_t i = 0; i < mNumModelCache; i++) {
848         sp<IPreparedModel> preparedModel = nullptr;
849         ErrorStatus status;
850         hidl_vec<hidl_handle> modelCache, dataCache;
851         auto tmp = mModelCache[i].back();
852         mModelCache[i].pop_back();
853         createCacheHandles(mModelCache, AccessMode::READ_WRITE, &modelCache);
854         createCacheHandles(mDataCache, AccessMode::READ_WRITE, &dataCache);
855         mModelCache[i].push_back(tmp);
856         prepareModelFromCache(modelCache, dataCache, &preparedModel, &status);
857         if (status != ErrorStatus::GENERAL_FAILURE) {
858             ASSERT_EQ(status, ErrorStatus::INVALID_ARGUMENT);
859         }
860         ASSERT_EQ(preparedModel, nullptr);
861     }
862 
863     // Go through each handle in data cache, test with NumFd greater than 1.
864     for (uint32_t i = 0; i < mNumDataCache; i++) {
865         sp<IPreparedModel> preparedModel = nullptr;
866         ErrorStatus status;
867         hidl_vec<hidl_handle> modelCache, dataCache;
868         mDataCache[i].push_back(mTmpCache);
869         createCacheHandles(mModelCache, AccessMode::READ_WRITE, &modelCache);
870         createCacheHandles(mDataCache, AccessMode::READ_WRITE, &dataCache);
871         mDataCache[i].pop_back();
872         prepareModelFromCache(modelCache, dataCache, &preparedModel, &status);
873         if (status != ErrorStatus::GENERAL_FAILURE) {
874             ASSERT_EQ(status, ErrorStatus::INVALID_ARGUMENT);
875         }
876         ASSERT_EQ(preparedModel, nullptr);
877     }
878 
879     // Go through each handle in data cache, test with NumFd equal to 0.
880     for (uint32_t i = 0; i < mNumDataCache; i++) {
881         sp<IPreparedModel> preparedModel = nullptr;
882         ErrorStatus status;
883         hidl_vec<hidl_handle> modelCache, dataCache;
884         auto tmp = mDataCache[i].back();
885         mDataCache[i].pop_back();
886         createCacheHandles(mModelCache, AccessMode::READ_WRITE, &modelCache);
887         createCacheHandles(mDataCache, AccessMode::READ_WRITE, &dataCache);
888         mDataCache[i].push_back(tmp);
889         prepareModelFromCache(modelCache, dataCache, &preparedModel, &status);
890         if (status != ErrorStatus::GENERAL_FAILURE) {
891             ASSERT_EQ(status, ErrorStatus::INVALID_ARGUMENT);
892         }
893         ASSERT_EQ(preparedModel, nullptr);
894     }
895 }
896 
TEST_P(CompilationCachingTest,SaveToCacheInvalidAccessMode)897 TEST_P(CompilationCachingTest, SaveToCacheInvalidAccessMode) {
898     // Create test HIDL model and compile.
899     const TestModel& testModel = createTestModel();
900     const Model model = createModel(testModel);
901     if (checkEarlyTermination(model)) return;
902     std::vector<AccessMode> modelCacheMode(mNumModelCache, AccessMode::READ_WRITE);
903     std::vector<AccessMode> dataCacheMode(mNumDataCache, AccessMode::READ_WRITE);
904 
905     // Go through each handle in model cache, test with invalid access mode.
906     for (uint32_t i = 0; i < mNumModelCache; i++) {
907         hidl_vec<hidl_handle> modelCache, dataCache;
908         modelCacheMode[i] = AccessMode::READ_ONLY;
909         createCacheHandles(mModelCache, modelCacheMode, &modelCache);
910         createCacheHandles(mDataCache, dataCacheMode, &dataCache);
911         modelCacheMode[i] = AccessMode::READ_WRITE;
912         sp<IPreparedModel> preparedModel = nullptr;
913         saveModelToCache(model, modelCache, dataCache, &preparedModel);
914         ASSERT_NE(preparedModel, nullptr);
915         // Execute and verify results.
916         EvaluatePreparedModel(preparedModel, testModel,
917                               /*testDynamicOutputShape=*/false);
918         // Check if prepareModelFromCache fails.
919         preparedModel = nullptr;
920         ErrorStatus status;
921         prepareModelFromCache(modelCache, dataCache, &preparedModel, &status);
922         if (status != ErrorStatus::INVALID_ARGUMENT) {
923             ASSERT_EQ(status, ErrorStatus::GENERAL_FAILURE);
924         }
925         ASSERT_EQ(preparedModel, nullptr);
926     }
927 
928     // Go through each handle in data cache, test with invalid access mode.
929     for (uint32_t i = 0; i < mNumDataCache; i++) {
930         hidl_vec<hidl_handle> modelCache, dataCache;
931         dataCacheMode[i] = AccessMode::READ_ONLY;
932         createCacheHandles(mModelCache, modelCacheMode, &modelCache);
933         createCacheHandles(mDataCache, dataCacheMode, &dataCache);
934         dataCacheMode[i] = AccessMode::READ_WRITE;
935         sp<IPreparedModel> preparedModel = nullptr;
936         saveModelToCache(model, modelCache, dataCache, &preparedModel);
937         ASSERT_NE(preparedModel, nullptr);
938         // Execute and verify results.
939         EvaluatePreparedModel(preparedModel, testModel,
940                               /*testDynamicOutputShape=*/false);
941         // Check if prepareModelFromCache fails.
942         preparedModel = nullptr;
943         ErrorStatus status;
944         prepareModelFromCache(modelCache, dataCache, &preparedModel, &status);
945         if (status != ErrorStatus::INVALID_ARGUMENT) {
946             ASSERT_EQ(status, ErrorStatus::GENERAL_FAILURE);
947         }
948         ASSERT_EQ(preparedModel, nullptr);
949     }
950 }
951 
TEST_P(CompilationCachingTest,PrepareModelFromCacheInvalidAccessMode)952 TEST_P(CompilationCachingTest, PrepareModelFromCacheInvalidAccessMode) {
953     // Create test HIDL model and compile.
954     const TestModel& testModel = createTestModel();
955     const Model model = createModel(testModel);
956     if (checkEarlyTermination(model)) return;
957     std::vector<AccessMode> modelCacheMode(mNumModelCache, AccessMode::READ_WRITE);
958     std::vector<AccessMode> dataCacheMode(mNumDataCache, AccessMode::READ_WRITE);
959 
960     // Save the compilation to cache.
961     {
962         hidl_vec<hidl_handle> modelCache, dataCache;
963         createCacheHandles(mModelCache, AccessMode::READ_WRITE, &modelCache);
964         createCacheHandles(mDataCache, AccessMode::READ_WRITE, &dataCache);
965         saveModelToCache(model, modelCache, dataCache);
966     }
967 
968     // Go through each handle in model cache, test with invalid access mode.
969     for (uint32_t i = 0; i < mNumModelCache; i++) {
970         sp<IPreparedModel> preparedModel = nullptr;
971         ErrorStatus status;
972         hidl_vec<hidl_handle> modelCache, dataCache;
973         modelCacheMode[i] = AccessMode::WRITE_ONLY;
974         createCacheHandles(mModelCache, modelCacheMode, &modelCache);
975         createCacheHandles(mDataCache, dataCacheMode, &dataCache);
976         modelCacheMode[i] = AccessMode::READ_WRITE;
977         prepareModelFromCache(modelCache, dataCache, &preparedModel, &status);
978         ASSERT_EQ(status, ErrorStatus::GENERAL_FAILURE);
979         ASSERT_EQ(preparedModel, nullptr);
980     }
981 
982     // Go through each handle in data cache, test with invalid access mode.
983     for (uint32_t i = 0; i < mNumDataCache; i++) {
984         sp<IPreparedModel> preparedModel = nullptr;
985         ErrorStatus status;
986         hidl_vec<hidl_handle> modelCache, dataCache;
987         dataCacheMode[i] = AccessMode::WRITE_ONLY;
988         createCacheHandles(mModelCache, modelCacheMode, &modelCache);
989         createCacheHandles(mDataCache, dataCacheMode, &dataCache);
990         dataCacheMode[i] = AccessMode::READ_WRITE;
991         prepareModelFromCache(modelCache, dataCache, &preparedModel, &status);
992         ASSERT_EQ(status, ErrorStatus::GENERAL_FAILURE);
993         ASSERT_EQ(preparedModel, nullptr);
994     }
995 }
996 
997 // Copy file contents between file groups.
998 // The outer vector corresponds to handles and the inner vector is for fds held by each handle.
999 // The outer vector sizes must match and the inner vectors must have size = 1.
copyCacheFiles(const std::vector<std::vector<std::string>> & from,const std::vector<std::vector<std::string>> & to)1000 static void copyCacheFiles(const std::vector<std::vector<std::string>>& from,
1001                            const std::vector<std::vector<std::string>>& to) {
1002     constexpr size_t kBufferSize = 1000000;
1003     uint8_t buffer[kBufferSize];
1004 
1005     ASSERT_EQ(from.size(), to.size());
1006     for (uint32_t i = 0; i < from.size(); i++) {
1007         ASSERT_EQ(from[i].size(), 1u);
1008         ASSERT_EQ(to[i].size(), 1u);
1009         int fromFd = open(from[i][0].c_str(), O_RDONLY);
1010         int toFd = open(to[i][0].c_str(), O_WRONLY | O_CREAT, S_IRUSR | S_IWUSR);
1011         ASSERT_GE(fromFd, 0);
1012         ASSERT_GE(toFd, 0);
1013 
1014         ssize_t readBytes;
1015         while ((readBytes = read(fromFd, &buffer, kBufferSize)) > 0) {
1016             ASSERT_EQ(write(toFd, &buffer, readBytes), readBytes);
1017         }
1018         ASSERT_GE(readBytes, 0);
1019 
1020         close(fromFd);
1021         close(toFd);
1022     }
1023 }
1024 
1025 // Number of operations in the large test model.
1026 constexpr uint32_t kLargeModelSize = 100;
1027 constexpr uint32_t kNumIterationsTOCTOU = 100;
1028 
TEST_P(CompilationCachingTest,SaveToCache_TOCTOU)1029 TEST_P(CompilationCachingTest, SaveToCache_TOCTOU) {
1030     if (!mIsCachingSupported) return;
1031 
1032     // Create test models and check if fully supported by the service.
1033     const TestModel testModelMul = createLargeTestModel(OperationType::MUL, kLargeModelSize);
1034     const Model modelMul = createModel(testModelMul);
1035     if (checkEarlyTermination(modelMul)) return;
1036     const TestModel testModelAdd = createLargeTestModel(OperationType::ADD, kLargeModelSize);
1037     const Model modelAdd = createModel(testModelAdd);
1038     if (checkEarlyTermination(modelAdd)) return;
1039 
1040     // Save the modelMul compilation to cache.
1041     auto modelCacheMul = mModelCache;
1042     for (auto& cache : modelCacheMul) {
1043         cache[0].append("_mul");
1044     }
1045     {
1046         hidl_vec<hidl_handle> modelCache, dataCache;
1047         createCacheHandles(modelCacheMul, AccessMode::READ_WRITE, &modelCache);
1048         createCacheHandles(mDataCache, AccessMode::READ_WRITE, &dataCache);
1049         saveModelToCache(modelMul, modelCache, dataCache);
1050     }
1051 
1052     // Use a different token for modelAdd.
1053     mToken[0]++;
1054 
1055     // This test is probabilistic, so we run it multiple times.
1056     for (uint32_t i = 0; i < kNumIterationsTOCTOU; i++) {
1057         // Save the modelAdd compilation to cache.
1058         {
1059             hidl_vec<hidl_handle> modelCache, dataCache;
1060             createCacheHandles(mModelCache, AccessMode::READ_WRITE, &modelCache);
1061             createCacheHandles(mDataCache, AccessMode::READ_WRITE, &dataCache);
1062 
1063             // Spawn a thread to copy the cache content concurrently while saving to cache.
1064             std::thread thread(copyCacheFiles, std::cref(modelCacheMul), std::cref(mModelCache));
1065             saveModelToCache(modelAdd, modelCache, dataCache);
1066             thread.join();
1067         }
1068 
1069         // Retrieve preparedModel from cache.
1070         {
1071             sp<IPreparedModel> preparedModel = nullptr;
1072             ErrorStatus status;
1073             hidl_vec<hidl_handle> modelCache, dataCache;
1074             createCacheHandles(mModelCache, AccessMode::READ_WRITE, &modelCache);
1075             createCacheHandles(mDataCache, AccessMode::READ_WRITE, &dataCache);
1076             prepareModelFromCache(modelCache, dataCache, &preparedModel, &status);
1077 
1078             // The preparation may fail or succeed, but must not crash. If the preparation succeeds,
1079             // the prepared model must be executed with the correct result and not crash.
1080             if (status != ErrorStatus::NONE) {
1081                 ASSERT_EQ(preparedModel, nullptr);
1082             } else {
1083                 ASSERT_NE(preparedModel, nullptr);
1084                 EvaluatePreparedModel(preparedModel, testModelAdd,
1085                                       /*testDynamicOutputShape=*/false);
1086             }
1087         }
1088     }
1089 }
1090 
TEST_P(CompilationCachingTest,PrepareFromCache_TOCTOU)1091 TEST_P(CompilationCachingTest, PrepareFromCache_TOCTOU) {
1092     if (!mIsCachingSupported) return;
1093 
1094     // Create test models and check if fully supported by the service.
1095     const TestModel testModelMul = createLargeTestModel(OperationType::MUL, kLargeModelSize);
1096     const Model modelMul = createModel(testModelMul);
1097     if (checkEarlyTermination(modelMul)) return;
1098     const TestModel testModelAdd = createLargeTestModel(OperationType::ADD, kLargeModelSize);
1099     const Model modelAdd = createModel(testModelAdd);
1100     if (checkEarlyTermination(modelAdd)) return;
1101 
1102     // Save the modelMul compilation to cache.
1103     auto modelCacheMul = mModelCache;
1104     for (auto& cache : modelCacheMul) {
1105         cache[0].append("_mul");
1106     }
1107     {
1108         hidl_vec<hidl_handle> modelCache, dataCache;
1109         createCacheHandles(modelCacheMul, AccessMode::READ_WRITE, &modelCache);
1110         createCacheHandles(mDataCache, AccessMode::READ_WRITE, &dataCache);
1111         saveModelToCache(modelMul, modelCache, dataCache);
1112     }
1113 
1114     // Use a different token for modelAdd.
1115     mToken[0]++;
1116 
1117     // This test is probabilistic, so we run it multiple times.
1118     for (uint32_t i = 0; i < kNumIterationsTOCTOU; i++) {
1119         // Save the modelAdd compilation to cache.
1120         {
1121             hidl_vec<hidl_handle> modelCache, dataCache;
1122             createCacheHandles(mModelCache, AccessMode::READ_WRITE, &modelCache);
1123             createCacheHandles(mDataCache, AccessMode::READ_WRITE, &dataCache);
1124             saveModelToCache(modelAdd, modelCache, dataCache);
1125         }
1126 
1127         // Retrieve preparedModel from cache.
1128         {
1129             sp<IPreparedModel> preparedModel = nullptr;
1130             ErrorStatus status;
1131             hidl_vec<hidl_handle> modelCache, dataCache;
1132             createCacheHandles(mModelCache, AccessMode::READ_WRITE, &modelCache);
1133             createCacheHandles(mDataCache, AccessMode::READ_WRITE, &dataCache);
1134 
1135             // Spawn a thread to copy the cache content concurrently while preparing from cache.
1136             std::thread thread(copyCacheFiles, std::cref(modelCacheMul), std::cref(mModelCache));
1137             prepareModelFromCache(modelCache, dataCache, &preparedModel, &status);
1138             thread.join();
1139 
1140             // The preparation may fail or succeed, but must not crash. If the preparation succeeds,
1141             // the prepared model must be executed with the correct result and not crash.
1142             if (status != ErrorStatus::NONE) {
1143                 ASSERT_EQ(preparedModel, nullptr);
1144             } else {
1145                 ASSERT_NE(preparedModel, nullptr);
1146                 EvaluatePreparedModel(preparedModel, testModelAdd,
1147                                       /*testDynamicOutputShape=*/false);
1148             }
1149         }
1150     }
1151 }
1152 
TEST_P(CompilationCachingTest,ReplaceSecuritySensitiveCache)1153 TEST_P(CompilationCachingTest, ReplaceSecuritySensitiveCache) {
1154     if (!mIsCachingSupported) return;
1155 
1156     // Create test models and check if fully supported by the service.
1157     const TestModel testModelMul = createLargeTestModel(OperationType::MUL, kLargeModelSize);
1158     const Model modelMul = createModel(testModelMul);
1159     if (checkEarlyTermination(modelMul)) return;
1160     const TestModel testModelAdd = createLargeTestModel(OperationType::ADD, kLargeModelSize);
1161     const Model modelAdd = createModel(testModelAdd);
1162     if (checkEarlyTermination(modelAdd)) return;
1163 
1164     // Save the modelMul compilation to cache.
1165     auto modelCacheMul = mModelCache;
1166     for (auto& cache : modelCacheMul) {
1167         cache[0].append("_mul");
1168     }
1169     {
1170         hidl_vec<hidl_handle> modelCache, dataCache;
1171         createCacheHandles(modelCacheMul, AccessMode::READ_WRITE, &modelCache);
1172         createCacheHandles(mDataCache, AccessMode::READ_WRITE, &dataCache);
1173         saveModelToCache(modelMul, modelCache, dataCache);
1174     }
1175 
1176     // Use a different token for modelAdd.
1177     mToken[0]++;
1178 
1179     // Save the modelAdd compilation to cache.
1180     {
1181         hidl_vec<hidl_handle> modelCache, dataCache;
1182         createCacheHandles(mModelCache, AccessMode::READ_WRITE, &modelCache);
1183         createCacheHandles(mDataCache, AccessMode::READ_WRITE, &dataCache);
1184         saveModelToCache(modelAdd, modelCache, dataCache);
1185     }
1186 
1187     // Replace the model cache of modelAdd with modelMul.
1188     copyCacheFiles(modelCacheMul, mModelCache);
1189 
1190     // Retrieve the preparedModel from cache, expect failure.
1191     {
1192         sp<IPreparedModel> preparedModel = nullptr;
1193         ErrorStatus status;
1194         hidl_vec<hidl_handle> modelCache, dataCache;
1195         createCacheHandles(mModelCache, AccessMode::READ_WRITE, &modelCache);
1196         createCacheHandles(mDataCache, AccessMode::READ_WRITE, &dataCache);
1197         prepareModelFromCache(modelCache, dataCache, &preparedModel, &status);
1198         ASSERT_EQ(status, ErrorStatus::GENERAL_FAILURE);
1199         ASSERT_EQ(preparedModel, nullptr);
1200     }
1201 }
1202 
1203 static const auto kNamedDeviceChoices = testing::ValuesIn(getNamedDevices());
1204 static const auto kOperandTypeChoices =
1205         testing::Values(OperandType::TENSOR_FLOAT32, OperandType::TENSOR_QUANT8_ASYMM);
1206 
printCompilationCachingTest(const testing::TestParamInfo<CompilationCachingTestParam> & info)1207 std::string printCompilationCachingTest(
1208         const testing::TestParamInfo<CompilationCachingTestParam>& info) {
1209     const auto& [namedDevice, operandType] = info.param;
1210     const std::string type = (operandType == OperandType::TENSOR_FLOAT32 ? "float32" : "quant8");
1211     return gtestCompliantName(getName(namedDevice) + "_" + type);
1212 }
1213 
1214 GTEST_ALLOW_UNINSTANTIATED_PARAMETERIZED_TEST(CompilationCachingTest);
1215 INSTANTIATE_TEST_SUITE_P(TestCompilationCaching, CompilationCachingTest,
1216                          testing::Combine(kNamedDeviceChoices, kOperandTypeChoices),
1217                          printCompilationCachingTest);
1218 
1219 using CompilationCachingSecurityTestParam = std::tuple<NamedDevice, OperandType, uint32_t>;
1220 
1221 class CompilationCachingSecurityTest
1222     : public CompilationCachingTestBase,
1223       public testing::WithParamInterface<CompilationCachingSecurityTestParam> {
1224   protected:
CompilationCachingSecurityTest()1225     CompilationCachingSecurityTest()
1226         : CompilationCachingTestBase(getData(std::get<NamedDevice>(GetParam())),
1227                                      std::get<OperandType>(GetParam())) {}
1228 
SetUp()1229     void SetUp() {
1230         CompilationCachingTestBase::SetUp();
1231         generator.seed(kSeed);
1232     }
1233 
1234     // Get a random integer within a closed range [lower, upper].
1235     template <typename T>
getRandomInt(T lower,T upper)1236     T getRandomInt(T lower, T upper) {
1237         std::uniform_int_distribution<T> dis(lower, upper);
1238         return dis(generator);
1239     }
1240 
1241     // Randomly flip one single bit of the cache entry.
flipOneBitOfCache(const std::string & filename,bool * skip)1242     void flipOneBitOfCache(const std::string& filename, bool* skip) {
1243         FILE* pFile = fopen(filename.c_str(), "r+");
1244         ASSERT_EQ(fseek(pFile, 0, SEEK_END), 0);
1245         long int fileSize = ftell(pFile);
1246         if (fileSize == 0) {
1247             fclose(pFile);
1248             *skip = true;
1249             return;
1250         }
1251         ASSERT_EQ(fseek(pFile, getRandomInt(0l, fileSize - 1), SEEK_SET), 0);
1252         int readByte = fgetc(pFile);
1253         ASSERT_NE(readByte, EOF);
1254         ASSERT_EQ(fseek(pFile, -1, SEEK_CUR), 0);
1255         ASSERT_NE(fputc(static_cast<uint8_t>(readByte) ^ (1U << getRandomInt(0, 7)), pFile), EOF);
1256         fclose(pFile);
1257         *skip = false;
1258     }
1259 
1260     // Randomly append bytes to the cache entry.
appendBytesToCache(const std::string & filename,bool * skip)1261     void appendBytesToCache(const std::string& filename, bool* skip) {
1262         FILE* pFile = fopen(filename.c_str(), "a");
1263         uint32_t appendLength = getRandomInt(1, 256);
1264         for (uint32_t i = 0; i < appendLength; i++) {
1265             ASSERT_NE(fputc(getRandomInt<uint16_t>(0, 255), pFile), EOF);
1266         }
1267         fclose(pFile);
1268         *skip = false;
1269     }
1270 
1271     enum class ExpectedResult { GENERAL_FAILURE, NOT_CRASH };
1272 
1273     // Test if the driver behaves as expected when given corrupted cache or token.
1274     // The modifier will be invoked after save to cache but before prepare from cache.
1275     // The modifier accepts one pointer argument "skip" as the returning value, indicating
1276     // whether the test should be skipped or not.
testCorruptedCache(ExpectedResult expected,std::function<void (bool *)> modifier)1277     void testCorruptedCache(ExpectedResult expected, std::function<void(bool*)> modifier) {
1278         const TestModel& testModel = createTestModel();
1279         const Model model = createModel(testModel);
1280         if (checkEarlyTermination(model)) return;
1281 
1282         // Save the compilation to cache.
1283         {
1284             hidl_vec<hidl_handle> modelCache, dataCache;
1285             createCacheHandles(mModelCache, AccessMode::READ_WRITE, &modelCache);
1286             createCacheHandles(mDataCache, AccessMode::READ_WRITE, &dataCache);
1287             saveModelToCache(model, modelCache, dataCache);
1288         }
1289 
1290         bool skip = false;
1291         modifier(&skip);
1292         if (skip) return;
1293 
1294         // Retrieve preparedModel from cache.
1295         {
1296             sp<IPreparedModel> preparedModel = nullptr;
1297             ErrorStatus status;
1298             hidl_vec<hidl_handle> modelCache, dataCache;
1299             createCacheHandles(mModelCache, AccessMode::READ_WRITE, &modelCache);
1300             createCacheHandles(mDataCache, AccessMode::READ_WRITE, &dataCache);
1301             prepareModelFromCache(modelCache, dataCache, &preparedModel, &status);
1302 
1303             switch (expected) {
1304                 case ExpectedResult::GENERAL_FAILURE:
1305                     ASSERT_EQ(status, ErrorStatus::GENERAL_FAILURE);
1306                     ASSERT_EQ(preparedModel, nullptr);
1307                     break;
1308                 case ExpectedResult::NOT_CRASH:
1309                     ASSERT_EQ(preparedModel == nullptr, status != ErrorStatus::NONE);
1310                     break;
1311                 default:
1312                     FAIL();
1313             }
1314         }
1315     }
1316 
1317     const uint32_t kSeed = std::get<uint32_t>(GetParam());
1318     std::mt19937 generator;
1319 };
1320 
TEST_P(CompilationCachingSecurityTest,CorruptedModelCache)1321 TEST_P(CompilationCachingSecurityTest, CorruptedModelCache) {
1322     if (!mIsCachingSupported) return;
1323     for (uint32_t i = 0; i < mNumModelCache; i++) {
1324         testCorruptedCache(ExpectedResult::GENERAL_FAILURE,
1325                            [this, i](bool* skip) { flipOneBitOfCache(mModelCache[i][0], skip); });
1326     }
1327 }
1328 
TEST_P(CompilationCachingSecurityTest,WrongLengthModelCache)1329 TEST_P(CompilationCachingSecurityTest, WrongLengthModelCache) {
1330     if (!mIsCachingSupported) return;
1331     for (uint32_t i = 0; i < mNumModelCache; i++) {
1332         testCorruptedCache(ExpectedResult::GENERAL_FAILURE,
1333                            [this, i](bool* skip) { appendBytesToCache(mModelCache[i][0], skip); });
1334     }
1335 }
1336 
TEST_P(CompilationCachingSecurityTest,CorruptedDataCache)1337 TEST_P(CompilationCachingSecurityTest, CorruptedDataCache) {
1338     if (!mIsCachingSupported) return;
1339     for (uint32_t i = 0; i < mNumDataCache; i++) {
1340         testCorruptedCache(ExpectedResult::NOT_CRASH,
1341                            [this, i](bool* skip) { flipOneBitOfCache(mDataCache[i][0], skip); });
1342     }
1343 }
1344 
TEST_P(CompilationCachingSecurityTest,WrongLengthDataCache)1345 TEST_P(CompilationCachingSecurityTest, WrongLengthDataCache) {
1346     if (!mIsCachingSupported) return;
1347     for (uint32_t i = 0; i < mNumDataCache; i++) {
1348         testCorruptedCache(ExpectedResult::NOT_CRASH,
1349                            [this, i](bool* skip) { appendBytesToCache(mDataCache[i][0], skip); });
1350     }
1351 }
1352 
TEST_P(CompilationCachingSecurityTest,WrongToken)1353 TEST_P(CompilationCachingSecurityTest, WrongToken) {
1354     if (!mIsCachingSupported) return;
1355     testCorruptedCache(ExpectedResult::GENERAL_FAILURE, [this](bool* skip) {
1356         // Randomly flip one single bit in mToken.
1357         uint32_t ind =
1358                 getRandomInt(0u, static_cast<uint32_t>(Constant::BYTE_SIZE_OF_CACHE_TOKEN) - 1);
1359         mToken[ind] ^= (1U << getRandomInt(0, 7));
1360         *skip = false;
1361     });
1362 }
1363 
printCompilationCachingSecurityTest(const testing::TestParamInfo<CompilationCachingSecurityTestParam> & info)1364 std::string printCompilationCachingSecurityTest(
1365         const testing::TestParamInfo<CompilationCachingSecurityTestParam>& info) {
1366     const auto& [namedDevice, operandType, seed] = info.param;
1367     const std::string type = (operandType == OperandType::TENSOR_FLOAT32 ? "float32" : "quant8");
1368     return gtestCompliantName(getName(namedDevice) + "_" + type + "_" + std::to_string(seed));
1369 }
1370 
1371 GTEST_ALLOW_UNINSTANTIATED_PARAMETERIZED_TEST(CompilationCachingSecurityTest);
1372 INSTANTIATE_TEST_SUITE_P(TestCompilationCaching, CompilationCachingSecurityTest,
1373                          testing::Combine(kNamedDeviceChoices, kOperandTypeChoices,
1374                                           testing::Range(0U, 10U)),
1375                          printCompilationCachingSecurityTest);
1376 
1377 }  // namespace android::hardware::neuralnetworks::V1_2::vts::functional
1378