1 /*
2 * Copyright (C) 2019 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #define LOG_TAG "ExecutionBurstController"
18
19 #include "ExecutionBurstController.h"
20
21 #include <android-base/logging.h>
22
23 #include <algorithm>
24 #include <cstring>
25 #include <limits>
26 #include <memory>
27 #include <string>
28 #include <tuple>
29 #include <utility>
30 #include <vector>
31
32 #include "HalInterfaces.h"
33 #include "Tracing.h"
34 #include "Utils.h"
35
36 namespace android::nn {
37 namespace {
38
39 using namespace hal;
40
41 using V1_2::FmqRequestDatum;
42 using V1_2::FmqResultDatum;
43 using V1_2::IBurstCallback;
44 using V1_2::IBurstContext;
45 using FmqRequestDescriptor = hardware::MQDescriptorSync<FmqRequestDatum>;
46 using FmqResultDescriptor = hardware::MQDescriptorSync<FmqResultDatum>;
47
48 constexpr Timing kNoTiming = {std::numeric_limits<uint64_t>::max(),
49 std::numeric_limits<uint64_t>::max()};
50
51 class BurstContextDeathHandler : public hidl_death_recipient {
52 public:
53 using Callback = std::function<void()>;
54
BurstContextDeathHandler(const Callback & onDeathCallback)55 BurstContextDeathHandler(const Callback& onDeathCallback) : mOnDeathCallback(onDeathCallback) {
56 CHECK(onDeathCallback != nullptr);
57 }
58
serviceDied(uint64_t,const wp<hidl::base::V1_0::IBase> &)59 void serviceDied(uint64_t /*cookie*/, const wp<hidl::base::V1_0::IBase>& /*who*/) override {
60 LOG(ERROR) << "BurstContextDeathHandler::serviceDied -- service unexpectedly died!";
61 mOnDeathCallback();
62 }
63
64 private:
65 const Callback mOnDeathCallback;
66 };
67
68 } // anonymous namespace
69
70 // serialize a request into a packet
serialize(const V1_0::Request & request,MeasureTiming measure,const std::vector<int32_t> & slots)71 std::vector<FmqRequestDatum> serialize(const V1_0::Request& request, MeasureTiming measure,
72 const std::vector<int32_t>& slots) {
73 // count how many elements need to be sent for a request
74 size_t count = 2 + request.inputs.size() + request.outputs.size() + request.pools.size();
75 for (const auto& input : request.inputs) {
76 count += input.dimensions.size();
77 }
78 for (const auto& output : request.outputs) {
79 count += output.dimensions.size();
80 }
81
82 // create buffer to temporarily store elements
83 std::vector<FmqRequestDatum> data;
84 data.reserve(count);
85
86 // package packetInfo
87 {
88 FmqRequestDatum datum;
89 datum.packetInformation(
90 {/*.packetSize=*/static_cast<uint32_t>(count),
91 /*.numberOfInputOperands=*/static_cast<uint32_t>(request.inputs.size()),
92 /*.numberOfOutputOperands=*/static_cast<uint32_t>(request.outputs.size()),
93 /*.numberOfPools=*/static_cast<uint32_t>(request.pools.size())});
94 data.push_back(datum);
95 }
96
97 // package input data
98 for (const auto& input : request.inputs) {
99 // package operand information
100 FmqRequestDatum datum;
101 datum.inputOperandInformation(
102 {/*.hasNoValue=*/input.hasNoValue,
103 /*.location=*/input.location,
104 /*.numberOfDimensions=*/static_cast<uint32_t>(input.dimensions.size())});
105 data.push_back(datum);
106
107 // package operand dimensions
108 for (uint32_t dimension : input.dimensions) {
109 FmqRequestDatum datum;
110 datum.inputOperandDimensionValue(dimension);
111 data.push_back(datum);
112 }
113 }
114
115 // package output data
116 for (const auto& output : request.outputs) {
117 // package operand information
118 FmqRequestDatum datum;
119 datum.outputOperandInformation(
120 {/*.hasNoValue=*/output.hasNoValue,
121 /*.location=*/output.location,
122 /*.numberOfDimensions=*/static_cast<uint32_t>(output.dimensions.size())});
123 data.push_back(datum);
124
125 // package operand dimensions
126 for (uint32_t dimension : output.dimensions) {
127 FmqRequestDatum datum;
128 datum.outputOperandDimensionValue(dimension);
129 data.push_back(datum);
130 }
131 }
132
133 // package pool identifier
134 for (int32_t slot : slots) {
135 FmqRequestDatum datum;
136 datum.poolIdentifier(slot);
137 data.push_back(datum);
138 }
139
140 // package measureTiming
141 {
142 FmqRequestDatum datum;
143 datum.measureTiming(measure);
144 data.push_back(datum);
145 }
146
147 // return packet
148 return data;
149 }
150
151 // deserialize a packet into the result
deserialize(const std::vector<FmqResultDatum> & data)152 std::optional<std::tuple<V1_0::ErrorStatus, std::vector<OutputShape>, Timing>> deserialize(
153 const std::vector<FmqResultDatum>& data) {
154 using discriminator = FmqResultDatum::hidl_discriminator;
155
156 std::vector<OutputShape> outputShapes;
157 size_t index = 0;
158
159 // validate packet information
160 if (data.size() == 0 || data[index].getDiscriminator() != discriminator::packetInformation) {
161 LOG(ERROR) << "FMQ Result packet ill-formed";
162 return std::nullopt;
163 }
164
165 // unpackage packet information
166 const FmqResultDatum::PacketInformation& packetInfo = data[index].packetInformation();
167 index++;
168 const uint32_t packetSize = packetInfo.packetSize;
169 const V1_0::ErrorStatus errorStatus = packetInfo.errorStatus;
170 const uint32_t numberOfOperands = packetInfo.numberOfOperands;
171
172 // verify packet size
173 if (data.size() != packetSize) {
174 LOG(ERROR) << "FMQ Result packet ill-formed";
175 return std::nullopt;
176 }
177
178 // unpackage operands
179 for (size_t operand = 0; operand < numberOfOperands; ++operand) {
180 // validate operand information
181 if (data[index].getDiscriminator() != discriminator::operandInformation) {
182 LOG(ERROR) << "FMQ Result packet ill-formed";
183 return std::nullopt;
184 }
185
186 // unpackage operand information
187 const FmqResultDatum::OperandInformation& operandInfo = data[index].operandInformation();
188 index++;
189 const bool isSufficient = operandInfo.isSufficient;
190 const uint32_t numberOfDimensions = operandInfo.numberOfDimensions;
191
192 // unpackage operand dimensions
193 std::vector<uint32_t> dimensions;
194 dimensions.reserve(numberOfDimensions);
195 for (size_t i = 0; i < numberOfDimensions; ++i) {
196 // validate dimension
197 if (data[index].getDiscriminator() != discriminator::operandDimensionValue) {
198 LOG(ERROR) << "FMQ Result packet ill-formed";
199 return std::nullopt;
200 }
201
202 // unpackage dimension
203 const uint32_t dimension = data[index].operandDimensionValue();
204 index++;
205
206 // store result
207 dimensions.push_back(dimension);
208 }
209
210 // store result
211 outputShapes.push_back({/*.dimensions=*/dimensions, /*.isSufficient=*/isSufficient});
212 }
213
214 // validate execution timing
215 if (data[index].getDiscriminator() != discriminator::executionTiming) {
216 LOG(ERROR) << "FMQ Result packet ill-formed";
217 return std::nullopt;
218 }
219
220 // unpackage execution timing
221 const Timing timing = data[index].executionTiming();
222 index++;
223
224 // validate packet information
225 if (index != packetSize) {
226 LOG(ERROR) << "FMQ Result packet ill-formed";
227 return std::nullopt;
228 }
229
230 // return result
231 return std::make_tuple(errorStatus, std::move(outputShapes), timing);
232 }
233
legacyConvertResultCodeToErrorStatus(int resultCode)234 V1_0::ErrorStatus legacyConvertResultCodeToErrorStatus(int resultCode) {
235 return convertToV1_0(convertResultCodeToErrorStatus(resultCode));
236 }
237
238 std::pair<std::unique_ptr<ResultChannelReceiver>, const FmqResultDescriptor*>
create(size_t channelLength,std::chrono::microseconds pollingTimeWindow)239 ResultChannelReceiver::create(size_t channelLength, std::chrono::microseconds pollingTimeWindow) {
240 std::unique_ptr<FmqResultChannel> fmqResultChannel =
241 std::make_unique<FmqResultChannel>(channelLength, /*confEventFlag=*/true);
242 if (!fmqResultChannel->isValid()) {
243 LOG(ERROR) << "Unable to create ResultChannelReceiver";
244 return {nullptr, nullptr};
245 }
246
247 const FmqResultDescriptor* descriptor = fmqResultChannel->getDesc();
248 return std::make_pair(
249 std::make_unique<ResultChannelReceiver>(std::move(fmqResultChannel), pollingTimeWindow),
250 descriptor);
251 }
252
ResultChannelReceiver(std::unique_ptr<FmqResultChannel> fmqResultChannel,std::chrono::microseconds pollingTimeWindow)253 ResultChannelReceiver::ResultChannelReceiver(std::unique_ptr<FmqResultChannel> fmqResultChannel,
254 std::chrono::microseconds pollingTimeWindow)
255 : mFmqResultChannel(std::move(fmqResultChannel)), kPollingTimeWindow(pollingTimeWindow) {}
256
257 std::optional<std::tuple<V1_0::ErrorStatus, std::vector<OutputShape>, Timing>>
getBlocking()258 ResultChannelReceiver::getBlocking() {
259 const auto packet = getPacketBlocking();
260 if (!packet) {
261 return std::nullopt;
262 }
263
264 return deserialize(*packet);
265 }
266
invalidate()267 void ResultChannelReceiver::invalidate() {
268 mValid = false;
269
270 // force unblock
271 // ExecutionBurstController waits on a result packet after sending a
272 // request. If the driver containing ExecutionBurstServer crashes, the
273 // controller may be waiting on the futex. This force unblock wakes up any
274 // thread waiting on the futex.
275 // TODO: look for a different/better way to signal/notify the futex to
276 // wake up any thread waiting on it
277 FmqResultDatum datum;
278 datum.packetInformation({/*.packetSize=*/0, /*.errorStatus=*/V1_0::ErrorStatus::GENERAL_FAILURE,
279 /*.numberOfOperands=*/0});
280 mFmqResultChannel->writeBlocking(&datum, 1);
281 }
282
getPacketBlocking()283 std::optional<std::vector<FmqResultDatum>> ResultChannelReceiver::getPacketBlocking() {
284 using discriminator = FmqResultDatum::hidl_discriminator;
285
286 if (!mValid) {
287 return std::nullopt;
288 }
289
290 // First spend time polling if results are available in FMQ instead of
291 // waiting on the futex. Polling is more responsive (yielding lower
292 // latencies), but can take up more power, so only poll for a limited period
293 // of time.
294
295 auto& getCurrentTime = std::chrono::high_resolution_clock::now;
296 const auto timeToStopPolling = getCurrentTime() + kPollingTimeWindow;
297
298 while (getCurrentTime() < timeToStopPolling) {
299 // if class is being torn down, immediately return
300 if (!mValid.load(std::memory_order_relaxed)) {
301 return std::nullopt;
302 }
303
304 // Check if data is available. If it is, immediately retrieve it and
305 // return.
306 const size_t available = mFmqResultChannel->availableToRead();
307 if (available > 0) {
308 std::vector<FmqResultDatum> packet(available);
309 const bool success = mFmqResultChannel->read(packet.data(), available);
310 if (!success) {
311 LOG(ERROR) << "Error receiving packet";
312 return std::nullopt;
313 }
314 return std::make_optional(std::move(packet));
315 }
316 }
317
318 // If we get to this point, we either stopped polling because it was taking
319 // too long or polling was not allowed. Instead, perform a blocking call
320 // which uses a futex to save power.
321
322 // wait for result packet and read first element of result packet
323 FmqResultDatum datum;
324 bool success = mFmqResultChannel->readBlocking(&datum, 1);
325
326 // retrieve remaining elements
327 // NOTE: all of the data is already available at this point, so there's no
328 // need to do a blocking wait to wait for more data. This is known because
329 // in FMQ, all writes are published (made available) atomically. Currently,
330 // the producer always publishes the entire packet in one function call, so
331 // if the first element of the packet is available, the remaining elements
332 // are also available.
333 const size_t count = mFmqResultChannel->availableToRead();
334 std::vector<FmqResultDatum> packet(count + 1);
335 std::memcpy(&packet.front(), &datum, sizeof(datum));
336 success &= mFmqResultChannel->read(packet.data() + 1, count);
337
338 if (!mValid) {
339 return std::nullopt;
340 }
341
342 // ensure packet was successfully received
343 if (!success) {
344 LOG(ERROR) << "Error receiving packet";
345 return std::nullopt;
346 }
347
348 return std::make_optional(std::move(packet));
349 }
350
351 std::pair<std::unique_ptr<RequestChannelSender>, const FmqRequestDescriptor*>
create(size_t channelLength)352 RequestChannelSender::create(size_t channelLength) {
353 std::unique_ptr<FmqRequestChannel> fmqRequestChannel =
354 std::make_unique<FmqRequestChannel>(channelLength, /*confEventFlag=*/true);
355 if (!fmqRequestChannel->isValid()) {
356 LOG(ERROR) << "Unable to create RequestChannelSender";
357 return {nullptr, nullptr};
358 }
359
360 const FmqRequestDescriptor* descriptor = fmqRequestChannel->getDesc();
361 return std::make_pair(std::make_unique<RequestChannelSender>(std::move(fmqRequestChannel)),
362 descriptor);
363 }
364
RequestChannelSender(std::unique_ptr<FmqRequestChannel> fmqRequestChannel)365 RequestChannelSender::RequestChannelSender(std::unique_ptr<FmqRequestChannel> fmqRequestChannel)
366 : mFmqRequestChannel(std::move(fmqRequestChannel)) {}
367
send(const V1_0::Request & request,MeasureTiming measure,const std::vector<int32_t> & slots)368 bool RequestChannelSender::send(const V1_0::Request& request, MeasureTiming measure,
369 const std::vector<int32_t>& slots) {
370 const std::vector<FmqRequestDatum> serialized = serialize(request, measure, slots);
371 return sendPacket(serialized);
372 }
373
sendPacket(const std::vector<FmqRequestDatum> & packet)374 bool RequestChannelSender::sendPacket(const std::vector<FmqRequestDatum>& packet) {
375 if (!mValid) {
376 return false;
377 }
378
379 if (packet.size() > mFmqRequestChannel->availableToWrite()) {
380 LOG(ERROR)
381 << "RequestChannelSender::sendPacket -- packet size exceeds size available in FMQ";
382 return false;
383 }
384
385 // Always send the packet with "blocking" because this signals the futex and
386 // unblocks the consumer if it is waiting on the futex.
387 return mFmqRequestChannel->writeBlocking(packet.data(), packet.size());
388 }
389
invalidate()390 void RequestChannelSender::invalidate() {
391 mValid = false;
392 }
393
getMemories(const hidl_vec<int32_t> & slots,getMemories_cb cb)394 Return<void> ExecutionBurstController::ExecutionBurstCallback::getMemories(
395 const hidl_vec<int32_t>& slots, getMemories_cb cb) {
396 std::lock_guard<std::mutex> guard(mMutex);
397
398 // get all memories
399 hidl_vec<hidl_memory> memories(slots.size());
400 std::transform(slots.begin(), slots.end(), memories.begin(), [this](int32_t slot) {
401 return slot < mMemoryCache.size() ? mMemoryCache[slot] : hidl_memory{};
402 });
403
404 // ensure all memories are valid
405 if (!std::all_of(memories.begin(), memories.end(),
406 [](const hidl_memory& memory) { return memory.valid(); })) {
407 cb(V1_0::ErrorStatus::INVALID_ARGUMENT, {});
408 return Void();
409 }
410
411 // return successful
412 cb(V1_0::ErrorStatus::NONE, std::move(memories));
413 return Void();
414 }
415
getSlots(const hidl_vec<hidl_memory> & memories,const std::vector<intptr_t> & keys)416 std::vector<int32_t> ExecutionBurstController::ExecutionBurstCallback::getSlots(
417 const hidl_vec<hidl_memory>& memories, const std::vector<intptr_t>& keys) {
418 std::lock_guard<std::mutex> guard(mMutex);
419
420 // retrieve (or bind) all slots corresponding to memories
421 std::vector<int32_t> slots;
422 slots.reserve(memories.size());
423 for (size_t i = 0; i < memories.size(); ++i) {
424 slots.push_back(getSlotLocked(memories[i], keys[i]));
425 }
426 return slots;
427 }
428
freeMemory(intptr_t key)429 std::pair<bool, int32_t> ExecutionBurstController::ExecutionBurstCallback::freeMemory(
430 intptr_t key) {
431 std::lock_guard<std::mutex> guard(mMutex);
432
433 auto iter = mMemoryIdToSlot.find(key);
434 if (iter == mMemoryIdToSlot.end()) {
435 return {false, 0};
436 }
437 const int32_t slot = iter->second;
438 mMemoryIdToSlot.erase(key);
439 mMemoryCache[slot] = {};
440 mFreeSlots.push(slot);
441 return {true, slot};
442 }
443
getSlotLocked(const hidl_memory & memory,intptr_t key)444 int32_t ExecutionBurstController::ExecutionBurstCallback::getSlotLocked(const hidl_memory& memory,
445 intptr_t key) {
446 auto iter = mMemoryIdToSlot.find(key);
447 if (iter == mMemoryIdToSlot.end()) {
448 const int32_t slot = allocateSlotLocked();
449 mMemoryIdToSlot[key] = slot;
450 mMemoryCache[slot] = memory;
451 return slot;
452 } else {
453 const int32_t slot = iter->second;
454 return slot;
455 }
456 }
457
allocateSlotLocked()458 int32_t ExecutionBurstController::ExecutionBurstCallback::allocateSlotLocked() {
459 constexpr size_t kMaxNumberOfSlots = std::numeric_limits<int32_t>::max();
460
461 // if there is a free slot, use it
462 if (mFreeSlots.size() > 0) {
463 const int32_t slot = mFreeSlots.top();
464 mFreeSlots.pop();
465 return slot;
466 }
467
468 // otherwise use a slot for the first time
469 CHECK(mMemoryCache.size() < kMaxNumberOfSlots) << "Exceeded maximum number of slots!";
470 const int32_t slot = static_cast<int32_t>(mMemoryCache.size());
471 mMemoryCache.emplace_back();
472
473 return slot;
474 }
475
create(const sp<V1_2::IPreparedModel> & preparedModel,std::chrono::microseconds pollingTimeWindow)476 std::unique_ptr<ExecutionBurstController> ExecutionBurstController::create(
477 const sp<V1_2::IPreparedModel>& preparedModel,
478 std::chrono::microseconds pollingTimeWindow) {
479 // check inputs
480 if (preparedModel == nullptr) {
481 LOG(ERROR) << "ExecutionBurstController::create passed a nullptr";
482 return nullptr;
483 }
484
485 // create callback object
486 sp<ExecutionBurstCallback> callback = new ExecutionBurstCallback();
487
488 // create FMQ objects
489 auto [requestChannelSenderTemp, requestChannelDescriptor] =
490 RequestChannelSender::create(kExecutionBurstChannelLength);
491 auto [resultChannelReceiverTemp, resultChannelDescriptor] =
492 ResultChannelReceiver::create(kExecutionBurstChannelLength, pollingTimeWindow);
493 std::shared_ptr<RequestChannelSender> requestChannelSender =
494 std::move(requestChannelSenderTemp);
495 std::shared_ptr<ResultChannelReceiver> resultChannelReceiver =
496 std::move(resultChannelReceiverTemp);
497
498 // check FMQ objects
499 if (!requestChannelSender || !resultChannelReceiver || !requestChannelDescriptor ||
500 !resultChannelDescriptor) {
501 LOG(ERROR) << "ExecutionBurstController::create failed to create FastMessageQueue";
502 return nullptr;
503 }
504
505 // configure burst
506 V1_0::ErrorStatus errorStatus;
507 sp<IBurstContext> burstContext;
508 const Return<void> ret = preparedModel->configureExecutionBurst(
509 callback, *requestChannelDescriptor, *resultChannelDescriptor,
510 [&errorStatus, &burstContext](V1_0::ErrorStatus status,
511 const sp<IBurstContext>& context) {
512 errorStatus = status;
513 burstContext = context;
514 });
515
516 // check burst
517 if (!ret.isOk()) {
518 LOG(ERROR) << "IPreparedModel::configureExecutionBurst failed with description "
519 << ret.description();
520 return nullptr;
521 }
522 if (errorStatus != V1_0::ErrorStatus::NONE) {
523 LOG(ERROR) << "IPreparedModel::configureExecutionBurst failed with status "
524 << toString(errorStatus);
525 return nullptr;
526 }
527 if (burstContext == nullptr) {
528 LOG(ERROR) << "IPreparedModel::configureExecutionBurst returned nullptr for burst";
529 return nullptr;
530 }
531
532 // create death handler object
533 BurstContextDeathHandler::Callback onDeathCallback = [requestChannelSender,
534 resultChannelReceiver] {
535 requestChannelSender->invalidate();
536 resultChannelReceiver->invalidate();
537 };
538 const sp<BurstContextDeathHandler> deathHandler = new BurstContextDeathHandler(onDeathCallback);
539
540 // linkToDeath registers a callback that will be invoked on service death to
541 // proactively handle service crashes. If the linkToDeath call fails,
542 // asynchronous calls are susceptible to hangs if the service crashes before
543 // providing the response.
544 const Return<bool> deathHandlerRet = burstContext->linkToDeath(deathHandler, 0);
545 if (!deathHandlerRet.isOk() || deathHandlerRet != true) {
546 LOG(ERROR) << "ExecutionBurstController::create -- Failed to register a death recipient "
547 "for the IBurstContext object.";
548 return nullptr;
549 }
550
551 // make and return controller
552 return std::make_unique<ExecutionBurstController>(requestChannelSender, resultChannelReceiver,
553 burstContext, callback, deathHandler);
554 }
555
ExecutionBurstController(const std::shared_ptr<RequestChannelSender> & requestChannelSender,const std::shared_ptr<ResultChannelReceiver> & resultChannelReceiver,const sp<IBurstContext> & burstContext,const sp<ExecutionBurstCallback> & callback,const sp<hidl_death_recipient> & deathHandler)556 ExecutionBurstController::ExecutionBurstController(
557 const std::shared_ptr<RequestChannelSender>& requestChannelSender,
558 const std::shared_ptr<ResultChannelReceiver>& resultChannelReceiver,
559 const sp<IBurstContext>& burstContext, const sp<ExecutionBurstCallback>& callback,
560 const sp<hidl_death_recipient>& deathHandler)
561 : mRequestChannelSender(requestChannelSender),
562 mResultChannelReceiver(resultChannelReceiver),
563 mBurstContext(burstContext),
564 mMemoryCache(callback),
565 mDeathHandler(deathHandler) {}
566
~ExecutionBurstController()567 ExecutionBurstController::~ExecutionBurstController() {
568 // It is safe to ignore any errors resulting from this unlinkToDeath call
569 // because the ExecutionBurstController object is already being destroyed
570 // and its underlying IBurstContext object is no longer being used by the NN
571 // runtime.
572 if (mDeathHandler) {
573 mBurstContext->unlinkToDeath(mDeathHandler).isOk();
574 }
575 }
576
getExecutionResult(V1_0::ErrorStatus status,std::vector<OutputShape> outputShapes,Timing timing,bool fallback)577 static std::tuple<int, std::vector<OutputShape>, Timing, bool> getExecutionResult(
578 V1_0::ErrorStatus status, std::vector<OutputShape> outputShapes, Timing timing,
579 bool fallback) {
580 auto [n, checkedOutputShapes, checkedTiming] =
581 getExecutionResult(convertToV1_3(status), std::move(outputShapes), timing);
582 return {n, std::move(checkedOutputShapes), checkedTiming, fallback};
583 }
584
compute(const V1_0::Request & request,MeasureTiming measure,const std::vector<intptr_t> & memoryIds)585 std::tuple<int, std::vector<OutputShape>, Timing, bool> ExecutionBurstController::compute(
586 const V1_0::Request& request, MeasureTiming measure,
587 const std::vector<intptr_t>& memoryIds) {
588 // This is the first point when we know an execution is occurring, so begin
589 // to collect systraces. Note that the first point we can begin collecting
590 // systraces in ExecutionBurstServer is when the RequestChannelReceiver
591 // realizes there is data in the FMQ, so ExecutionBurstServer collects
592 // systraces at different points in the code.
593 NNTRACE_FULL(NNTRACE_LAYER_IPC, NNTRACE_PHASE_EXECUTION, "ExecutionBurstController::compute");
594
595 std::lock_guard<std::mutex> guard(mMutex);
596
597 // send request packet
598 const std::vector<int32_t> slots = mMemoryCache->getSlots(request.pools, memoryIds);
599 const bool success = mRequestChannelSender->send(request, measure, slots);
600 if (!success) {
601 LOG(ERROR) << "Error sending FMQ packet";
602 // only use fallback execution path if the packet could not be sent
603 return getExecutionResult(V1_0::ErrorStatus::GENERAL_FAILURE, {}, kNoTiming,
604 /*fallback=*/true);
605 }
606
607 // get result packet
608 const auto result = mResultChannelReceiver->getBlocking();
609 if (!result) {
610 LOG(ERROR) << "Error retrieving FMQ packet";
611 // only use fallback execution path if the packet could not be sent
612 return getExecutionResult(V1_0::ErrorStatus::GENERAL_FAILURE, {}, kNoTiming,
613 /*fallback=*/false);
614 }
615
616 // unpack results and return (only use fallback execution path if the
617 // packet could not be sent)
618 auto [status, outputShapes, timing] = std::move(*result);
619 return getExecutionResult(status, std::move(outputShapes), timing, /*fallback=*/false);
620 }
621
freeMemory(intptr_t key)622 void ExecutionBurstController::freeMemory(intptr_t key) {
623 std::lock_guard<std::mutex> guard(mMutex);
624
625 bool valid;
626 int32_t slot;
627 std::tie(valid, slot) = mMemoryCache->freeMemory(key);
628 if (valid) {
629 mBurstContext->freeMemory(slot).isOk();
630 }
631 }
632
633 } // namespace android::nn
634