1 /*
2 * Copyright (C) 2019 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #define LOG_TAG "ExecutionBurstController"
18
19 #include "ExecutionBurstController.h"
20
21 #include <android-base/logging.h>
22 #include <cstring>
23 #include <limits>
24 #include <string>
25 #include "Tracing.h"
26
27 namespace android::nn {
28 namespace {
29
30 using ::android::hardware::MQDescriptorSync;
31 using FmqRequestDescriptor = MQDescriptorSync<FmqRequestDatum>;
32 using FmqResultDescriptor = MQDescriptorSync<FmqResultDatum>;
33
34 constexpr Timing kNoTiming = {std::numeric_limits<uint64_t>::max(),
35 std::numeric_limits<uint64_t>::max()};
36
37 class BurstContextDeathHandler : public hardware::hidl_death_recipient {
38 public:
39 using Callback = std::function<void()>;
40
BurstContextDeathHandler(const Callback & onDeathCallback)41 BurstContextDeathHandler(const Callback& onDeathCallback) : mOnDeathCallback(onDeathCallback) {
42 CHECK(onDeathCallback != nullptr);
43 }
44
serviceDied(uint64_t,const wp<hidl::base::V1_0::IBase> &)45 void serviceDied(uint64_t /*cookie*/, const wp<hidl::base::V1_0::IBase>& /*who*/) override {
46 LOG(ERROR) << "BurstContextDeathHandler::serviceDied -- service unexpectedly died!";
47 mOnDeathCallback();
48 }
49
50 private:
51 const Callback mOnDeathCallback;
52 };
53
54 } // anonymous namespace
55
56 // serialize a request into a packet
serialize(const Request & request,MeasureTiming measure,const std::vector<int32_t> & slots)57 std::vector<FmqRequestDatum> serialize(const Request& request, MeasureTiming measure,
58 const std::vector<int32_t>& slots) {
59 // count how many elements need to be sent for a request
60 size_t count = 2 + request.inputs.size() + request.outputs.size() + request.pools.size();
61 for (const auto& input : request.inputs) {
62 count += input.dimensions.size();
63 }
64 for (const auto& output : request.outputs) {
65 count += output.dimensions.size();
66 }
67
68 // create buffer to temporarily store elements
69 std::vector<FmqRequestDatum> data;
70 data.reserve(count);
71
72 // package packetInfo
73 {
74 FmqRequestDatum datum;
75 datum.packetInformation(
76 {/*.packetSize=*/static_cast<uint32_t>(count),
77 /*.numberOfInputOperands=*/static_cast<uint32_t>(request.inputs.size()),
78 /*.numberOfOutputOperands=*/static_cast<uint32_t>(request.outputs.size()),
79 /*.numberOfPools=*/static_cast<uint32_t>(request.pools.size())});
80 data.push_back(datum);
81 }
82
83 // package input data
84 for (const auto& input : request.inputs) {
85 // package operand information
86 FmqRequestDatum datum;
87 datum.inputOperandInformation(
88 {/*.hasNoValue=*/input.hasNoValue,
89 /*.location=*/input.location,
90 /*.numberOfDimensions=*/static_cast<uint32_t>(input.dimensions.size())});
91 data.push_back(datum);
92
93 // package operand dimensions
94 for (uint32_t dimension : input.dimensions) {
95 FmqRequestDatum datum;
96 datum.inputOperandDimensionValue(dimension);
97 data.push_back(datum);
98 }
99 }
100
101 // package output data
102 for (const auto& output : request.outputs) {
103 // package operand information
104 FmqRequestDatum datum;
105 datum.outputOperandInformation(
106 {/*.hasNoValue=*/output.hasNoValue,
107 /*.location=*/output.location,
108 /*.numberOfDimensions=*/static_cast<uint32_t>(output.dimensions.size())});
109 data.push_back(datum);
110
111 // package operand dimensions
112 for (uint32_t dimension : output.dimensions) {
113 FmqRequestDatum datum;
114 datum.outputOperandDimensionValue(dimension);
115 data.push_back(datum);
116 }
117 }
118
119 // package pool identifier
120 for (int32_t slot : slots) {
121 FmqRequestDatum datum;
122 datum.poolIdentifier(slot);
123 data.push_back(datum);
124 }
125
126 // package measureTiming
127 {
128 FmqRequestDatum datum;
129 datum.measureTiming(measure);
130 data.push_back(datum);
131 }
132
133 // return packet
134 return data;
135 }
136
137 // deserialize a packet into the result
deserialize(const std::vector<FmqResultDatum> & data)138 std::optional<std::tuple<ErrorStatus, std::vector<OutputShape>, Timing>> deserialize(
139 const std::vector<FmqResultDatum>& data) {
140 using discriminator = FmqResultDatum::hidl_discriminator;
141
142 std::vector<OutputShape> outputShapes;
143 size_t index = 0;
144
145 // validate packet information
146 if (data.size() == 0 || data[index].getDiscriminator() != discriminator::packetInformation) {
147 LOG(ERROR) << "FMQ Result packet ill-formed";
148 return std::nullopt;
149 }
150
151 // unpackage packet information
152 const FmqResultDatum::PacketInformation& packetInfo = data[index].packetInformation();
153 index++;
154 const uint32_t packetSize = packetInfo.packetSize;
155 const ErrorStatus errorStatus = packetInfo.errorStatus;
156 const uint32_t numberOfOperands = packetInfo.numberOfOperands;
157
158 // verify packet size
159 if (data.size() != packetSize) {
160 LOG(ERROR) << "FMQ Result packet ill-formed";
161 return std::nullopt;
162 }
163
164 // unpackage operands
165 for (size_t operand = 0; operand < numberOfOperands; ++operand) {
166 // validate operand information
167 if (data[index].getDiscriminator() != discriminator::operandInformation) {
168 LOG(ERROR) << "FMQ Result packet ill-formed";
169 return std::nullopt;
170 }
171
172 // unpackage operand information
173 const FmqResultDatum::OperandInformation& operandInfo = data[index].operandInformation();
174 index++;
175 const bool isSufficient = operandInfo.isSufficient;
176 const uint32_t numberOfDimensions = operandInfo.numberOfDimensions;
177
178 // unpackage operand dimensions
179 std::vector<uint32_t> dimensions;
180 dimensions.reserve(numberOfDimensions);
181 for (size_t i = 0; i < numberOfDimensions; ++i) {
182 // validate dimension
183 if (data[index].getDiscriminator() != discriminator::operandDimensionValue) {
184 LOG(ERROR) << "FMQ Result packet ill-formed";
185 return std::nullopt;
186 }
187
188 // unpackage dimension
189 const uint32_t dimension = data[index].operandDimensionValue();
190 index++;
191
192 // store result
193 dimensions.push_back(dimension);
194 }
195
196 // store result
197 outputShapes.push_back({/*.dimensions=*/dimensions, /*.isSufficient=*/isSufficient});
198 }
199
200 // validate execution timing
201 if (data[index].getDiscriminator() != discriminator::executionTiming) {
202 LOG(ERROR) << "FMQ Result packet ill-formed";
203 return std::nullopt;
204 }
205
206 // unpackage execution timing
207 const Timing timing = data[index].executionTiming();
208 index++;
209
210 // validate packet information
211 if (index != packetSize) {
212 LOG(ERROR) << "FMQ Result packet ill-formed";
213 return std::nullopt;
214 }
215
216 // return result
217 return std::make_tuple(errorStatus, std::move(outputShapes), timing);
218 }
219
220 std::pair<std::unique_ptr<ResultChannelReceiver>, const FmqResultDescriptor*>
create(size_t channelLength,bool blocking)221 ResultChannelReceiver::create(size_t channelLength, bool blocking) {
222 std::unique_ptr<FmqResultChannel> fmqResultChannel =
223 std::make_unique<FmqResultChannel>(channelLength, /*confEventFlag=*/blocking);
224 if (!fmqResultChannel->isValid()) {
225 LOG(ERROR) << "Unable to create ResultChannelReceiver";
226 return {nullptr, nullptr};
227 }
228 const FmqResultDescriptor* descriptor = fmqResultChannel->getDesc();
229 return std::make_pair(
230 std::make_unique<ResultChannelReceiver>(std::move(fmqResultChannel), blocking),
231 descriptor);
232 }
233
ResultChannelReceiver(std::unique_ptr<FmqResultChannel> fmqResultChannel,bool blocking)234 ResultChannelReceiver::ResultChannelReceiver(std::unique_ptr<FmqResultChannel> fmqResultChannel,
235 bool blocking)
236 : mFmqResultChannel(std::move(fmqResultChannel)), mBlocking(blocking) {}
237
238 std::optional<std::tuple<ErrorStatus, std::vector<OutputShape>, Timing>>
getBlocking()239 ResultChannelReceiver::getBlocking() {
240 const auto packet = getPacketBlocking();
241 if (!packet) {
242 return std::nullopt;
243 }
244
245 return deserialize(*packet);
246 }
247
invalidate()248 void ResultChannelReceiver::invalidate() {
249 mValid = false;
250
251 // force unblock
252 // ExecutionBurstController waits on a result packet after sending a
253 // request. If the driver containing ExecutionBurstServer crashes, the
254 // controller will still be waiting on the futex (assuming mBlocking is
255 // true). This force unblock wakes up any thread waiting on the futex.
256 if (mBlocking) {
257 // TODO: look for a different/better way to signal/notify the futex to
258 // wake up any thread waiting on it
259 FmqResultDatum datum;
260 datum.packetInformation({/*.packetSize=*/0, /*.errorStatus=*/ErrorStatus::GENERAL_FAILURE,
261 /*.numberOfOperands=*/0});
262 mFmqResultChannel->writeBlocking(&datum, 1);
263 }
264 }
265
getPacketBlocking()266 std::optional<std::vector<FmqResultDatum>> ResultChannelReceiver::getPacketBlocking() {
267 using discriminator = FmqResultDatum::hidl_discriminator;
268
269 if (!mValid) {
270 return std::nullopt;
271 }
272
273 // wait for result packet and read first element of result packet
274 FmqResultDatum datum;
275 bool success = true;
276 if (mBlocking) {
277 success = mFmqResultChannel->readBlocking(&datum, 1);
278 } else {
279 while ((success = mValid.load(std::memory_order_relaxed)) &&
280 !mFmqResultChannel->read(&datum, 1)) {
281 }
282 }
283
284 // retrieve remaining elements
285 // NOTE: all of the data is already available at this point, so there's no
286 // need to do a blocking wait to wait for more data. This is known because
287 // in FMQ, all writes are published (made available) atomically. Currently,
288 // the producer always publishes the entire packet in one function call, so
289 // if the first element of the packet is available, the remaining elements
290 // are also available.
291 const size_t count = mFmqResultChannel->availableToRead();
292 std::vector<FmqResultDatum> packet(count + 1);
293 std::memcpy(&packet.front(), &datum, sizeof(datum));
294 success &= mFmqResultChannel->read(packet.data() + 1, count);
295
296 if (!mValid) {
297 return std::nullopt;
298 }
299
300 // ensure packet was successfully received
301 if (!success) {
302 LOG(ERROR) << "Error receiving packet";
303 return std::nullopt;
304 }
305
306 return std::make_optional(std::move(packet));
307 }
308
309 std::pair<std::unique_ptr<RequestChannelSender>, const FmqRequestDescriptor*>
create(size_t channelLength,bool blocking)310 RequestChannelSender::create(size_t channelLength, bool blocking) {
311 std::unique_ptr<FmqRequestChannel> fmqRequestChannel =
312 std::make_unique<FmqRequestChannel>(channelLength, /*confEventFlag=*/blocking);
313 if (!fmqRequestChannel->isValid()) {
314 LOG(ERROR) << "Unable to create RequestChannelSender";
315 return {nullptr, nullptr};
316 }
317 const FmqRequestDescriptor* descriptor = fmqRequestChannel->getDesc();
318 return std::make_pair(
319 std::make_unique<RequestChannelSender>(std::move(fmqRequestChannel), blocking),
320 descriptor);
321 }
322
RequestChannelSender(std::unique_ptr<FmqRequestChannel> fmqRequestChannel,bool blocking)323 RequestChannelSender::RequestChannelSender(std::unique_ptr<FmqRequestChannel> fmqRequestChannel,
324 bool blocking)
325 : mFmqRequestChannel(std::move(fmqRequestChannel)), mBlocking(blocking) {}
326
send(const Request & request,MeasureTiming measure,const std::vector<int32_t> & slots)327 bool RequestChannelSender::send(const Request& request, MeasureTiming measure,
328 const std::vector<int32_t>& slots) {
329 const std::vector<FmqRequestDatum> serialized = serialize(request, measure, slots);
330 return sendPacket(serialized);
331 }
332
sendPacket(const std::vector<FmqRequestDatum> & packet)333 bool RequestChannelSender::sendPacket(const std::vector<FmqRequestDatum>& packet) {
334 if (!mValid) {
335 return false;
336 }
337
338 if (packet.size() > mFmqRequestChannel->availableToWrite()) {
339 LOG(ERROR)
340 << "RequestChannelSender::sendPacket -- packet size exceeds size available in FMQ";
341 return false;
342 }
343
344 if (mBlocking) {
345 return mFmqRequestChannel->writeBlocking(packet.data(), packet.size());
346 } else {
347 return mFmqRequestChannel->write(packet.data(), packet.size());
348 }
349 }
350
invalidate()351 void RequestChannelSender::invalidate() {
352 mValid = false;
353 }
354
getMemories(const hidl_vec<int32_t> & slots,getMemories_cb cb)355 Return<void> ExecutionBurstController::ExecutionBurstCallback::getMemories(
356 const hidl_vec<int32_t>& slots, getMemories_cb cb) {
357 std::lock_guard<std::mutex> guard(mMutex);
358
359 // get all memories
360 hidl_vec<hidl_memory> memories(slots.size());
361 std::transform(slots.begin(), slots.end(), memories.begin(), [this](int32_t slot) {
362 return slot < mMemoryCache.size() ? mMemoryCache[slot] : hidl_memory{};
363 });
364
365 // ensure all memories are valid
366 if (!std::all_of(memories.begin(), memories.end(),
367 [](const hidl_memory& memory) { return memory.valid(); })) {
368 cb(ErrorStatus::INVALID_ARGUMENT, {});
369 return Void();
370 }
371
372 // return successful
373 cb(ErrorStatus::NONE, std::move(memories));
374 return Void();
375 }
376
getSlots(const hidl_vec<hidl_memory> & memories,const std::vector<intptr_t> & keys)377 std::vector<int32_t> ExecutionBurstController::ExecutionBurstCallback::getSlots(
378 const hidl_vec<hidl_memory>& memories, const std::vector<intptr_t>& keys) {
379 std::lock_guard<std::mutex> guard(mMutex);
380
381 // retrieve (or bind) all slots corresponding to memories
382 std::vector<int32_t> slots;
383 slots.reserve(memories.size());
384 for (size_t i = 0; i < memories.size(); ++i) {
385 slots.push_back(getSlotLocked(memories[i], keys[i]));
386 }
387 return slots;
388 }
389
freeMemory(intptr_t key)390 std::pair<bool, int32_t> ExecutionBurstController::ExecutionBurstCallback::freeMemory(
391 intptr_t key) {
392 std::lock_guard<std::mutex> guard(mMutex);
393
394 auto iter = mMemoryIdToSlot.find(key);
395 if (iter == mMemoryIdToSlot.end()) {
396 return {false, 0};
397 }
398 const int32_t slot = iter->second;
399 mMemoryIdToSlot.erase(key);
400 mMemoryCache[slot] = {};
401 mFreeSlots.push(slot);
402 return {true, slot};
403 }
404
getSlotLocked(const hidl_memory & memory,intptr_t key)405 int32_t ExecutionBurstController::ExecutionBurstCallback::getSlotLocked(const hidl_memory& memory,
406 intptr_t key) {
407 auto iter = mMemoryIdToSlot.find(key);
408 if (iter == mMemoryIdToSlot.end()) {
409 const int32_t slot = allocateSlotLocked();
410 mMemoryIdToSlot[key] = slot;
411 mMemoryCache[slot] = memory;
412 return slot;
413 } else {
414 const int32_t slot = iter->second;
415 return slot;
416 }
417 }
418
allocateSlotLocked()419 int32_t ExecutionBurstController::ExecutionBurstCallback::allocateSlotLocked() {
420 constexpr size_t kMaxNumberOfSlots = std::numeric_limits<int32_t>::max();
421
422 // if there is a free slot, use it
423 if (mFreeSlots.size() > 0) {
424 const int32_t slot = mFreeSlots.top();
425 mFreeSlots.pop();
426 return slot;
427 }
428
429 // otherwise use a slot for the first time
430 CHECK(mMemoryCache.size() < kMaxNumberOfSlots) << "Exceeded maximum number of slots!";
431 const int32_t slot = static_cast<int32_t>(mMemoryCache.size());
432 mMemoryCache.emplace_back();
433
434 return slot;
435 }
436
create(const sp<IPreparedModel> & preparedModel,bool blocking)437 std::unique_ptr<ExecutionBurstController> ExecutionBurstController::create(
438 const sp<IPreparedModel>& preparedModel, bool blocking) {
439 // check inputs
440 if (preparedModel == nullptr) {
441 LOG(ERROR) << "ExecutionBurstController::create passed a nullptr";
442 return nullptr;
443 }
444
445 // create callback object
446 sp<ExecutionBurstCallback> callback = new ExecutionBurstCallback();
447
448 // create FMQ objects
449 auto [requestChannelSenderTemp, requestChannelDescriptor] =
450 RequestChannelSender::create(kExecutionBurstChannelLength, blocking);
451 auto [resultChannelReceiverTemp, resultChannelDescriptor] =
452 ResultChannelReceiver::create(kExecutionBurstChannelLength, blocking);
453 std::shared_ptr<RequestChannelSender> requestChannelSender =
454 std::move(requestChannelSenderTemp);
455 std::shared_ptr<ResultChannelReceiver> resultChannelReceiver =
456 std::move(resultChannelReceiverTemp);
457
458 // check FMQ objects
459 if (!requestChannelSender || !resultChannelReceiver || !requestChannelDescriptor ||
460 !resultChannelDescriptor) {
461 LOG(ERROR) << "ExecutionBurstController::create failed to create FastMessageQueue";
462 return nullptr;
463 }
464
465 // configure burst
466 ErrorStatus errorStatus;
467 sp<IBurstContext> burstContext;
468 const Return<void> ret = preparedModel->configureExecutionBurst(
469 callback, *requestChannelDescriptor, *resultChannelDescriptor,
470 [&errorStatus, &burstContext](ErrorStatus status, const sp<IBurstContext>& context) {
471 errorStatus = status;
472 burstContext = context;
473 });
474
475 // check burst
476 if (!ret.isOk()) {
477 LOG(ERROR) << "IPreparedModel::configureExecutionBurst failed with description "
478 << ret.description();
479 return nullptr;
480 }
481 if (errorStatus != ErrorStatus::NONE) {
482 LOG(ERROR) << "IPreparedModel::configureExecutionBurst failed with status "
483 << toString(errorStatus);
484 return nullptr;
485 }
486 if (burstContext == nullptr) {
487 LOG(ERROR) << "IPreparedModel::configureExecutionBurst returned nullptr for burst";
488 return nullptr;
489 }
490
491 // create death handler object
492 BurstContextDeathHandler::Callback onDeathCallback = [requestChannelSender,
493 resultChannelReceiver] {
494 requestChannelSender->invalidate();
495 resultChannelReceiver->invalidate();
496 };
497 const sp<BurstContextDeathHandler> deathHandler = new BurstContextDeathHandler(onDeathCallback);
498
499 // linkToDeath registers a callback that will be invoked on service death to
500 // proactively handle service crashes. If the linkToDeath call fails,
501 // asynchronous calls are susceptible to hangs if the service crashes before
502 // providing the response.
503 const Return<bool> deathHandlerRet = burstContext->linkToDeath(deathHandler, 0);
504 if (!deathHandlerRet.isOk() || deathHandlerRet != true) {
505 LOG(ERROR) << "ExecutionBurstController::create -- Failed to register a death recipient "
506 "for the IBurstContext object.";
507 return nullptr;
508 }
509
510 // make and return controller
511 return std::make_unique<ExecutionBurstController>(requestChannelSender, resultChannelReceiver,
512 burstContext, callback, deathHandler);
513 }
514
ExecutionBurstController(const std::shared_ptr<RequestChannelSender> & requestChannelSender,const std::shared_ptr<ResultChannelReceiver> & resultChannelReceiver,const sp<IBurstContext> & burstContext,const sp<ExecutionBurstCallback> & callback,const sp<hardware::hidl_death_recipient> & deathHandler)515 ExecutionBurstController::ExecutionBurstController(
516 const std::shared_ptr<RequestChannelSender>& requestChannelSender,
517 const std::shared_ptr<ResultChannelReceiver>& resultChannelReceiver,
518 const sp<IBurstContext>& burstContext, const sp<ExecutionBurstCallback>& callback,
519 const sp<hardware::hidl_death_recipient>& deathHandler)
520 : mRequestChannelSender(requestChannelSender),
521 mResultChannelReceiver(resultChannelReceiver),
522 mBurstContext(burstContext),
523 mMemoryCache(callback),
524 mDeathHandler(deathHandler) {}
525
~ExecutionBurstController()526 ExecutionBurstController::~ExecutionBurstController() {
527 // It is safe to ignore any errors resulting from this unlinkToDeath call
528 // because the ExecutionBurstController object is already being destroyed
529 // and its underlying IBurstContext object is no longer being used by the NN
530 // runtime.
531 if (mDeathHandler) {
532 mBurstContext->unlinkToDeath(mDeathHandler).isOk();
533 }
534 }
535
compute(const Request & request,MeasureTiming measure,const std::vector<intptr_t> & memoryIds)536 std::tuple<ErrorStatus, std::vector<OutputShape>, Timing> ExecutionBurstController::compute(
537 const Request& request, MeasureTiming measure, const std::vector<intptr_t>& memoryIds) {
538 auto [status, outputShapes, timing, fallback] = tryCompute(request, measure, memoryIds);
539 (void)fallback; // ignore fallback field
540 return {status, std::move(outputShapes), timing};
541 }
542
543 std::tuple<ErrorStatus, std::vector<OutputShape>, Timing, bool>
tryCompute(const Request & request,MeasureTiming measure,const std::vector<intptr_t> & memoryIds)544 ExecutionBurstController::tryCompute(const Request& request, MeasureTiming measure,
545 const std::vector<intptr_t>& memoryIds) {
546 NNTRACE_FULL(NNTRACE_LAYER_IPC, NNTRACE_PHASE_EXECUTION, "ExecutionBurstController::compute");
547
548 std::lock_guard<std::mutex> guard(mMutex);
549
550 // send request packet
551 const std::vector<int32_t> slots = mMemoryCache->getSlots(request.pools, memoryIds);
552 const bool success = mRequestChannelSender->send(request, measure, slots);
553 if (!success) {
554 LOG(ERROR) << "Error sending FMQ packet";
555 // only use fallback execution path if the packet could not be sent
556 return {ErrorStatus::GENERAL_FAILURE, {}, kNoTiming, /*fallback=*/true};
557 }
558
559 // get result packet
560 const auto result = mResultChannelReceiver->getBlocking();
561 if (!result) {
562 LOG(ERROR) << "Error retrieving FMQ packet";
563 // only use fallback execution path if the packet could not be sent
564 return {ErrorStatus::GENERAL_FAILURE, {}, kNoTiming, /*fallback=*/false};
565 }
566
567 // unpack results and return (only use fallback execution path if the
568 // packet could not be sent)
569 auto [status, outputShapes, timing] = std::move(*result);
570 return {status, std::move(outputShapes), timing, /*fallback=*/false};
571 }
572
freeMemory(intptr_t key)573 void ExecutionBurstController::freeMemory(intptr_t key) {
574 std::lock_guard<std::mutex> guard(mMutex);
575
576 bool valid;
577 int32_t slot;
578 std::tie(valid, slot) = mMemoryCache->freeMemory(key);
579 if (valid) {
580 mBurstContext->freeMemory(slot).isOk();
581 }
582 }
583
584 } // namespace android::nn
585