1 /*
2 * Copyright (C) 2020 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include <gtest/gtest.h>
18
19 #include <memory>
20 #include <vector>
21
22 #include "CompilationBuilder.h"
23 #include "ExecutionPlan.h"
24 #include "Manager.h"
25 #include "SampleDriverPartial.h"
26 #include "TestNeuralNetworksWrapper.h"
27
28 namespace android::nn {
29 namespace {
30
31 using namespace hal;
32 using sample_driver::SampleDriverPartial;
33 using Result = test_wrapper::Result;
34 using WrapperOperandType = test_wrapper::OperandType;
35 using WrapperCompilation = test_wrapper::Compilation;
36 using WrapperExecution = test_wrapper::Execution;
37 using WrapperType = test_wrapper::Type;
38 using WrapperModel = test_wrapper::Model;
39
40 class EmptyOperationResolver : public IOperationResolver {
41 public:
findOperation(OperationType) const42 const OperationRegistration* findOperation(OperationType) const override { return nullptr; }
43 };
44
45 const char* kTestDriverName = "nnapi-test-sqrt-failing";
46
47 // A driver that only supports SQRT and fails during execution.
48 class FailingTestDriver : public SampleDriverPartial {
49 public:
50 // EmptyOperationResolver causes execution to fail.
FailingTestDriver()51 FailingTestDriver() : SampleDriverPartial(kTestDriverName, &mEmptyOperationResolver) {}
52
getCapabilities_1_3(getCapabilities_1_3_cb cb)53 Return<void> getCapabilities_1_3(getCapabilities_1_3_cb cb) override {
54 cb(V1_3::ErrorStatus::NONE,
55 {.operandPerformance = {{.type = OperandType::TENSOR_FLOAT32,
56 .info = {.execTime = 0.1, // Faster than CPU.
57 .powerUsage = 0.1}}}});
58 return Void();
59 }
60
61 private:
getSupportedOperationsImpl(const Model & model) const62 std::vector<bool> getSupportedOperationsImpl(const Model& model) const override {
63 std::vector<bool> supported(model.main.operations.size());
64 std::transform(
65 model.main.operations.begin(), model.main.operations.end(), supported.begin(),
66 [](const Operation& operation) { return operation.type == OperationType::SQRT; });
67 return supported;
68 }
69
70 const EmptyOperationResolver mEmptyOperationResolver;
71 };
72
73 class FailingDriverTest : public ::testing::Test {
SetUp()74 virtual void SetUp() {
75 DeviceManager* deviceManager = DeviceManager::get();
76 if (deviceManager->getUseCpuOnly() ||
77 !DeviceManager::partitioningAllowsFallback(deviceManager->getPartitioning())) {
78 GTEST_SKIP();
79 }
80 mTestDevice =
81 DeviceManager::forTest_makeDriverDevice(kTestDriverName, new FailingTestDriver());
82 deviceManager->forTest_setDevices({
83 mTestDevice,
84 DeviceManager::getCpuDevice(),
85 });
86 }
87
TearDown()88 virtual void TearDown() { DeviceManager::get()->forTest_reInitializeDeviceList(); }
89
90 protected:
91 std::shared_ptr<Device> mTestDevice;
92 };
93
94 // Regression test for b/152623150.
TEST_F(FailingDriverTest,FailAfterInterpretedWhile)95 TEST_F(FailingDriverTest, FailAfterInterpretedWhile) {
96 // Model:
97 // f = input0
98 // b = input1
99 // while CAST(b): # Identity cast.
100 // f = CAST(f)
101 // # FailingTestDriver fails here. When partial CPU fallback happens,
102 // # it should not loop forever.
103 // output0 = SQRT(f)
104
105 WrapperOperandType floatType(WrapperType::TENSOR_FLOAT32, {2});
106 WrapperOperandType boolType(WrapperType::TENSOR_BOOL8, {1});
107
108 WrapperModel conditionModel;
109 {
110 uint32_t f = conditionModel.addOperand(&floatType);
111 uint32_t b = conditionModel.addOperand(&boolType);
112 uint32_t out = conditionModel.addOperand(&boolType);
113 conditionModel.addOperation(ANEURALNETWORKS_CAST, {b}, {out});
114 conditionModel.identifyInputsAndOutputs({f, b}, {out});
115 ASSERT_EQ(conditionModel.finish(), Result::NO_ERROR);
116 ASSERT_TRUE(conditionModel.isValid());
117 }
118
119 WrapperModel bodyModel;
120 {
121 uint32_t f = bodyModel.addOperand(&floatType);
122 uint32_t b = bodyModel.addOperand(&boolType);
123 uint32_t out = bodyModel.addOperand(&floatType);
124 bodyModel.addOperation(ANEURALNETWORKS_CAST, {f}, {out});
125 bodyModel.identifyInputsAndOutputs({f, b}, {out});
126 ASSERT_EQ(bodyModel.finish(), Result::NO_ERROR);
127 ASSERT_TRUE(bodyModel.isValid());
128 }
129
130 WrapperModel model;
131 {
132 uint32_t fInput = model.addOperand(&floatType);
133 uint32_t bInput = model.addOperand(&boolType);
134 uint32_t fTmp = model.addOperand(&floatType);
135 uint32_t fSqrt = model.addOperand(&floatType);
136 uint32_t cond = model.addModelOperand(&conditionModel);
137 uint32_t body = model.addModelOperand(&bodyModel);
138 model.addOperation(ANEURALNETWORKS_WHILE, {cond, body, fInput, bInput}, {fTmp});
139 model.addOperation(ANEURALNETWORKS_SQRT, {fTmp}, {fSqrt});
140 model.identifyInputsAndOutputs({fInput, bInput}, {fSqrt});
141 ASSERT_TRUE(model.isValid());
142 ASSERT_EQ(model.finish(), Result::NO_ERROR);
143 }
144
145 WrapperCompilation compilation(&model);
146 ASSERT_EQ(compilation.finish(), Result::NO_ERROR);
147
148 const CompilationBuilder* compilationBuilder =
149 reinterpret_cast<CompilationBuilder*>(compilation.getHandle());
150 const ExecutionPlan& plan = compilationBuilder->forTest_getExecutionPlan();
151 const std::vector<std::shared_ptr<LogicalStep>>& steps = plan.forTest_compoundGetSteps();
152 ASSERT_EQ(steps.size(), 6u);
153 ASSERT_TRUE(steps[0]->isWhile());
154 ASSERT_TRUE(steps[1]->isExecution());
155 ASSERT_EQ(steps[1]->executionStep()->getDevice(), DeviceManager::getCpuDevice());
156 ASSERT_TRUE(steps[2]->isGoto());
157 ASSERT_TRUE(steps[3]->isExecution());
158 ASSERT_EQ(steps[3]->executionStep()->getDevice(), DeviceManager::getCpuDevice());
159 ASSERT_TRUE(steps[4]->isGoto());
160 ASSERT_TRUE(steps[5]->isExecution());
161 ASSERT_EQ(steps[5]->executionStep()->getDevice(), mTestDevice);
162
163 WrapperExecution execution(&compilation);
164 const float fInput[] = {12 * 12, 5 * 5};
165 const bool8 bInput = false;
166 float fSqrt[] = {0, 0};
167 ASSERT_EQ(execution.setInput(0, &fInput), Result::NO_ERROR);
168 ASSERT_EQ(execution.setInput(1, &bInput), Result::NO_ERROR);
169 ASSERT_EQ(execution.setOutput(0, &fSqrt), Result::NO_ERROR);
170 ASSERT_EQ(execution.compute(), Result::NO_ERROR);
171 ASSERT_EQ(fSqrt[0], 12);
172 ASSERT_EQ(fSqrt[1], 5);
173 }
174
175 // Regression test for b/155923033.
TEST_F(FailingDriverTest,SimplePlan)176 TEST_F(FailingDriverTest, SimplePlan) {
177 // Model:
178 // output0 = SQRT(input0)
179 //
180 // This results in a SIMPLE execution plan. When FailingTestDriver fails,
181 // partial CPU fallback should complete the execution.
182
183 WrapperOperandType floatType(WrapperType::TENSOR_FLOAT32, {2});
184
185 WrapperModel model;
186 {
187 uint32_t fInput = model.addOperand(&floatType);
188 uint32_t fSqrt = model.addOperand(&floatType);
189 model.addOperation(ANEURALNETWORKS_SQRT, {fInput}, {fSqrt});
190 model.identifyInputsAndOutputs({fInput}, {fSqrt});
191 ASSERT_TRUE(model.isValid());
192 ASSERT_EQ(model.finish(), Result::NO_ERROR);
193 }
194
195 WrapperCompilation compilation(&model);
196 ASSERT_EQ(compilation.finish(), Result::NO_ERROR);
197
198 const CompilationBuilder* compilationBuilder =
199 reinterpret_cast<CompilationBuilder*>(compilation.getHandle());
200 const ExecutionPlan& plan = compilationBuilder->forTest_getExecutionPlan();
201 ASSERT_TRUE(plan.isSimple());
202
203 WrapperExecution execution(&compilation);
204 const float fInput[] = {12 * 12, 5 * 5};
205 float fSqrt[] = {0, 0};
206 ASSERT_EQ(execution.setInput(0, &fInput), Result::NO_ERROR);
207 ASSERT_EQ(execution.setOutput(0, &fSqrt), Result::NO_ERROR);
208 ASSERT_EQ(execution.compute(), Result::NO_ERROR);
209 ASSERT_EQ(fSqrt[0], 12);
210 ASSERT_EQ(fSqrt[1], 5);
211 }
212
213 } // namespace
214 } // namespace android::nn
215