1 /*
2 * Copyright (C) 2019 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #define LOG_TAG "neuralnetworks_hidl_hal_test"
18
19 #include "VtsHalNeuralnetworks.h"
20
21 #include "1.2/Callbacks.h"
22 #include "ExecutionBurstController.h"
23 #include "ExecutionBurstServer.h"
24 #include "GeneratedTestHarness.h"
25 #include "TestHarness.h"
26
27 #include <android-base/logging.h>
28 #include <chrono>
29 #include <cstring>
30
31 namespace android::hardware::neuralnetworks::V1_3::vts::functional {
32
33 using nn::ExecutionBurstController;
34 using nn::RequestChannelSender;
35 using nn::ResultChannelReceiver;
36 using V1_0::Request;
37 using V1_2::FmqRequestDatum;
38 using V1_2::FmqResultDatum;
39 using V1_2::IBurstCallback;
40 using V1_2::IBurstContext;
41 using V1_2::MeasureTiming;
42 using V1_2::Timing;
43 using ExecutionBurstCallback = ExecutionBurstController::ExecutionBurstCallback;
44
45 using BurstExecutionMutation = std::function<void(std::vector<FmqRequestDatum>*)>;
46
47 // This constant value represents the length of an FMQ that is large enough to
48 // return a result from a burst execution for all of the generated test cases.
49 constexpr size_t kExecutionBurstChannelLength = 1024;
50
51 // This constant value represents a length of an FMQ that is not large enough
52 // to return a result from a burst execution for some of the generated test
53 // cases.
54 constexpr size_t kExecutionBurstChannelSmallLength = 8;
55
56 ///////////////////////// UTILITY FUNCTIONS /////////////////////////
57
badTiming(Timing timing)58 static bool badTiming(Timing timing) {
59 return timing.timeOnDevice == UINT64_MAX && timing.timeInDriver == UINT64_MAX;
60 }
61
createBurst(const sp<IPreparedModel> & preparedModel,const sp<IBurstCallback> & callback,std::unique_ptr<RequestChannelSender> * sender,std::unique_ptr<ResultChannelReceiver> * receiver,sp<IBurstContext> * context,size_t resultChannelLength=kExecutionBurstChannelLength)62 static void createBurst(const sp<IPreparedModel>& preparedModel, const sp<IBurstCallback>& callback,
63 std::unique_ptr<RequestChannelSender>* sender,
64 std::unique_ptr<ResultChannelReceiver>* receiver,
65 sp<IBurstContext>* context,
66 size_t resultChannelLength = kExecutionBurstChannelLength) {
67 ASSERT_NE(nullptr, preparedModel.get());
68 ASSERT_NE(nullptr, sender);
69 ASSERT_NE(nullptr, receiver);
70 ASSERT_NE(nullptr, context);
71
72 // create FMQ objects
73 auto [fmqRequestChannel, fmqRequestDescriptor] =
74 RequestChannelSender::create(kExecutionBurstChannelLength);
75 auto [fmqResultChannel, fmqResultDescriptor] =
76 ResultChannelReceiver::create(resultChannelLength, std::chrono::microseconds{0});
77 ASSERT_NE(nullptr, fmqRequestChannel.get());
78 ASSERT_NE(nullptr, fmqResultChannel.get());
79 ASSERT_NE(nullptr, fmqRequestDescriptor);
80 ASSERT_NE(nullptr, fmqResultDescriptor);
81
82 // configure burst
83 V1_0::ErrorStatus errorStatus;
84 sp<IBurstContext> burstContext;
85 const Return<void> ret = preparedModel->configureExecutionBurst(
86 callback, *fmqRequestDescriptor, *fmqResultDescriptor,
87 [&errorStatus, &burstContext](V1_0::ErrorStatus status,
88 const sp<IBurstContext>& context) {
89 errorStatus = status;
90 burstContext = context;
91 });
92 ASSERT_TRUE(ret.isOk());
93 ASSERT_EQ(V1_0::ErrorStatus::NONE, errorStatus);
94 ASSERT_NE(nullptr, burstContext.get());
95
96 // return values
97 *sender = std::move(fmqRequestChannel);
98 *receiver = std::move(fmqResultChannel);
99 *context = burstContext;
100 }
101
createBurstWithResultChannelLength(const sp<IPreparedModel> & preparedModel,size_t resultChannelLength,std::shared_ptr<ExecutionBurstController> * controller)102 static void createBurstWithResultChannelLength(
103 const sp<IPreparedModel>& preparedModel, size_t resultChannelLength,
104 std::shared_ptr<ExecutionBurstController>* controller) {
105 ASSERT_NE(nullptr, preparedModel.get());
106 ASSERT_NE(nullptr, controller);
107
108 // create FMQ objects
109 std::unique_ptr<RequestChannelSender> sender;
110 std::unique_ptr<ResultChannelReceiver> receiver;
111 sp<ExecutionBurstCallback> callback = new ExecutionBurstCallback();
112 sp<IBurstContext> context;
113 ASSERT_NO_FATAL_FAILURE(createBurst(preparedModel, callback, &sender, &receiver, &context,
114 resultChannelLength));
115 ASSERT_NE(nullptr, sender.get());
116 ASSERT_NE(nullptr, receiver.get());
117 ASSERT_NE(nullptr, context.get());
118
119 // return values
120 *controller = std::make_shared<ExecutionBurstController>(std::move(sender), std::move(receiver),
121 context, callback);
122 }
123
124 // Primary validation function. This function will take a valid serialized
125 // request, apply a mutation to it to invalidate the serialized request, then
126 // pass it to interface calls that use the serialized request.
validate(RequestChannelSender * sender,ResultChannelReceiver * receiver,const std::string & message,const std::vector<FmqRequestDatum> & originalSerialized,const BurstExecutionMutation & mutate)127 static void validate(RequestChannelSender* sender, ResultChannelReceiver* receiver,
128 const std::string& message,
129 const std::vector<FmqRequestDatum>& originalSerialized,
130 const BurstExecutionMutation& mutate) {
131 std::vector<FmqRequestDatum> serialized = originalSerialized;
132 mutate(&serialized);
133
134 // skip if packet is too large to send
135 if (serialized.size() > kExecutionBurstChannelLength) {
136 return;
137 }
138
139 SCOPED_TRACE(message);
140
141 // send invalid packet
142 ASSERT_TRUE(sender->sendPacket(serialized));
143
144 // receive error
145 auto results = receiver->getBlocking();
146 ASSERT_TRUE(results.has_value());
147 const auto [status, outputShapes, timing] = std::move(*results);
148 EXPECT_NE(V1_0::ErrorStatus::NONE, status);
149 EXPECT_EQ(0u, outputShapes.size());
150 EXPECT_TRUE(badTiming(timing));
151 }
152
153 // For validation, valid packet entries are mutated to invalid packet entries,
154 // or invalid packet entries are inserted into valid packets. This function
155 // creates pre-set invalid packet entries for convenience.
createBadRequestPacketEntries()156 static std::vector<FmqRequestDatum> createBadRequestPacketEntries() {
157 const FmqRequestDatum::PacketInformation packetInformation = {
158 /*.packetSize=*/10, /*.numberOfInputOperands=*/10, /*.numberOfOutputOperands=*/10,
159 /*.numberOfPools=*/10};
160 const FmqRequestDatum::OperandInformation operandInformation = {
161 /*.hasNoValue=*/false, /*.location=*/{}, /*.numberOfDimensions=*/10};
162 const int32_t invalidPoolIdentifier = std::numeric_limits<int32_t>::max();
163 std::vector<FmqRequestDatum> bad(7);
164 bad[0].packetInformation(packetInformation);
165 bad[1].inputOperandInformation(operandInformation);
166 bad[2].inputOperandDimensionValue(0);
167 bad[3].outputOperandInformation(operandInformation);
168 bad[4].outputOperandDimensionValue(0);
169 bad[5].poolIdentifier(invalidPoolIdentifier);
170 bad[6].measureTiming(MeasureTiming::YES);
171 return bad;
172 }
173
174 // For validation, valid packet entries are mutated to invalid packet entries,
175 // or invalid packet entries are inserted into valid packets. This function
176 // retrieves pre-set invalid packet entries for convenience. This function
177 // caches these data so they can be reused on subsequent validation checks.
getBadRequestPacketEntries()178 static const std::vector<FmqRequestDatum>& getBadRequestPacketEntries() {
179 static const std::vector<FmqRequestDatum> bad = createBadRequestPacketEntries();
180 return bad;
181 }
182
183 ///////////////////////// REMOVE DATUM ////////////////////////////////////
184
removeDatumTest(RequestChannelSender * sender,ResultChannelReceiver * receiver,const std::vector<FmqRequestDatum> & serialized)185 static void removeDatumTest(RequestChannelSender* sender, ResultChannelReceiver* receiver,
186 const std::vector<FmqRequestDatum>& serialized) {
187 for (size_t index = 0; index < serialized.size(); ++index) {
188 const std::string message = "removeDatum: removed datum at index " + std::to_string(index);
189 validate(sender, receiver, message, serialized,
190 [index](std::vector<FmqRequestDatum>* serialized) {
191 serialized->erase(serialized->begin() + index);
192 });
193 }
194 }
195
196 ///////////////////////// ADD DATUM ////////////////////////////////////
197
addDatumTest(RequestChannelSender * sender,ResultChannelReceiver * receiver,const std::vector<FmqRequestDatum> & serialized)198 static void addDatumTest(RequestChannelSender* sender, ResultChannelReceiver* receiver,
199 const std::vector<FmqRequestDatum>& serialized) {
200 const std::vector<FmqRequestDatum>& extra = getBadRequestPacketEntries();
201 for (size_t index = 0; index <= serialized.size(); ++index) {
202 for (size_t type = 0; type < extra.size(); ++type) {
203 const std::string message = "addDatum: added datum type " + std::to_string(type) +
204 " at index " + std::to_string(index);
205 validate(sender, receiver, message, serialized,
206 [index, type, &extra](std::vector<FmqRequestDatum>* serialized) {
207 serialized->insert(serialized->begin() + index, extra[type]);
208 });
209 }
210 }
211 }
212
213 ///////////////////////// MUTATE DATUM ////////////////////////////////////
214
interestingCase(const FmqRequestDatum & lhs,const FmqRequestDatum & rhs)215 static bool interestingCase(const FmqRequestDatum& lhs, const FmqRequestDatum& rhs) {
216 using Discriminator = FmqRequestDatum::hidl_discriminator;
217
218 const bool differentValues = (lhs != rhs);
219 const bool sameDiscriminator = (lhs.getDiscriminator() == rhs.getDiscriminator());
220 const auto discriminator = rhs.getDiscriminator();
221 const bool isDimensionValue = (discriminator == Discriminator::inputOperandDimensionValue ||
222 discriminator == Discriminator::outputOperandDimensionValue);
223
224 return differentValues && !(sameDiscriminator && isDimensionValue);
225 }
226
mutateDatumTest(RequestChannelSender * sender,ResultChannelReceiver * receiver,const std::vector<FmqRequestDatum> & serialized)227 static void mutateDatumTest(RequestChannelSender* sender, ResultChannelReceiver* receiver,
228 const std::vector<FmqRequestDatum>& serialized) {
229 const std::vector<FmqRequestDatum>& change = getBadRequestPacketEntries();
230 for (size_t index = 0; index < serialized.size(); ++index) {
231 for (size_t type = 0; type < change.size(); ++type) {
232 if (interestingCase(serialized[index], change[type])) {
233 const std::string message = "mutateDatum: changed datum at index " +
234 std::to_string(index) + " to datum type " +
235 std::to_string(type);
236 validate(sender, receiver, message, serialized,
237 [index, type, &change](std::vector<FmqRequestDatum>* serialized) {
238 (*serialized)[index] = change[type];
239 });
240 }
241 }
242 }
243 }
244
245 ///////////////////////// BURST VALIATION TESTS ////////////////////////////////////
246
validateBurstSerialization(const sp<IPreparedModel> & preparedModel,const Request & request)247 static void validateBurstSerialization(const sp<IPreparedModel>& preparedModel,
248 const Request& request) {
249 // create burst
250 std::unique_ptr<RequestChannelSender> sender;
251 std::unique_ptr<ResultChannelReceiver> receiver;
252 sp<ExecutionBurstCallback> callback = new ExecutionBurstCallback();
253 sp<IBurstContext> context;
254 ASSERT_NO_FATAL_FAILURE(createBurst(preparedModel, callback, &sender, &receiver, &context));
255 ASSERT_NE(nullptr, sender.get());
256 ASSERT_NE(nullptr, receiver.get());
257 ASSERT_NE(nullptr, context.get());
258
259 // load memory into callback slots
260 std::vector<intptr_t> keys;
261 keys.reserve(request.pools.size());
262 std::transform(request.pools.begin(), request.pools.end(), std::back_inserter(keys),
263 [](const auto& pool) { return reinterpret_cast<intptr_t>(&pool); });
264 const std::vector<int32_t> slots = callback->getSlots(request.pools, keys);
265
266 // ensure slot std::numeric_limits<int32_t>::max() doesn't exist (for
267 // subsequent slot validation testing)
268 ASSERT_TRUE(std::all_of(slots.begin(), slots.end(), [](int32_t slot) {
269 return slot != std::numeric_limits<int32_t>::max();
270 }));
271
272 // serialize the request
273 const auto serialized = android::nn::serialize(request, MeasureTiming::YES, slots);
274
275 // validations
276 removeDatumTest(sender.get(), receiver.get(), serialized);
277 addDatumTest(sender.get(), receiver.get(), serialized);
278 mutateDatumTest(sender.get(), receiver.get(), serialized);
279 }
280
281 // This test validates that when the Result message size exceeds length of the
282 // result FMQ, the service instance gracefully fails and returns an error.
validateBurstFmqLength(const sp<IPreparedModel> & preparedModel,const Request & request)283 static void validateBurstFmqLength(const sp<IPreparedModel>& preparedModel,
284 const Request& request) {
285 // create regular burst
286 std::shared_ptr<ExecutionBurstController> controllerRegular;
287 ASSERT_NO_FATAL_FAILURE(createBurstWithResultChannelLength(
288 preparedModel, kExecutionBurstChannelLength, &controllerRegular));
289 ASSERT_NE(nullptr, controllerRegular.get());
290
291 // create burst with small output channel
292 std::shared_ptr<ExecutionBurstController> controllerSmall;
293 ASSERT_NO_FATAL_FAILURE(createBurstWithResultChannelLength(
294 preparedModel, kExecutionBurstChannelSmallLength, &controllerSmall));
295 ASSERT_NE(nullptr, controllerSmall.get());
296
297 // load memory into callback slots
298 std::vector<intptr_t> keys(request.pools.size());
299 for (size_t i = 0; i < keys.size(); ++i) {
300 keys[i] = reinterpret_cast<intptr_t>(&request.pools[i]);
301 }
302
303 // collect serialized result by running regular burst
304 const auto [nRegular, outputShapesRegular, timingRegular, fallbackRegular] =
305 controllerRegular->compute(request, MeasureTiming::NO, keys);
306 const V1_0::ErrorStatus statusRegular = nn::legacyConvertResultCodeToErrorStatus(nRegular);
307 EXPECT_FALSE(fallbackRegular);
308
309 // skip test if regular burst output isn't useful for testing a failure
310 // caused by having too small of a length for the result FMQ
311 const std::vector<FmqResultDatum> serialized =
312 android::nn::serialize(statusRegular, outputShapesRegular, timingRegular);
313 if (statusRegular != V1_0::ErrorStatus::NONE ||
314 serialized.size() <= kExecutionBurstChannelSmallLength) {
315 return;
316 }
317
318 // by this point, execution should fail because the result channel isn't
319 // large enough to return the serialized result
320 const auto [nSmall, outputShapesSmall, timingSmall, fallbackSmall] =
321 controllerSmall->compute(request, MeasureTiming::NO, keys);
322 const V1_0::ErrorStatus statusSmall = nn::legacyConvertResultCodeToErrorStatus(nSmall);
323 EXPECT_NE(V1_0::ErrorStatus::NONE, statusSmall);
324 EXPECT_EQ(0u, outputShapesSmall.size());
325 EXPECT_TRUE(badTiming(timingSmall));
326 EXPECT_FALSE(fallbackSmall);
327 }
328
isSanitized(const FmqResultDatum & datum)329 static bool isSanitized(const FmqResultDatum& datum) {
330 using Discriminator = FmqResultDatum::hidl_discriminator;
331
332 // check to ensure the padding values in the returned
333 // FmqResultDatum::OperandInformation are initialized to 0
334 if (datum.getDiscriminator() == Discriminator::operandInformation) {
335 static_assert(
336 offsetof(FmqResultDatum::OperandInformation, isSufficient) == 0,
337 "unexpected value for offset of FmqResultDatum::OperandInformation::isSufficient");
338 static_assert(
339 sizeof(FmqResultDatum::OperandInformation::isSufficient) == 1,
340 "unexpected value for size of FmqResultDatum::OperandInformation::isSufficient");
341 static_assert(offsetof(FmqResultDatum::OperandInformation, numberOfDimensions) == 4,
342 "unexpected value for offset of "
343 "FmqResultDatum::OperandInformation::numberOfDimensions");
344 static_assert(sizeof(FmqResultDatum::OperandInformation::numberOfDimensions) == 4,
345 "unexpected value for size of "
346 "FmqResultDatum::OperandInformation::numberOfDimensions");
347 static_assert(sizeof(FmqResultDatum::OperandInformation) == 8,
348 "unexpected value for size of "
349 "FmqResultDatum::OperandInformation");
350
351 constexpr size_t paddingOffset =
352 offsetof(FmqResultDatum::OperandInformation, isSufficient) +
353 sizeof(FmqResultDatum::OperandInformation::isSufficient);
354 constexpr size_t paddingSize =
355 offsetof(FmqResultDatum::OperandInformation, numberOfDimensions) - paddingOffset;
356
357 FmqResultDatum::OperandInformation initialized{};
358 std::memset(&initialized, 0, sizeof(initialized));
359
360 const char* initializedPaddingStart =
361 reinterpret_cast<const char*>(&initialized) + paddingOffset;
362 const char* datumPaddingStart =
363 reinterpret_cast<const char*>(&datum.operandInformation()) + paddingOffset;
364
365 return std::memcmp(datumPaddingStart, initializedPaddingStart, paddingSize) == 0;
366 }
367
368 // there are no other padding initialization checks required, so return true
369 // for any sum-type that isn't FmqResultDatum::OperandInformation
370 return true;
371 }
372
validateBurstSanitized(const sp<IPreparedModel> & preparedModel,const Request & request)373 static void validateBurstSanitized(const sp<IPreparedModel>& preparedModel,
374 const Request& request) {
375 // create burst
376 std::unique_ptr<RequestChannelSender> sender;
377 std::unique_ptr<ResultChannelReceiver> receiver;
378 sp<ExecutionBurstCallback> callback = new ExecutionBurstCallback();
379 sp<IBurstContext> context;
380 ASSERT_NO_FATAL_FAILURE(createBurst(preparedModel, callback, &sender, &receiver, &context));
381 ASSERT_NE(nullptr, sender.get());
382 ASSERT_NE(nullptr, receiver.get());
383 ASSERT_NE(nullptr, context.get());
384
385 // load memory into callback slots
386 std::vector<intptr_t> keys;
387 keys.reserve(request.pools.size());
388 std::transform(request.pools.begin(), request.pools.end(), std::back_inserter(keys),
389 [](const auto& pool) { return reinterpret_cast<intptr_t>(&pool); });
390 const std::vector<int32_t> slots = callback->getSlots(request.pools, keys);
391
392 // send valid request
393 ASSERT_TRUE(sender->send(request, MeasureTiming::YES, slots));
394
395 // receive valid result
396 auto serialized = receiver->getPacketBlocking();
397 ASSERT_TRUE(serialized.has_value());
398
399 // sanitize result
400 ASSERT_TRUE(std::all_of(serialized->begin(), serialized->end(), isSanitized))
401 << "The result serialized data is not properly sanitized";
402 }
403
404 ///////////////////////////// ENTRY POINT //////////////////////////////////
405
validateBurst(const sp<IPreparedModel> & preparedModel,const Request & request)406 void validateBurst(const sp<IPreparedModel>& preparedModel, const Request& request) {
407 ASSERT_NO_FATAL_FAILURE(validateBurstSerialization(preparedModel, request));
408 ASSERT_NO_FATAL_FAILURE(validateBurstFmqLength(preparedModel, request));
409 ASSERT_NO_FATAL_FAILURE(validateBurstSanitized(preparedModel, request));
410 }
411
412 } // namespace android::hardware::neuralnetworks::V1_3::vts::functional
413