1 /*
2  * Copyright (C) 2019 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #define LOG_TAG "neuralnetworks_hidl_hal_test"
18 
19 #include "VtsHalNeuralnetworks.h"
20 
21 #include "1.2/Callbacks.h"
22 #include "ExecutionBurstController.h"
23 #include "ExecutionBurstServer.h"
24 #include "GeneratedTestHarness.h"
25 #include "TestHarness.h"
26 
27 #include <android-base/logging.h>
28 #include <chrono>
29 #include <cstring>
30 
31 namespace android::hardware::neuralnetworks::V1_3::vts::functional {
32 
33 using nn::ExecutionBurstController;
34 using nn::RequestChannelSender;
35 using nn::ResultChannelReceiver;
36 using V1_0::Request;
37 using V1_2::FmqRequestDatum;
38 using V1_2::FmqResultDatum;
39 using V1_2::IBurstCallback;
40 using V1_2::IBurstContext;
41 using V1_2::MeasureTiming;
42 using V1_2::Timing;
43 using ExecutionBurstCallback = ExecutionBurstController::ExecutionBurstCallback;
44 
45 using BurstExecutionMutation = std::function<void(std::vector<FmqRequestDatum>*)>;
46 
47 // This constant value represents the length of an FMQ that is large enough to
48 // return a result from a burst execution for all of the generated test cases.
49 constexpr size_t kExecutionBurstChannelLength = 1024;
50 
51 // This constant value represents a length of an FMQ that is not large enough
52 // to return a result from a burst execution for some of the generated test
53 // cases.
54 constexpr size_t kExecutionBurstChannelSmallLength = 8;
55 
56 ///////////////////////// UTILITY FUNCTIONS /////////////////////////
57 
badTiming(Timing timing)58 static bool badTiming(Timing timing) {
59     return timing.timeOnDevice == UINT64_MAX && timing.timeInDriver == UINT64_MAX;
60 }
61 
createBurst(const sp<IPreparedModel> & preparedModel,const sp<IBurstCallback> & callback,std::unique_ptr<RequestChannelSender> * sender,std::unique_ptr<ResultChannelReceiver> * receiver,sp<IBurstContext> * context,size_t resultChannelLength=kExecutionBurstChannelLength)62 static void createBurst(const sp<IPreparedModel>& preparedModel, const sp<IBurstCallback>& callback,
63                         std::unique_ptr<RequestChannelSender>* sender,
64                         std::unique_ptr<ResultChannelReceiver>* receiver,
65                         sp<IBurstContext>* context,
66                         size_t resultChannelLength = kExecutionBurstChannelLength) {
67     ASSERT_NE(nullptr, preparedModel.get());
68     ASSERT_NE(nullptr, sender);
69     ASSERT_NE(nullptr, receiver);
70     ASSERT_NE(nullptr, context);
71 
72     // create FMQ objects
73     auto [fmqRequestChannel, fmqRequestDescriptor] =
74             RequestChannelSender::create(kExecutionBurstChannelLength);
75     auto [fmqResultChannel, fmqResultDescriptor] =
76             ResultChannelReceiver::create(resultChannelLength, std::chrono::microseconds{0});
77     ASSERT_NE(nullptr, fmqRequestChannel.get());
78     ASSERT_NE(nullptr, fmqResultChannel.get());
79     ASSERT_NE(nullptr, fmqRequestDescriptor);
80     ASSERT_NE(nullptr, fmqResultDescriptor);
81 
82     // configure burst
83     V1_0::ErrorStatus errorStatus;
84     sp<IBurstContext> burstContext;
85     const Return<void> ret = preparedModel->configureExecutionBurst(
86             callback, *fmqRequestDescriptor, *fmqResultDescriptor,
87             [&errorStatus, &burstContext](V1_0::ErrorStatus status,
88                                           const sp<IBurstContext>& context) {
89                 errorStatus = status;
90                 burstContext = context;
91             });
92     ASSERT_TRUE(ret.isOk());
93     ASSERT_EQ(V1_0::ErrorStatus::NONE, errorStatus);
94     ASSERT_NE(nullptr, burstContext.get());
95 
96     // return values
97     *sender = std::move(fmqRequestChannel);
98     *receiver = std::move(fmqResultChannel);
99     *context = burstContext;
100 }
101 
createBurstWithResultChannelLength(const sp<IPreparedModel> & preparedModel,size_t resultChannelLength,std::shared_ptr<ExecutionBurstController> * controller)102 static void createBurstWithResultChannelLength(
103         const sp<IPreparedModel>& preparedModel, size_t resultChannelLength,
104         std::shared_ptr<ExecutionBurstController>* controller) {
105     ASSERT_NE(nullptr, preparedModel.get());
106     ASSERT_NE(nullptr, controller);
107 
108     // create FMQ objects
109     std::unique_ptr<RequestChannelSender> sender;
110     std::unique_ptr<ResultChannelReceiver> receiver;
111     sp<ExecutionBurstCallback> callback = new ExecutionBurstCallback();
112     sp<IBurstContext> context;
113     ASSERT_NO_FATAL_FAILURE(createBurst(preparedModel, callback, &sender, &receiver, &context,
114                                         resultChannelLength));
115     ASSERT_NE(nullptr, sender.get());
116     ASSERT_NE(nullptr, receiver.get());
117     ASSERT_NE(nullptr, context.get());
118 
119     // return values
120     *controller = std::make_shared<ExecutionBurstController>(std::move(sender), std::move(receiver),
121                                                              context, callback);
122 }
123 
124 // Primary validation function. This function will take a valid serialized
125 // request, apply a mutation to it to invalidate the serialized request, then
126 // pass it to interface calls that use the serialized request.
validate(RequestChannelSender * sender,ResultChannelReceiver * receiver,const std::string & message,const std::vector<FmqRequestDatum> & originalSerialized,const BurstExecutionMutation & mutate)127 static void validate(RequestChannelSender* sender, ResultChannelReceiver* receiver,
128                      const std::string& message,
129                      const std::vector<FmqRequestDatum>& originalSerialized,
130                      const BurstExecutionMutation& mutate) {
131     std::vector<FmqRequestDatum> serialized = originalSerialized;
132     mutate(&serialized);
133 
134     // skip if packet is too large to send
135     if (serialized.size() > kExecutionBurstChannelLength) {
136         return;
137     }
138 
139     SCOPED_TRACE(message);
140 
141     // send invalid packet
142     ASSERT_TRUE(sender->sendPacket(serialized));
143 
144     // receive error
145     auto results = receiver->getBlocking();
146     ASSERT_TRUE(results.has_value());
147     const auto [status, outputShapes, timing] = std::move(*results);
148     EXPECT_NE(V1_0::ErrorStatus::NONE, status);
149     EXPECT_EQ(0u, outputShapes.size());
150     EXPECT_TRUE(badTiming(timing));
151 }
152 
153 // For validation, valid packet entries are mutated to invalid packet entries,
154 // or invalid packet entries are inserted into valid packets. This function
155 // creates pre-set invalid packet entries for convenience.
createBadRequestPacketEntries()156 static std::vector<FmqRequestDatum> createBadRequestPacketEntries() {
157     const FmqRequestDatum::PacketInformation packetInformation = {
158             /*.packetSize=*/10, /*.numberOfInputOperands=*/10, /*.numberOfOutputOperands=*/10,
159             /*.numberOfPools=*/10};
160     const FmqRequestDatum::OperandInformation operandInformation = {
161             /*.hasNoValue=*/false, /*.location=*/{}, /*.numberOfDimensions=*/10};
162     const int32_t invalidPoolIdentifier = std::numeric_limits<int32_t>::max();
163     std::vector<FmqRequestDatum> bad(7);
164     bad[0].packetInformation(packetInformation);
165     bad[1].inputOperandInformation(operandInformation);
166     bad[2].inputOperandDimensionValue(0);
167     bad[3].outputOperandInformation(operandInformation);
168     bad[4].outputOperandDimensionValue(0);
169     bad[5].poolIdentifier(invalidPoolIdentifier);
170     bad[6].measureTiming(MeasureTiming::YES);
171     return bad;
172 }
173 
174 // For validation, valid packet entries are mutated to invalid packet entries,
175 // or invalid packet entries are inserted into valid packets. This function
176 // retrieves pre-set invalid packet entries for convenience. This function
177 // caches these data so they can be reused on subsequent validation checks.
getBadRequestPacketEntries()178 static const std::vector<FmqRequestDatum>& getBadRequestPacketEntries() {
179     static const std::vector<FmqRequestDatum> bad = createBadRequestPacketEntries();
180     return bad;
181 }
182 
183 ///////////////////////// REMOVE DATUM ////////////////////////////////////
184 
removeDatumTest(RequestChannelSender * sender,ResultChannelReceiver * receiver,const std::vector<FmqRequestDatum> & serialized)185 static void removeDatumTest(RequestChannelSender* sender, ResultChannelReceiver* receiver,
186                             const std::vector<FmqRequestDatum>& serialized) {
187     for (size_t index = 0; index < serialized.size(); ++index) {
188         const std::string message = "removeDatum: removed datum at index " + std::to_string(index);
189         validate(sender, receiver, message, serialized,
190                  [index](std::vector<FmqRequestDatum>* serialized) {
191                      serialized->erase(serialized->begin() + index);
192                  });
193     }
194 }
195 
196 ///////////////////////// ADD DATUM ////////////////////////////////////
197 
addDatumTest(RequestChannelSender * sender,ResultChannelReceiver * receiver,const std::vector<FmqRequestDatum> & serialized)198 static void addDatumTest(RequestChannelSender* sender, ResultChannelReceiver* receiver,
199                          const std::vector<FmqRequestDatum>& serialized) {
200     const std::vector<FmqRequestDatum>& extra = getBadRequestPacketEntries();
201     for (size_t index = 0; index <= serialized.size(); ++index) {
202         for (size_t type = 0; type < extra.size(); ++type) {
203             const std::string message = "addDatum: added datum type " + std::to_string(type) +
204                                         " at index " + std::to_string(index);
205             validate(sender, receiver, message, serialized,
206                      [index, type, &extra](std::vector<FmqRequestDatum>* serialized) {
207                          serialized->insert(serialized->begin() + index, extra[type]);
208                      });
209         }
210     }
211 }
212 
213 ///////////////////////// MUTATE DATUM ////////////////////////////////////
214 
interestingCase(const FmqRequestDatum & lhs,const FmqRequestDatum & rhs)215 static bool interestingCase(const FmqRequestDatum& lhs, const FmqRequestDatum& rhs) {
216     using Discriminator = FmqRequestDatum::hidl_discriminator;
217 
218     const bool differentValues = (lhs != rhs);
219     const bool sameDiscriminator = (lhs.getDiscriminator() == rhs.getDiscriminator());
220     const auto discriminator = rhs.getDiscriminator();
221     const bool isDimensionValue = (discriminator == Discriminator::inputOperandDimensionValue ||
222                                    discriminator == Discriminator::outputOperandDimensionValue);
223 
224     return differentValues && !(sameDiscriminator && isDimensionValue);
225 }
226 
mutateDatumTest(RequestChannelSender * sender,ResultChannelReceiver * receiver,const std::vector<FmqRequestDatum> & serialized)227 static void mutateDatumTest(RequestChannelSender* sender, ResultChannelReceiver* receiver,
228                             const std::vector<FmqRequestDatum>& serialized) {
229     const std::vector<FmqRequestDatum>& change = getBadRequestPacketEntries();
230     for (size_t index = 0; index < serialized.size(); ++index) {
231         for (size_t type = 0; type < change.size(); ++type) {
232             if (interestingCase(serialized[index], change[type])) {
233                 const std::string message = "mutateDatum: changed datum at index " +
234                                             std::to_string(index) + " to datum type " +
235                                             std::to_string(type);
236                 validate(sender, receiver, message, serialized,
237                          [index, type, &change](std::vector<FmqRequestDatum>* serialized) {
238                              (*serialized)[index] = change[type];
239                          });
240             }
241         }
242     }
243 }
244 
245 ///////////////////////// BURST VALIATION TESTS ////////////////////////////////////
246 
validateBurstSerialization(const sp<IPreparedModel> & preparedModel,const Request & request)247 static void validateBurstSerialization(const sp<IPreparedModel>& preparedModel,
248                                        const Request& request) {
249     // create burst
250     std::unique_ptr<RequestChannelSender> sender;
251     std::unique_ptr<ResultChannelReceiver> receiver;
252     sp<ExecutionBurstCallback> callback = new ExecutionBurstCallback();
253     sp<IBurstContext> context;
254     ASSERT_NO_FATAL_FAILURE(createBurst(preparedModel, callback, &sender, &receiver, &context));
255     ASSERT_NE(nullptr, sender.get());
256     ASSERT_NE(nullptr, receiver.get());
257     ASSERT_NE(nullptr, context.get());
258 
259     // load memory into callback slots
260     std::vector<intptr_t> keys;
261     keys.reserve(request.pools.size());
262     std::transform(request.pools.begin(), request.pools.end(), std::back_inserter(keys),
263                    [](const auto& pool) { return reinterpret_cast<intptr_t>(&pool); });
264     const std::vector<int32_t> slots = callback->getSlots(request.pools, keys);
265 
266     // ensure slot std::numeric_limits<int32_t>::max() doesn't exist (for
267     // subsequent slot validation testing)
268     ASSERT_TRUE(std::all_of(slots.begin(), slots.end(), [](int32_t slot) {
269         return slot != std::numeric_limits<int32_t>::max();
270     }));
271 
272     // serialize the request
273     const auto serialized = android::nn::serialize(request, MeasureTiming::YES, slots);
274 
275     // validations
276     removeDatumTest(sender.get(), receiver.get(), serialized);
277     addDatumTest(sender.get(), receiver.get(), serialized);
278     mutateDatumTest(sender.get(), receiver.get(), serialized);
279 }
280 
281 // This test validates that when the Result message size exceeds length of the
282 // result FMQ, the service instance gracefully fails and returns an error.
validateBurstFmqLength(const sp<IPreparedModel> & preparedModel,const Request & request)283 static void validateBurstFmqLength(const sp<IPreparedModel>& preparedModel,
284                                    const Request& request) {
285     // create regular burst
286     std::shared_ptr<ExecutionBurstController> controllerRegular;
287     ASSERT_NO_FATAL_FAILURE(createBurstWithResultChannelLength(
288             preparedModel, kExecutionBurstChannelLength, &controllerRegular));
289     ASSERT_NE(nullptr, controllerRegular.get());
290 
291     // create burst with small output channel
292     std::shared_ptr<ExecutionBurstController> controllerSmall;
293     ASSERT_NO_FATAL_FAILURE(createBurstWithResultChannelLength(
294             preparedModel, kExecutionBurstChannelSmallLength, &controllerSmall));
295     ASSERT_NE(nullptr, controllerSmall.get());
296 
297     // load memory into callback slots
298     std::vector<intptr_t> keys(request.pools.size());
299     for (size_t i = 0; i < keys.size(); ++i) {
300         keys[i] = reinterpret_cast<intptr_t>(&request.pools[i]);
301     }
302 
303     // collect serialized result by running regular burst
304     const auto [nRegular, outputShapesRegular, timingRegular, fallbackRegular] =
305             controllerRegular->compute(request, MeasureTiming::NO, keys);
306     const V1_0::ErrorStatus statusRegular = nn::legacyConvertResultCodeToErrorStatus(nRegular);
307     EXPECT_FALSE(fallbackRegular);
308 
309     // skip test if regular burst output isn't useful for testing a failure
310     // caused by having too small of a length for the result FMQ
311     const std::vector<FmqResultDatum> serialized =
312             android::nn::serialize(statusRegular, outputShapesRegular, timingRegular);
313     if (statusRegular != V1_0::ErrorStatus::NONE ||
314         serialized.size() <= kExecutionBurstChannelSmallLength) {
315         return;
316     }
317 
318     // by this point, execution should fail because the result channel isn't
319     // large enough to return the serialized result
320     const auto [nSmall, outputShapesSmall, timingSmall, fallbackSmall] =
321             controllerSmall->compute(request, MeasureTiming::NO, keys);
322     const V1_0::ErrorStatus statusSmall = nn::legacyConvertResultCodeToErrorStatus(nSmall);
323     EXPECT_NE(V1_0::ErrorStatus::NONE, statusSmall);
324     EXPECT_EQ(0u, outputShapesSmall.size());
325     EXPECT_TRUE(badTiming(timingSmall));
326     EXPECT_FALSE(fallbackSmall);
327 }
328 
isSanitized(const FmqResultDatum & datum)329 static bool isSanitized(const FmqResultDatum& datum) {
330     using Discriminator = FmqResultDatum::hidl_discriminator;
331 
332     // check to ensure the padding values in the returned
333     // FmqResultDatum::OperandInformation are initialized to 0
334     if (datum.getDiscriminator() == Discriminator::operandInformation) {
335         static_assert(
336                 offsetof(FmqResultDatum::OperandInformation, isSufficient) == 0,
337                 "unexpected value for offset of FmqResultDatum::OperandInformation::isSufficient");
338         static_assert(
339                 sizeof(FmqResultDatum::OperandInformation::isSufficient) == 1,
340                 "unexpected value for size of FmqResultDatum::OperandInformation::isSufficient");
341         static_assert(offsetof(FmqResultDatum::OperandInformation, numberOfDimensions) == 4,
342                       "unexpected value for offset of "
343                       "FmqResultDatum::OperandInformation::numberOfDimensions");
344         static_assert(sizeof(FmqResultDatum::OperandInformation::numberOfDimensions) == 4,
345                       "unexpected value for size of "
346                       "FmqResultDatum::OperandInformation::numberOfDimensions");
347         static_assert(sizeof(FmqResultDatum::OperandInformation) == 8,
348                       "unexpected value for size of "
349                       "FmqResultDatum::OperandInformation");
350 
351         constexpr size_t paddingOffset =
352                 offsetof(FmqResultDatum::OperandInformation, isSufficient) +
353                 sizeof(FmqResultDatum::OperandInformation::isSufficient);
354         constexpr size_t paddingSize =
355                 offsetof(FmqResultDatum::OperandInformation, numberOfDimensions) - paddingOffset;
356 
357         FmqResultDatum::OperandInformation initialized{};
358         std::memset(&initialized, 0, sizeof(initialized));
359 
360         const char* initializedPaddingStart =
361                 reinterpret_cast<const char*>(&initialized) + paddingOffset;
362         const char* datumPaddingStart =
363                 reinterpret_cast<const char*>(&datum.operandInformation()) + paddingOffset;
364 
365         return std::memcmp(datumPaddingStart, initializedPaddingStart, paddingSize) == 0;
366     }
367 
368     // there are no other padding initialization checks required, so return true
369     // for any sum-type that isn't FmqResultDatum::OperandInformation
370     return true;
371 }
372 
validateBurstSanitized(const sp<IPreparedModel> & preparedModel,const Request & request)373 static void validateBurstSanitized(const sp<IPreparedModel>& preparedModel,
374                                    const Request& request) {
375     // create burst
376     std::unique_ptr<RequestChannelSender> sender;
377     std::unique_ptr<ResultChannelReceiver> receiver;
378     sp<ExecutionBurstCallback> callback = new ExecutionBurstCallback();
379     sp<IBurstContext> context;
380     ASSERT_NO_FATAL_FAILURE(createBurst(preparedModel, callback, &sender, &receiver, &context));
381     ASSERT_NE(nullptr, sender.get());
382     ASSERT_NE(nullptr, receiver.get());
383     ASSERT_NE(nullptr, context.get());
384 
385     // load memory into callback slots
386     std::vector<intptr_t> keys;
387     keys.reserve(request.pools.size());
388     std::transform(request.pools.begin(), request.pools.end(), std::back_inserter(keys),
389                    [](const auto& pool) { return reinterpret_cast<intptr_t>(&pool); });
390     const std::vector<int32_t> slots = callback->getSlots(request.pools, keys);
391 
392     // send valid request
393     ASSERT_TRUE(sender->send(request, MeasureTiming::YES, slots));
394 
395     // receive valid result
396     auto serialized = receiver->getPacketBlocking();
397     ASSERT_TRUE(serialized.has_value());
398 
399     // sanitize result
400     ASSERT_TRUE(std::all_of(serialized->begin(), serialized->end(), isSanitized))
401             << "The result serialized data is not properly sanitized";
402 }
403 
404 ///////////////////////////// ENTRY POINT //////////////////////////////////
405 
validateBurst(const sp<IPreparedModel> & preparedModel,const Request & request)406 void validateBurst(const sp<IPreparedModel>& preparedModel, const Request& request) {
407     ASSERT_NO_FATAL_FAILURE(validateBurstSerialization(preparedModel, request));
408     ASSERT_NO_FATAL_FAILURE(validateBurstFmqLength(preparedModel, request));
409     ASSERT_NO_FATAL_FAILURE(validateBurstSanitized(preparedModel, request));
410 }
411 
412 }  // namespace android::hardware::neuralnetworks::V1_3::vts::functional
413