1 /*
2 * Copyright (C) 2019 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include <android-base/scopeguard.h>
18 #include <gtest/gtest.h>
19
20 #include <cstdlib>
21 #include <filesystem>
22 #include <numeric>
23 #include <string>
24 #include <string_view>
25 #include <tuple>
26 #include <vector>
27
28 #include "HalInterfaces.h"
29 #include "Manager.h"
30 #include "SampleDriver.h"
31 #include "TestNeuralNetworksWrapper.h"
32
33 using namespace android::nn;
34 using namespace hal;
35 using Result = test_wrapper::Result;
36 using Type = test_wrapper::Type;
37 const Timing kBadTiming = {.timeOnDevice = UINT64_MAX, .timeInDriver = UINT64_MAX};
38 template <typename T>
39 using MQDescriptorSync = ::android::hardware::MQDescriptorSync<T>;
40
41 namespace android::hardware::neuralnetworks::V1_0 {
42
operator <<(::std::ostream & os,ErrorStatus errorStatus)43 ::std::ostream& operator<<(::std::ostream& os, ErrorStatus errorStatus) {
44 return os << toString(errorStatus);
45 }
46
47 } // namespace android::hardware::neuralnetworks::V1_0
48
49 namespace {
50
51 enum class HasCalledPrepareModel { NO, WITHOUT_CACHING, WITH_CACHING };
52
53 // Print HasCalledPrepareModel enum for better GTEST failure messages
operator <<(std::ostream & os,HasCalledPrepareModel hasCalledPrepareModel)54 std::ostream& operator<<(std::ostream& os, HasCalledPrepareModel hasCalledPrepareModel) {
55 switch (hasCalledPrepareModel) {
56 case HasCalledPrepareModel::NO:
57 return os << "NO";
58 case HasCalledPrepareModel::WITHOUT_CACHING:
59 return os << "WITHOUT_CACHING";
60 case HasCalledPrepareModel::WITH_CACHING:
61 return os << "WITH_CACHING";
62 }
63 CHECK(false) << "HasCalledPrepareModel print called with invalid code "
64 << static_cast<int>(hasCalledPrepareModel);
65 return os;
66 }
67
68 // Whether the driver is expected to be registered because it can pass initialization.
canDeviceBeRegistered(ErrorStatus error,uint32_t numModelCache,uint32_t numDataCache)69 bool canDeviceBeRegistered(ErrorStatus error, uint32_t numModelCache, uint32_t numDataCache) {
70 constexpr uint32_t maxNumCacheFiles =
71 static_cast<uint32_t>(Constant::MAX_NUMBER_OF_CACHE_FILES);
72 return error == ErrorStatus::NONE && numModelCache <= maxNumCacheFiles &&
73 numDataCache <= maxNumCacheFiles;
74 }
75
76 // Whether the driver supports caching based on the returns from getNumberOfCacheFilesNeeded.
isCachingSupported(uint32_t numModelCache,uint32_t numDataCache)77 bool isCachingSupported(uint32_t numModelCache, uint32_t numDataCache) {
78 return numModelCache != 0 || numDataCache != 0;
79 }
80
81 // This is an IDevice for testing purposes which overrides several methods from sample driver:
82 // - supports all the operations and is faster than cpu fallback.
83 // - overrides getNumberOfCacheFilesNeeded to report according to given parameters.
84 // - overrides prepareModelFromCache_1_3 to return error status according to
85 // mErrorStatusPrepareFromCache.
86 // - produces CachingPreparedModel on prepareModel and prepareModelFromCache_1_3.
87 //
88 // The cache entry is written by prepareModel_1_3 and is checked later by
89 // CachingDriver::prepareModelFromCache_1_3.
90 //
91 // The CachingDriver has 2 flags mHasCalledPrepareModelFromCache and mHasCalledPrepareModel
92 // to check if the correct methods are invoked by the runtime.
93 class CachingDriver : public sample_driver::SampleDriver {
94 private:
95 static constexpr size_t kCacheSize = 256;
96
97 class CachingPreparedModel : public IPreparedModel {
98 public:
99 CachingPreparedModel() = default;
100
execute(const V1_0::Request &,const sp<V1_0::IExecutionCallback> &)101 Return<V1_0::ErrorStatus> execute(const V1_0::Request&,
102 const sp<V1_0::IExecutionCallback>&) override {
103 return V1_0::ErrorStatus::DEVICE_UNAVAILABLE;
104 }
execute_1_2(const V1_0::Request &,MeasureTiming,const sp<V1_2::IExecutionCallback> &)105 Return<V1_0::ErrorStatus> execute_1_2(const V1_0::Request&, MeasureTiming,
106 const sp<V1_2::IExecutionCallback>&) override {
107 return V1_0::ErrorStatus::DEVICE_UNAVAILABLE;
108 }
execute_1_3(const V1_3::Request &,MeasureTiming,const OptionalTimePoint &,const OptionalTimeoutDuration &,const sp<V1_3::IExecutionCallback> &)109 Return<V1_3::ErrorStatus> execute_1_3(const V1_3::Request&, MeasureTiming,
110 const OptionalTimePoint&,
111 const OptionalTimeoutDuration&,
112 const sp<V1_3::IExecutionCallback>&) override {
113 return V1_3::ErrorStatus::DEVICE_UNAVAILABLE;
114 }
executeSynchronously(const V1_0::Request &,MeasureTiming,executeSynchronously_cb cb)115 Return<void> executeSynchronously(const V1_0::Request&, MeasureTiming,
116 executeSynchronously_cb cb) override {
117 cb(V1_0::ErrorStatus::DEVICE_UNAVAILABLE, {}, kBadTiming);
118 return Void();
119 }
executeSynchronously_1_3(const V1_3::Request &,MeasureTiming,const OptionalTimePoint &,const OptionalTimeoutDuration &,executeSynchronously_1_3_cb cb)120 Return<void> executeSynchronously_1_3(const V1_3::Request&, MeasureTiming,
121 const OptionalTimePoint&,
122 const OptionalTimeoutDuration&,
123 executeSynchronously_1_3_cb cb) override {
124 cb(V1_3::ErrorStatus::DEVICE_UNAVAILABLE, {}, kBadTiming);
125 return Void();
126 }
configureExecutionBurst(const sp<V1_2::IBurstCallback> &,const MQDescriptorSync<V1_2::FmqRequestDatum> &,const MQDescriptorSync<V1_2::FmqResultDatum> &,configureExecutionBurst_cb cb)127 Return<void> configureExecutionBurst(const sp<V1_2::IBurstCallback>&,
128 const MQDescriptorSync<V1_2::FmqRequestDatum>&,
129 const MQDescriptorSync<V1_2::FmqResultDatum>&,
130 configureExecutionBurst_cb cb) override {
131 cb(V1_0::ErrorStatus::DEVICE_UNAVAILABLE, nullptr);
132 return Void();
133 }
executeFenced(const hal::Request &,const hidl_vec<hidl_handle> &,MeasureTiming,const OptionalTimePoint &,const OptionalTimeoutDuration &,const OptionalTimeoutDuration &,executeFenced_cb cb)134 Return<void> executeFenced(const hal::Request&, const hidl_vec<hidl_handle>&, MeasureTiming,
135 const OptionalTimePoint&, const OptionalTimeoutDuration&,
136 const OptionalTimeoutDuration&, executeFenced_cb cb) {
137 cb(ErrorStatus::DEVICE_UNAVAILABLE, hidl_handle(nullptr), nullptr);
138 return Void();
139 }
140 };
141
142 public:
CachingDriver(std::string_view name,ErrorStatus errorStatusGetNumCacheFiles,uint32_t numModelCache,uint32_t numDataCache,ErrorStatus errorStatusPrepareFromCache)143 CachingDriver(std::string_view name, ErrorStatus errorStatusGetNumCacheFiles,
144 uint32_t numModelCache, uint32_t numDataCache,
145 ErrorStatus errorStatusPrepareFromCache)
146 : SampleDriver(name.data()),
147 mErrorStatusGetNumCacheFiles(errorStatusGetNumCacheFiles),
148 mNumModelCache(numModelCache),
149 mNumDataCache(numDataCache),
150 mErrorStatusPrepareFromCache(errorStatusPrepareFromCache) {
151 mModelCacheData.resize(kCacheSize);
152 std::iota(mModelCacheData.begin(), mModelCacheData.end(), 0);
153 mDataCacheData.resize(kCacheSize);
154 std::iota(mDataCacheData.begin(), mDataCacheData.end(), 1);
155 }
~CachingDriver()156 ~CachingDriver() override {}
157
158 // Reports faster than cpu.
getCapabilities_1_3(getCapabilities_1_3_cb cb)159 Return<void> getCapabilities_1_3(getCapabilities_1_3_cb cb) override {
160 android::nn::initVLogMask();
161 const PerformanceInfo kPerf = {.execTime = 0.1, .powerUsage = 0.1};
162 Capabilities capabilities = {
163 .relaxedFloat32toFloat16PerformanceScalar = kPerf,
164 .relaxedFloat32toFloat16PerformanceTensor = kPerf,
165 .operandPerformance = nonExtensionOperandPerformance<HalVersion::V1_3>(kPerf),
166 .ifPerformance = kPerf,
167 .whilePerformance = kPerf};
168 cb(V1_3::ErrorStatus::NONE, capabilities);
169 return Void();
170 }
171
172 // Reports supporting all operations.
getSupportedOperations_1_3(const Model & model,getSupportedOperations_1_3_cb cb)173 Return<void> getSupportedOperations_1_3(const Model& model,
174 getSupportedOperations_1_3_cb cb) override {
175 std::vector<bool> supported(model.main.operations.size(), true);
176 cb(V1_3::ErrorStatus::NONE, supported);
177 return Void();
178 }
179
180 // Reports according to mGetNumCacheFiles.
getNumberOfCacheFilesNeeded(getNumberOfCacheFilesNeeded_cb cb)181 Return<void> getNumberOfCacheFilesNeeded(getNumberOfCacheFilesNeeded_cb cb) override {
182 cb(convertToV1_0(mErrorStatusGetNumCacheFiles), mNumModelCache, mNumDataCache);
183 return Void();
184 }
185
186 // Generates CachingPreparedModel.
187 // Writes the cache entry per mCacheXData and sets mHasCalledPrepareModel.
prepareModel_1_3(const Model &,ExecutionPreference,Priority,const OptionalTimePoint &,const hidl_vec<hidl_handle> & modelCacheHandle,const hidl_vec<hidl_handle> & dataCacheHandle,const CacheToken &,const sp<V1_3::IPreparedModelCallback> & cb)188 Return<V1_3::ErrorStatus> prepareModel_1_3(
189 const Model&, ExecutionPreference, Priority, const OptionalTimePoint&,
190 const hidl_vec<hidl_handle>& modelCacheHandle,
191 const hidl_vec<hidl_handle>& dataCacheHandle, const CacheToken&,
192 const sp<V1_3::IPreparedModelCallback>& cb) override {
193 checkNumberOfCacheHandles(modelCacheHandle.size(), dataCacheHandle.size());
194 if (modelCacheHandle.size() != 0 || dataCacheHandle.size() != 0) {
195 writeToCache(modelCacheHandle, mModelCacheData);
196 writeToCache(dataCacheHandle, mDataCacheData);
197 mHasCalledPrepareModel = HasCalledPrepareModel::WITH_CACHING;
198 } else {
199 mHasCalledPrepareModel = HasCalledPrepareModel::WITHOUT_CACHING;
200 }
201 cb->notify_1_3(V1_3::ErrorStatus::NONE, new CachingPreparedModel());
202 return V1_3::ErrorStatus::NONE;
203 }
204
205 // Checks if the cache entry is correct, notifies error status according to
206 // mErrorStatusPrepareFromCache, sets mHasCalledPrepareModelFromCache.
prepareModelFromCache_1_3(const OptionalTimePoint &,const hidl_vec<hidl_handle> & modelCacheHandle,const hidl_vec<hidl_handle> & dataCacheHandle,const CacheToken &,const sp<V1_3::IPreparedModelCallback> & callback)207 Return<V1_3::ErrorStatus> prepareModelFromCache_1_3(
208 const OptionalTimePoint&, const hidl_vec<hidl_handle>& modelCacheHandle,
209 const hidl_vec<hidl_handle>& dataCacheHandle, const CacheToken&,
210 const sp<V1_3::IPreparedModelCallback>& callback) override {
211 readFromCache(modelCacheHandle, mModelCacheData);
212 readFromCache(dataCacheHandle, mDataCacheData);
213 mHasCalledPrepareModelFromCache = true;
214 if (mErrorStatusPrepareFromCache == V1_3::ErrorStatus::NONE) {
215 callback->notify_1_3(mErrorStatusPrepareFromCache, new CachingPreparedModel());
216 } else {
217 callback->notify_1_3(mErrorStatusPrepareFromCache, nullptr);
218 }
219 return V1_3::ErrorStatus::NONE;
220 };
221
hasCalledPrepareModelFromCache() const222 bool hasCalledPrepareModelFromCache() const { return mHasCalledPrepareModelFromCache; }
hasCalledPrepareModel() const223 HasCalledPrepareModel hasCalledPrepareModel() const { return mHasCalledPrepareModel; }
224
225 private:
226 // Checks the number of cache files passed to the driver from runtime.
checkNumberOfCacheHandles(size_t modelCache,size_t dataCache)227 void checkNumberOfCacheHandles(size_t modelCache, size_t dataCache) {
228 if (isCachingSupported(mNumModelCache, mNumDataCache)) {
229 if (modelCache != 0 || dataCache != 0) {
230 ASSERT_EQ(modelCache, mNumModelCache);
231 ASSERT_EQ(dataCache, mNumDataCache);
232 }
233 } else {
234 ASSERT_EQ(modelCache, 0ul);
235 ASSERT_EQ(dataCache, 0ul);
236 }
237 }
238
writeToCache(const hidl_vec<hidl_handle> & handles,const std::vector<uint8_t> & cache)239 void writeToCache(const hidl_vec<hidl_handle>& handles, const std::vector<uint8_t>& cache) {
240 for (uint32_t i = 0; i < handles.size(); ++i) {
241 ASSERT_EQ(handles[i]->numFds, 1);
242 EXPECT_EQ(write(handles[i]->data[0], cache.data(), kCacheSize),
243 static_cast<ssize_t>(kCacheSize));
244 }
245 }
246
readFromCache(const hidl_vec<hidl_handle> & handles,const std::vector<uint8_t> & expected)247 void readFromCache(const hidl_vec<hidl_handle>& handles, const std::vector<uint8_t>& expected) {
248 for (uint32_t i = 0; i < handles.size(); ++i) {
249 ASSERT_EQ(handles[i]->numFds, 1);
250 std::vector<uint8_t> actual(kCacheSize);
251 EXPECT_EQ(read(handles[i]->data[0], actual.data(), kCacheSize),
252 static_cast<ssize_t>(kCacheSize));
253 EXPECT_EQ(actual, expected);
254 }
255 }
256
257 std::vector<uint8_t> mModelCacheData;
258 std::vector<uint8_t> mDataCacheData;
259
260 const ErrorStatus mErrorStatusGetNumCacheFiles;
261 const uint32_t mNumModelCache;
262 const uint32_t mNumDataCache;
263 const ErrorStatus mErrorStatusPrepareFromCache;
264
265 bool mHasCalledPrepareModelFromCache = false;
266 HasCalledPrepareModel mHasCalledPrepareModel = HasCalledPrepareModel::NO;
267 };
268
CreateBroadcastAddModel(test_wrapper::Model * model)269 void CreateBroadcastAddModel(test_wrapper::Model* model) {
270 test_wrapper::OperandType matrixType(Type::TENSOR_FLOAT32, {2, 2});
271 test_wrapper::OperandType vectorType(Type::TENSOR_FLOAT32, {2});
272 test_wrapper::OperandType scalarType(Type::INT32, {});
273 int32_t activation(ANEURALNETWORKS_FUSED_NONE);
274 auto a = model->addOperand(&matrixType);
275 auto b = model->addOperand(&vectorType);
276 auto c = model->addOperand(&matrixType);
277 auto d = model->addOperand(&scalarType);
278 model->setOperandValue(d, &activation, sizeof(activation));
279 model->addOperation(ANEURALNETWORKS_ADD, {a, b, d}, {c});
280 model->identifyInputsAndOutputs({a, b}, {c});
281 ASSERT_TRUE(model->isValid());
282 ASSERT_EQ(model->finish(), Result::NO_ERROR);
283 }
284
getDeviceWithName(std::string_view deviceName,const ANeuralNetworksDevice ** outputDevice)285 void getDeviceWithName(std::string_view deviceName, const ANeuralNetworksDevice** outputDevice) {
286 uint32_t numDevices = 0;
287 ASSERT_EQ(ANeuralNetworks_getDeviceCount(&numDevices), ANEURALNETWORKS_NO_ERROR);
288 EXPECT_GE(numDevices, (uint32_t)1);
289
290 int numMatchingDevices = 0;
291 for (uint32_t i = 0; i < numDevices; i++) {
292 ANeuralNetworksDevice* device = nullptr;
293 ASSERT_EQ(ANeuralNetworks_getDevice(i, &device), ANEURALNETWORKS_NO_ERROR);
294
295 const char* buffer = nullptr;
296 ASSERT_EQ(ANeuralNetworksDevice_getName(device, &buffer), ANEURALNETWORKS_NO_ERROR);
297 if (deviceName == buffer) {
298 *outputDevice = device;
299 numMatchingDevices++;
300 }
301 }
302
303 EXPECT_LE(numMatchingDevices, 1);
304 }
305
306 // Test device registration with a driver parameterized with
307 // - ErrorStatus returning from getNumberOfCacheFilesNeeded
308 // - Number of model cache files returning from getNumberOfCacheFilesNeeded
309 // - Number of data cache files returning from getNumberOfCacheFilesNeeded
310 using DeviceRegistrationTestParam = std::tuple<ErrorStatus, uint32_t, uint32_t>;
311
312 class DeviceRegistrationTest : public ::testing::TestWithParam<DeviceRegistrationTestParam> {
313 protected:
314 static constexpr std::string_view kDeviceName = "deviceTestCompilationCaching";
315 const ErrorStatus kErrorStatusGetNumCacheFiles = std::get<0>(GetParam());
316 const uint32_t kNumModelCache = std::get<1>(GetParam());
317 const uint32_t kNumDataCache = std::get<2>(GetParam());
318 const sp<CachingDriver> kDriver =
319 new CachingDriver(kDeviceName, kErrorStatusGetNumCacheFiles, kNumModelCache,
320 kNumDataCache, ErrorStatus::NONE);
321 };
322
TEST_P(DeviceRegistrationTest,CachingFailure)323 TEST_P(DeviceRegistrationTest, CachingFailure) {
324 if (DeviceManager::get()->getUseCpuOnly()) {
325 return;
326 }
327
328 DeviceManager::get()->forTest_registerDevice(kDeviceName.data(), kDriver);
329 const auto cleanup = android::base::make_scope_guard(
330 [] { DeviceManager::get()->forTest_reInitializeDeviceList(); });
331
332 // get device
333 const ANeuralNetworksDevice* device = nullptr;
334 getDeviceWithName(kDeviceName, &device);
335
336 // check if device registeration matches expectations
337 const bool isDeviceRegistered = (device != nullptr);
338 const bool expectDeviceToBeRegistered =
339 canDeviceBeRegistered(kErrorStatusGetNumCacheFiles, kNumModelCache, kNumDataCache);
340 ASSERT_EQ(isDeviceRegistered, expectDeviceToBeRegistered);
341 }
342
343 // Test model compilation with a driver parameterized with
344 // - Number of model cache files returning from getNumberOfCacheFilesNeeded
345 // - Number of data cache files returning from getNumberOfCacheFilesNeeded
346 // - ErrorStatus returning from prepareModelFromCache_1_3
347 using CompilationCachingTestParam = std::tuple<uint32_t, uint32_t, ErrorStatus>;
348
349 class CompilationCachingTest : public ::testing::TestWithParam<CompilationCachingTestParam> {
350 protected:
SetUp()351 virtual void SetUp() override {
352 char cacheDirTemp[] = "/data/local/tmp/TestCompilationCachingXXXXXX";
353 char* cacheDir = mkdtemp(cacheDirTemp);
354 ASSERT_NE(cacheDir, nullptr);
355 mCacheDir = cacheDir;
356 CreateBroadcastAddModel(&mModel);
357 }
358
TearDown()359 virtual void TearDown() override {
360 if (!::testing::Test::HasFailure()) {
361 std::filesystem::remove_all(mCacheDir);
362 }
363 }
364
compileModel(const sp<CachingDriver> & driver,bool withToken)365 void compileModel(const sp<CachingDriver>& driver, bool withToken) {
366 DeviceManager::get()->forTest_registerDevice(kDeviceName.data(), driver);
367 const auto cleanup = android::base::make_scope_guard(
368 [] { DeviceManager::get()->forTest_reInitializeDeviceList(); });
369
370 // Get a handle to the single driver device matching kDeviceName.
371 const ANeuralNetworksDevice* device = nullptr;
372 getDeviceWithName(kDeviceName, &device);
373 ASSERT_NE(device, nullptr);
374
375 // Compile the model with the device.
376 ANeuralNetworksCompilation* compilation = nullptr;
377 ASSERT_EQ(ANeuralNetworksCompilation_createForDevices(mModel.getHandle(), &device, 1,
378 &compilation),
379 ANEURALNETWORKS_NO_ERROR);
380 if (withToken) {
381 ASSERT_EQ(ANeuralNetworksCompilation_setCaching(compilation, mCacheDir.c_str(),
382 kToken.data()),
383 ANEURALNETWORKS_NO_ERROR);
384 }
385 ASSERT_EQ(ANeuralNetworksCompilation_finish(compilation), ANEURALNETWORKS_NO_ERROR);
386 }
387
createCache()388 void createCache() {
389 sp<CachingDriver> driver = new CachingDriver(kDeviceName, ErrorStatus::NONE, kNumModelCache,
390 kNumDataCache, ErrorStatus::NONE);
391 compileModel(driver, /*withToken=*/true);
392 }
393
394 static constexpr std::string_view kDeviceName = "deviceTestCompilationCaching";
395 const uint32_t kNumModelCache = std::get<0>(GetParam());
396 const uint32_t kNumDataCache = std::get<1>(GetParam());
397 const ErrorStatus kErrorStatusPrepareFromCache = std::get<2>(GetParam());
398 const bool kIsCachingSupported = isCachingSupported(kNumModelCache, kNumDataCache);
399 test_wrapper::Model mModel;
400 std::string mCacheDir;
401 const CacheToken kToken{};
402 };
403
TEST_P(CompilationCachingTest,TokenProvidedAndCacheNotExist)404 TEST_P(CompilationCachingTest, TokenProvidedAndCacheNotExist) {
405 if (DeviceManager::get()->getUseCpuOnly()) {
406 return;
407 }
408 sp<CachingDriver> driver = new CachingDriver(kDeviceName, ErrorStatus::NONE, kNumModelCache,
409 kNumDataCache, kErrorStatusPrepareFromCache);
410 compileModel(driver, /*withToken=*/true);
411
412 // When cache file does not exist, the runtime should never call prepareModelFromCache_1_3.
413 EXPECT_FALSE(driver->hasCalledPrepareModelFromCache());
414
415 // The runtime should call prepareModel_1_3. It should request caching iff caching supported.
416 EXPECT_EQ(driver->hasCalledPrepareModel(), kIsCachingSupported
417 ? HasCalledPrepareModel::WITH_CACHING
418 : HasCalledPrepareModel::WITHOUT_CACHING);
419 }
420
TEST_P(CompilationCachingTest,TokenProvidedAndCacheExist)421 TEST_P(CompilationCachingTest, TokenProvidedAndCacheExist) {
422 if (DeviceManager::get()->getUseCpuOnly()) {
423 return;
424 }
425 createCache();
426 sp<CachingDriver> driver = new CachingDriver(kDeviceName, ErrorStatus::NONE, kNumModelCache,
427 kNumDataCache, kErrorStatusPrepareFromCache);
428 compileModel(driver, /*withToken=*/true);
429
430 // When cache files exist, the runtime should call prepareModelFromCache_1_3 iff caching
431 // supported.
432 EXPECT_EQ(driver->hasCalledPrepareModelFromCache(), kIsCachingSupported);
433
434 HasCalledPrepareModel expectHasCalledPrepareModel;
435 if (kIsCachingSupported) {
436 if (kErrorStatusPrepareFromCache == ErrorStatus::NONE) {
437 // The runtime should not call prepareModel_1_3 iff caching supported and
438 // prepareModelFromCache_1_3 succeeds.
439 expectHasCalledPrepareModel = HasCalledPrepareModel::NO;
440 } else {
441 // The runtime should call prepareModel_1_3 and request caching iff caching supported
442 // but prepareModelFromCache_1_3 fails.
443 expectHasCalledPrepareModel = HasCalledPrepareModel::WITH_CACHING;
444 }
445 } else {
446 // The runtime should call prepareModel_1_3 without caching iff caching not supported.
447 expectHasCalledPrepareModel = HasCalledPrepareModel::WITHOUT_CACHING;
448 }
449 EXPECT_EQ(driver->hasCalledPrepareModel(), expectHasCalledPrepareModel);
450 }
451
TEST_P(CompilationCachingTest,TokenNotProvided)452 TEST_P(CompilationCachingTest, TokenNotProvided) {
453 if (DeviceManager::get()->getUseCpuOnly()) {
454 return;
455 }
456 sp<CachingDriver> driver = new CachingDriver(kDeviceName, ErrorStatus::NONE, kNumModelCache,
457 kNumDataCache, kErrorStatusPrepareFromCache);
458 compileModel(driver, /*withToken=*/false);
459
460 // When no NDK token is provided by the client, the runtime should never call
461 // prepareModelFromCache_1_3 or request caching with prepareModel_1_3.
462 EXPECT_FALSE(driver->hasCalledPrepareModelFromCache());
463 EXPECT_EQ(driver->hasCalledPrepareModel(), HasCalledPrepareModel::WITHOUT_CACHING);
464 }
465
466 static const auto kErrorStatusGetNumCacheFilesChoices =
467 testing::Values(ErrorStatus::NONE, ErrorStatus::DEVICE_UNAVAILABLE);
468 static const auto kNumCacheChoices =
469 testing::Values(0ul, 1ul, static_cast<uint32_t>(Constant::MAX_NUMBER_OF_CACHE_FILES),
470 static_cast<uint32_t>(Constant::MAX_NUMBER_OF_CACHE_FILES) + 1);
471 static const auto kNumValidCacheChoices =
472 testing::Values(0ul, 1ul, static_cast<uint32_t>(Constant::MAX_NUMBER_OF_CACHE_FILES));
473 static const auto kErrorStatusPrepareFromCacheChoices =
474 testing::Values(ErrorStatus::NONE, ErrorStatus::GENERAL_FAILURE,
475 ErrorStatus::DEVICE_UNAVAILABLE, ErrorStatus::INVALID_ARGUMENT);
476
477 INSTANTIATE_TEST_CASE_P(TestCompilationCaching, DeviceRegistrationTest,
478 testing::Combine(kErrorStatusGetNumCacheFilesChoices, kNumCacheChoices,
479 kNumCacheChoices));
480
481 INSTANTIATE_TEST_CASE_P(TestCompilationCaching, CompilationCachingTest,
482 testing::Combine(kNumValidCacheChoices, kNumValidCacheChoices,
483 kErrorStatusPrepareFromCacheChoices));
484
485 } // namespace
486