1 /*
2  * Copyright (C) 2020 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include <gtest/gtest.h>
18 
19 #include <memory>
20 #include <vector>
21 
22 #include "CompilationBuilder.h"
23 #include "ExecutionPlan.h"
24 #include "Manager.h"
25 #include "SampleDriverPartial.h"
26 #include "TestNeuralNetworksWrapper.h"
27 
28 namespace android::nn {
29 namespace {
30 
31 using namespace hal;
32 using sample_driver::SampleDriverPartial;
33 using Result = test_wrapper::Result;
34 using WrapperOperandType = test_wrapper::OperandType;
35 using WrapperCompilation = test_wrapper::Compilation;
36 using WrapperExecution = test_wrapper::Execution;
37 using WrapperType = test_wrapper::Type;
38 using WrapperModel = test_wrapper::Model;
39 
40 class EmptyOperationResolver : public IOperationResolver {
41    public:
findOperation(OperationType) const42     const OperationRegistration* findOperation(OperationType) const override { return nullptr; }
43 };
44 
45 const char* kTestDriverName = "nnapi-test-sqrt-failing";
46 
47 // A driver that only supports SQRT and fails during execution.
48 class FailingTestDriver : public SampleDriverPartial {
49    public:
50     // EmptyOperationResolver causes execution to fail.
FailingTestDriver()51     FailingTestDriver() : SampleDriverPartial(kTestDriverName, &mEmptyOperationResolver) {}
52 
getCapabilities_1_3(getCapabilities_1_3_cb cb)53     Return<void> getCapabilities_1_3(getCapabilities_1_3_cb cb) override {
54         cb(V1_3::ErrorStatus::NONE,
55            {.operandPerformance = {{.type = OperandType::TENSOR_FLOAT32,
56                                     .info = {.execTime = 0.1,  // Faster than CPU.
57                                              .powerUsage = 0.1}}}});
58         return Void();
59     }
60 
61    private:
getSupportedOperationsImpl(const Model & model) const62     std::vector<bool> getSupportedOperationsImpl(const Model& model) const override {
63         std::vector<bool> supported(model.main.operations.size());
64         std::transform(
65                 model.main.operations.begin(), model.main.operations.end(), supported.begin(),
66                 [](const Operation& operation) { return operation.type == OperationType::SQRT; });
67         return supported;
68     }
69 
70     const EmptyOperationResolver mEmptyOperationResolver;
71 };
72 
73 class FailingDriverTest : public ::testing::Test {
SetUp()74     virtual void SetUp() {
75         DeviceManager* deviceManager = DeviceManager::get();
76         if (deviceManager->getUseCpuOnly() ||
77             !DeviceManager::partitioningAllowsFallback(deviceManager->getPartitioning())) {
78             GTEST_SKIP();
79         }
80         mTestDevice =
81                 DeviceManager::forTest_makeDriverDevice(kTestDriverName, new FailingTestDriver());
82         deviceManager->forTest_setDevices({
83                 mTestDevice,
84                 DeviceManager::getCpuDevice(),
85         });
86     }
87 
TearDown()88     virtual void TearDown() { DeviceManager::get()->forTest_reInitializeDeviceList(); }
89 
90    protected:
91     std::shared_ptr<Device> mTestDevice;
92 };
93 
94 // Regression test for b/152623150.
TEST_F(FailingDriverTest,FailAfterInterpretedWhile)95 TEST_F(FailingDriverTest, FailAfterInterpretedWhile) {
96     // Model:
97     //     f = input0
98     //     b = input1
99     //     while CAST(b):  # Identity cast.
100     //         f = CAST(f)
101     //     # FailingTestDriver fails here. When partial CPU fallback happens,
102     //     # it should not loop forever.
103     //     output0 = SQRT(f)
104 
105     WrapperOperandType floatType(WrapperType::TENSOR_FLOAT32, {2});
106     WrapperOperandType boolType(WrapperType::TENSOR_BOOL8, {1});
107 
108     WrapperModel conditionModel;
109     {
110         uint32_t f = conditionModel.addOperand(&floatType);
111         uint32_t b = conditionModel.addOperand(&boolType);
112         uint32_t out = conditionModel.addOperand(&boolType);
113         conditionModel.addOperation(ANEURALNETWORKS_CAST, {b}, {out});
114         conditionModel.identifyInputsAndOutputs({f, b}, {out});
115         ASSERT_EQ(conditionModel.finish(), Result::NO_ERROR);
116         ASSERT_TRUE(conditionModel.isValid());
117     }
118 
119     WrapperModel bodyModel;
120     {
121         uint32_t f = bodyModel.addOperand(&floatType);
122         uint32_t b = bodyModel.addOperand(&boolType);
123         uint32_t out = bodyModel.addOperand(&floatType);
124         bodyModel.addOperation(ANEURALNETWORKS_CAST, {f}, {out});
125         bodyModel.identifyInputsAndOutputs({f, b}, {out});
126         ASSERT_EQ(bodyModel.finish(), Result::NO_ERROR);
127         ASSERT_TRUE(bodyModel.isValid());
128     }
129 
130     WrapperModel model;
131     {
132         uint32_t fInput = model.addOperand(&floatType);
133         uint32_t bInput = model.addOperand(&boolType);
134         uint32_t fTmp = model.addOperand(&floatType);
135         uint32_t fSqrt = model.addOperand(&floatType);
136         uint32_t cond = model.addModelOperand(&conditionModel);
137         uint32_t body = model.addModelOperand(&bodyModel);
138         model.addOperation(ANEURALNETWORKS_WHILE, {cond, body, fInput, bInput}, {fTmp});
139         model.addOperation(ANEURALNETWORKS_SQRT, {fTmp}, {fSqrt});
140         model.identifyInputsAndOutputs({fInput, bInput}, {fSqrt});
141         ASSERT_TRUE(model.isValid());
142         ASSERT_EQ(model.finish(), Result::NO_ERROR);
143     }
144 
145     WrapperCompilation compilation(&model);
146     ASSERT_EQ(compilation.finish(), Result::NO_ERROR);
147 
148     const CompilationBuilder* compilationBuilder =
149             reinterpret_cast<CompilationBuilder*>(compilation.getHandle());
150     const ExecutionPlan& plan = compilationBuilder->forTest_getExecutionPlan();
151     const std::vector<std::shared_ptr<LogicalStep>>& steps = plan.forTest_compoundGetSteps();
152     ASSERT_EQ(steps.size(), 6u);
153     ASSERT_TRUE(steps[0]->isWhile());
154     ASSERT_TRUE(steps[1]->isExecution());
155     ASSERT_EQ(steps[1]->executionStep()->getDevice(), DeviceManager::getCpuDevice());
156     ASSERT_TRUE(steps[2]->isGoto());
157     ASSERT_TRUE(steps[3]->isExecution());
158     ASSERT_EQ(steps[3]->executionStep()->getDevice(), DeviceManager::getCpuDevice());
159     ASSERT_TRUE(steps[4]->isGoto());
160     ASSERT_TRUE(steps[5]->isExecution());
161     ASSERT_EQ(steps[5]->executionStep()->getDevice(), mTestDevice);
162 
163     WrapperExecution execution(&compilation);
164     const float fInput[] = {12 * 12, 5 * 5};
165     const bool8 bInput = false;
166     float fSqrt[] = {0, 0};
167     ASSERT_EQ(execution.setInput(0, &fInput), Result::NO_ERROR);
168     ASSERT_EQ(execution.setInput(1, &bInput), Result::NO_ERROR);
169     ASSERT_EQ(execution.setOutput(0, &fSqrt), Result::NO_ERROR);
170     ASSERT_EQ(execution.compute(), Result::NO_ERROR);
171     ASSERT_EQ(fSqrt[0], 12);
172     ASSERT_EQ(fSqrt[1], 5);
173 }
174 
175 // Regression test for b/155923033.
TEST_F(FailingDriverTest,SimplePlan)176 TEST_F(FailingDriverTest, SimplePlan) {
177     // Model:
178     //     output0 = SQRT(input0)
179     //
180     // This results in a SIMPLE execution plan. When FailingTestDriver fails,
181     // partial CPU fallback should complete the execution.
182 
183     WrapperOperandType floatType(WrapperType::TENSOR_FLOAT32, {2});
184 
185     WrapperModel model;
186     {
187         uint32_t fInput = model.addOperand(&floatType);
188         uint32_t fSqrt = model.addOperand(&floatType);
189         model.addOperation(ANEURALNETWORKS_SQRT, {fInput}, {fSqrt});
190         model.identifyInputsAndOutputs({fInput}, {fSqrt});
191         ASSERT_TRUE(model.isValid());
192         ASSERT_EQ(model.finish(), Result::NO_ERROR);
193     }
194 
195     WrapperCompilation compilation(&model);
196     ASSERT_EQ(compilation.finish(), Result::NO_ERROR);
197 
198     const CompilationBuilder* compilationBuilder =
199             reinterpret_cast<CompilationBuilder*>(compilation.getHandle());
200     const ExecutionPlan& plan = compilationBuilder->forTest_getExecutionPlan();
201     ASSERT_TRUE(plan.isSimple());
202 
203     WrapperExecution execution(&compilation);
204     const float fInput[] = {12 * 12, 5 * 5};
205     float fSqrt[] = {0, 0};
206     ASSERT_EQ(execution.setInput(0, &fInput), Result::NO_ERROR);
207     ASSERT_EQ(execution.setOutput(0, &fSqrt), Result::NO_ERROR);
208     ASSERT_EQ(execution.compute(), Result::NO_ERROR);
209     ASSERT_EQ(fSqrt[0], 12);
210     ASSERT_EQ(fSqrt[1], 5);
211 }
212 
213 }  // namespace
214 }  // namespace android::nn
215