1 /*
2  * Copyright (C) 2019 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #define LOG_TAG "ExecutionBurstController"
18 
19 #include "ExecutionBurstController.h"
20 
21 #include <android-base/logging.h>
22 
23 #include <algorithm>
24 #include <cstring>
25 #include <limits>
26 #include <memory>
27 #include <string>
28 #include <tuple>
29 #include <utility>
30 #include <vector>
31 
32 #include "HalInterfaces.h"
33 #include "Tracing.h"
34 #include "Utils.h"
35 
36 namespace android::nn {
37 namespace {
38 
39 using namespace hal;
40 
41 using V1_2::FmqRequestDatum;
42 using V1_2::FmqResultDatum;
43 using V1_2::IBurstCallback;
44 using V1_2::IBurstContext;
45 using FmqRequestDescriptor = hardware::MQDescriptorSync<FmqRequestDatum>;
46 using FmqResultDescriptor = hardware::MQDescriptorSync<FmqResultDatum>;
47 
48 constexpr Timing kNoTiming = {std::numeric_limits<uint64_t>::max(),
49                               std::numeric_limits<uint64_t>::max()};
50 
51 class BurstContextDeathHandler : public hidl_death_recipient {
52    public:
53     using Callback = std::function<void()>;
54 
BurstContextDeathHandler(const Callback & onDeathCallback)55     BurstContextDeathHandler(const Callback& onDeathCallback) : mOnDeathCallback(onDeathCallback) {
56         CHECK(onDeathCallback != nullptr);
57     }
58 
serviceDied(uint64_t,const wp<hidl::base::V1_0::IBase> &)59     void serviceDied(uint64_t /*cookie*/, const wp<hidl::base::V1_0::IBase>& /*who*/) override {
60         LOG(ERROR) << "BurstContextDeathHandler::serviceDied -- service unexpectedly died!";
61         mOnDeathCallback();
62     }
63 
64    private:
65     const Callback mOnDeathCallback;
66 };
67 
68 }  // anonymous namespace
69 
70 // serialize a request into a packet
serialize(const V1_0::Request & request,MeasureTiming measure,const std::vector<int32_t> & slots)71 std::vector<FmqRequestDatum> serialize(const V1_0::Request& request, MeasureTiming measure,
72                                        const std::vector<int32_t>& slots) {
73     // count how many elements need to be sent for a request
74     size_t count = 2 + request.inputs.size() + request.outputs.size() + request.pools.size();
75     for (const auto& input : request.inputs) {
76         count += input.dimensions.size();
77     }
78     for (const auto& output : request.outputs) {
79         count += output.dimensions.size();
80     }
81 
82     // create buffer to temporarily store elements
83     std::vector<FmqRequestDatum> data;
84     data.reserve(count);
85 
86     // package packetInfo
87     {
88         FmqRequestDatum datum;
89         datum.packetInformation(
90                 {/*.packetSize=*/static_cast<uint32_t>(count),
91                  /*.numberOfInputOperands=*/static_cast<uint32_t>(request.inputs.size()),
92                  /*.numberOfOutputOperands=*/static_cast<uint32_t>(request.outputs.size()),
93                  /*.numberOfPools=*/static_cast<uint32_t>(request.pools.size())});
94         data.push_back(datum);
95     }
96 
97     // package input data
98     for (const auto& input : request.inputs) {
99         // package operand information
100         FmqRequestDatum datum;
101         datum.inputOperandInformation(
102                 {/*.hasNoValue=*/input.hasNoValue,
103                  /*.location=*/input.location,
104                  /*.numberOfDimensions=*/static_cast<uint32_t>(input.dimensions.size())});
105         data.push_back(datum);
106 
107         // package operand dimensions
108         for (uint32_t dimension : input.dimensions) {
109             FmqRequestDatum datum;
110             datum.inputOperandDimensionValue(dimension);
111             data.push_back(datum);
112         }
113     }
114 
115     // package output data
116     for (const auto& output : request.outputs) {
117         // package operand information
118         FmqRequestDatum datum;
119         datum.outputOperandInformation(
120                 {/*.hasNoValue=*/output.hasNoValue,
121                  /*.location=*/output.location,
122                  /*.numberOfDimensions=*/static_cast<uint32_t>(output.dimensions.size())});
123         data.push_back(datum);
124 
125         // package operand dimensions
126         for (uint32_t dimension : output.dimensions) {
127             FmqRequestDatum datum;
128             datum.outputOperandDimensionValue(dimension);
129             data.push_back(datum);
130         }
131     }
132 
133     // package pool identifier
134     for (int32_t slot : slots) {
135         FmqRequestDatum datum;
136         datum.poolIdentifier(slot);
137         data.push_back(datum);
138     }
139 
140     // package measureTiming
141     {
142         FmqRequestDatum datum;
143         datum.measureTiming(measure);
144         data.push_back(datum);
145     }
146 
147     // return packet
148     return data;
149 }
150 
151 // deserialize a packet into the result
deserialize(const std::vector<FmqResultDatum> & data)152 std::optional<std::tuple<V1_0::ErrorStatus, std::vector<OutputShape>, Timing>> deserialize(
153         const std::vector<FmqResultDatum>& data) {
154     using discriminator = FmqResultDatum::hidl_discriminator;
155 
156     std::vector<OutputShape> outputShapes;
157     size_t index = 0;
158 
159     // validate packet information
160     if (data.size() == 0 || data[index].getDiscriminator() != discriminator::packetInformation) {
161         LOG(ERROR) << "FMQ Result packet ill-formed";
162         return std::nullopt;
163     }
164 
165     // unpackage packet information
166     const FmqResultDatum::PacketInformation& packetInfo = data[index].packetInformation();
167     index++;
168     const uint32_t packetSize = packetInfo.packetSize;
169     const V1_0::ErrorStatus errorStatus = packetInfo.errorStatus;
170     const uint32_t numberOfOperands = packetInfo.numberOfOperands;
171 
172     // verify packet size
173     if (data.size() != packetSize) {
174         LOG(ERROR) << "FMQ Result packet ill-formed";
175         return std::nullopt;
176     }
177 
178     // unpackage operands
179     for (size_t operand = 0; operand < numberOfOperands; ++operand) {
180         // validate operand information
181         if (data[index].getDiscriminator() != discriminator::operandInformation) {
182             LOG(ERROR) << "FMQ Result packet ill-formed";
183             return std::nullopt;
184         }
185 
186         // unpackage operand information
187         const FmqResultDatum::OperandInformation& operandInfo = data[index].operandInformation();
188         index++;
189         const bool isSufficient = operandInfo.isSufficient;
190         const uint32_t numberOfDimensions = operandInfo.numberOfDimensions;
191 
192         // unpackage operand dimensions
193         std::vector<uint32_t> dimensions;
194         dimensions.reserve(numberOfDimensions);
195         for (size_t i = 0; i < numberOfDimensions; ++i) {
196             // validate dimension
197             if (data[index].getDiscriminator() != discriminator::operandDimensionValue) {
198                 LOG(ERROR) << "FMQ Result packet ill-formed";
199                 return std::nullopt;
200             }
201 
202             // unpackage dimension
203             const uint32_t dimension = data[index].operandDimensionValue();
204             index++;
205 
206             // store result
207             dimensions.push_back(dimension);
208         }
209 
210         // store result
211         outputShapes.push_back({/*.dimensions=*/dimensions, /*.isSufficient=*/isSufficient});
212     }
213 
214     // validate execution timing
215     if (data[index].getDiscriminator() != discriminator::executionTiming) {
216         LOG(ERROR) << "FMQ Result packet ill-formed";
217         return std::nullopt;
218     }
219 
220     // unpackage execution timing
221     const Timing timing = data[index].executionTiming();
222     index++;
223 
224     // validate packet information
225     if (index != packetSize) {
226         LOG(ERROR) << "FMQ Result packet ill-formed";
227         return std::nullopt;
228     }
229 
230     // return result
231     return std::make_tuple(errorStatus, std::move(outputShapes), timing);
232 }
233 
legacyConvertResultCodeToErrorStatus(int resultCode)234 V1_0::ErrorStatus legacyConvertResultCodeToErrorStatus(int resultCode) {
235     return convertToV1_0(convertResultCodeToErrorStatus(resultCode));
236 }
237 
238 std::pair<std::unique_ptr<ResultChannelReceiver>, const FmqResultDescriptor*>
create(size_t channelLength,std::chrono::microseconds pollingTimeWindow)239 ResultChannelReceiver::create(size_t channelLength, std::chrono::microseconds pollingTimeWindow) {
240     std::unique_ptr<FmqResultChannel> fmqResultChannel =
241             std::make_unique<FmqResultChannel>(channelLength, /*confEventFlag=*/true);
242     if (!fmqResultChannel->isValid()) {
243         LOG(ERROR) << "Unable to create ResultChannelReceiver";
244         return {nullptr, nullptr};
245     }
246 
247     const FmqResultDescriptor* descriptor = fmqResultChannel->getDesc();
248     return std::make_pair(
249             std::make_unique<ResultChannelReceiver>(std::move(fmqResultChannel), pollingTimeWindow),
250             descriptor);
251 }
252 
ResultChannelReceiver(std::unique_ptr<FmqResultChannel> fmqResultChannel,std::chrono::microseconds pollingTimeWindow)253 ResultChannelReceiver::ResultChannelReceiver(std::unique_ptr<FmqResultChannel> fmqResultChannel,
254                                              std::chrono::microseconds pollingTimeWindow)
255     : mFmqResultChannel(std::move(fmqResultChannel)), kPollingTimeWindow(pollingTimeWindow) {}
256 
257 std::optional<std::tuple<V1_0::ErrorStatus, std::vector<OutputShape>, Timing>>
getBlocking()258 ResultChannelReceiver::getBlocking() {
259     const auto packet = getPacketBlocking();
260     if (!packet) {
261         return std::nullopt;
262     }
263 
264     return deserialize(*packet);
265 }
266 
invalidate()267 void ResultChannelReceiver::invalidate() {
268     mValid = false;
269 
270     // force unblock
271     // ExecutionBurstController waits on a result packet after sending a
272     // request. If the driver containing ExecutionBurstServer crashes, the
273     // controller may be waiting on the futex. This force unblock wakes up any
274     // thread waiting on the futex.
275     // TODO: look for a different/better way to signal/notify the futex to
276     // wake up any thread waiting on it
277     FmqResultDatum datum;
278     datum.packetInformation({/*.packetSize=*/0, /*.errorStatus=*/V1_0::ErrorStatus::GENERAL_FAILURE,
279                              /*.numberOfOperands=*/0});
280     mFmqResultChannel->writeBlocking(&datum, 1);
281 }
282 
getPacketBlocking()283 std::optional<std::vector<FmqResultDatum>> ResultChannelReceiver::getPacketBlocking() {
284     using discriminator = FmqResultDatum::hidl_discriminator;
285 
286     if (!mValid) {
287         return std::nullopt;
288     }
289 
290     // First spend time polling if results are available in FMQ instead of
291     // waiting on the futex. Polling is more responsive (yielding lower
292     // latencies), but can take up more power, so only poll for a limited period
293     // of time.
294 
295     auto& getCurrentTime = std::chrono::high_resolution_clock::now;
296     const auto timeToStopPolling = getCurrentTime() + kPollingTimeWindow;
297 
298     while (getCurrentTime() < timeToStopPolling) {
299         // if class is being torn down, immediately return
300         if (!mValid.load(std::memory_order_relaxed)) {
301             return std::nullopt;
302         }
303 
304         // Check if data is available. If it is, immediately retrieve it and
305         // return.
306         const size_t available = mFmqResultChannel->availableToRead();
307         if (available > 0) {
308             std::vector<FmqResultDatum> packet(available);
309             const bool success = mFmqResultChannel->read(packet.data(), available);
310             if (!success) {
311                 LOG(ERROR) << "Error receiving packet";
312                 return std::nullopt;
313             }
314             return std::make_optional(std::move(packet));
315         }
316     }
317 
318     // If we get to this point, we either stopped polling because it was taking
319     // too long or polling was not allowed. Instead, perform a blocking call
320     // which uses a futex to save power.
321 
322     // wait for result packet and read first element of result packet
323     FmqResultDatum datum;
324     bool success = mFmqResultChannel->readBlocking(&datum, 1);
325 
326     // retrieve remaining elements
327     // NOTE: all of the data is already available at this point, so there's no
328     // need to do a blocking wait to wait for more data. This is known because
329     // in FMQ, all writes are published (made available) atomically. Currently,
330     // the producer always publishes the entire packet in one function call, so
331     // if the first element of the packet is available, the remaining elements
332     // are also available.
333     const size_t count = mFmqResultChannel->availableToRead();
334     std::vector<FmqResultDatum> packet(count + 1);
335     std::memcpy(&packet.front(), &datum, sizeof(datum));
336     success &= mFmqResultChannel->read(packet.data() + 1, count);
337 
338     if (!mValid) {
339         return std::nullopt;
340     }
341 
342     // ensure packet was successfully received
343     if (!success) {
344         LOG(ERROR) << "Error receiving packet";
345         return std::nullopt;
346     }
347 
348     return std::make_optional(std::move(packet));
349 }
350 
351 std::pair<std::unique_ptr<RequestChannelSender>, const FmqRequestDescriptor*>
create(size_t channelLength)352 RequestChannelSender::create(size_t channelLength) {
353     std::unique_ptr<FmqRequestChannel> fmqRequestChannel =
354             std::make_unique<FmqRequestChannel>(channelLength, /*confEventFlag=*/true);
355     if (!fmqRequestChannel->isValid()) {
356         LOG(ERROR) << "Unable to create RequestChannelSender";
357         return {nullptr, nullptr};
358     }
359 
360     const FmqRequestDescriptor* descriptor = fmqRequestChannel->getDesc();
361     return std::make_pair(std::make_unique<RequestChannelSender>(std::move(fmqRequestChannel)),
362                           descriptor);
363 }
364 
RequestChannelSender(std::unique_ptr<FmqRequestChannel> fmqRequestChannel)365 RequestChannelSender::RequestChannelSender(std::unique_ptr<FmqRequestChannel> fmqRequestChannel)
366     : mFmqRequestChannel(std::move(fmqRequestChannel)) {}
367 
send(const V1_0::Request & request,MeasureTiming measure,const std::vector<int32_t> & slots)368 bool RequestChannelSender::send(const V1_0::Request& request, MeasureTiming measure,
369                                 const std::vector<int32_t>& slots) {
370     const std::vector<FmqRequestDatum> serialized = serialize(request, measure, slots);
371     return sendPacket(serialized);
372 }
373 
sendPacket(const std::vector<FmqRequestDatum> & packet)374 bool RequestChannelSender::sendPacket(const std::vector<FmqRequestDatum>& packet) {
375     if (!mValid) {
376         return false;
377     }
378 
379     if (packet.size() > mFmqRequestChannel->availableToWrite()) {
380         LOG(ERROR)
381                 << "RequestChannelSender::sendPacket -- packet size exceeds size available in FMQ";
382         return false;
383     }
384 
385     // Always send the packet with "blocking" because this signals the futex and
386     // unblocks the consumer if it is waiting on the futex.
387     return mFmqRequestChannel->writeBlocking(packet.data(), packet.size());
388 }
389 
invalidate()390 void RequestChannelSender::invalidate() {
391     mValid = false;
392 }
393 
getMemories(const hidl_vec<int32_t> & slots,getMemories_cb cb)394 Return<void> ExecutionBurstController::ExecutionBurstCallback::getMemories(
395         const hidl_vec<int32_t>& slots, getMemories_cb cb) {
396     std::lock_guard<std::mutex> guard(mMutex);
397 
398     // get all memories
399     hidl_vec<hidl_memory> memories(slots.size());
400     std::transform(slots.begin(), slots.end(), memories.begin(), [this](int32_t slot) {
401         return slot < mMemoryCache.size() ? mMemoryCache[slot] : hidl_memory{};
402     });
403 
404     // ensure all memories are valid
405     if (!std::all_of(memories.begin(), memories.end(),
406                      [](const hidl_memory& memory) { return memory.valid(); })) {
407         cb(V1_0::ErrorStatus::INVALID_ARGUMENT, {});
408         return Void();
409     }
410 
411     // return successful
412     cb(V1_0::ErrorStatus::NONE, std::move(memories));
413     return Void();
414 }
415 
getSlots(const hidl_vec<hidl_memory> & memories,const std::vector<intptr_t> & keys)416 std::vector<int32_t> ExecutionBurstController::ExecutionBurstCallback::getSlots(
417         const hidl_vec<hidl_memory>& memories, const std::vector<intptr_t>& keys) {
418     std::lock_guard<std::mutex> guard(mMutex);
419 
420     // retrieve (or bind) all slots corresponding to memories
421     std::vector<int32_t> slots;
422     slots.reserve(memories.size());
423     for (size_t i = 0; i < memories.size(); ++i) {
424         slots.push_back(getSlotLocked(memories[i], keys[i]));
425     }
426     return slots;
427 }
428 
freeMemory(intptr_t key)429 std::pair<bool, int32_t> ExecutionBurstController::ExecutionBurstCallback::freeMemory(
430         intptr_t key) {
431     std::lock_guard<std::mutex> guard(mMutex);
432 
433     auto iter = mMemoryIdToSlot.find(key);
434     if (iter == mMemoryIdToSlot.end()) {
435         return {false, 0};
436     }
437     const int32_t slot = iter->second;
438     mMemoryIdToSlot.erase(key);
439     mMemoryCache[slot] = {};
440     mFreeSlots.push(slot);
441     return {true, slot};
442 }
443 
getSlotLocked(const hidl_memory & memory,intptr_t key)444 int32_t ExecutionBurstController::ExecutionBurstCallback::getSlotLocked(const hidl_memory& memory,
445                                                                         intptr_t key) {
446     auto iter = mMemoryIdToSlot.find(key);
447     if (iter == mMemoryIdToSlot.end()) {
448         const int32_t slot = allocateSlotLocked();
449         mMemoryIdToSlot[key] = slot;
450         mMemoryCache[slot] = memory;
451         return slot;
452     } else {
453         const int32_t slot = iter->second;
454         return slot;
455     }
456 }
457 
allocateSlotLocked()458 int32_t ExecutionBurstController::ExecutionBurstCallback::allocateSlotLocked() {
459     constexpr size_t kMaxNumberOfSlots = std::numeric_limits<int32_t>::max();
460 
461     // if there is a free slot, use it
462     if (mFreeSlots.size() > 0) {
463         const int32_t slot = mFreeSlots.top();
464         mFreeSlots.pop();
465         return slot;
466     }
467 
468     // otherwise use a slot for the first time
469     CHECK(mMemoryCache.size() < kMaxNumberOfSlots) << "Exceeded maximum number of slots!";
470     const int32_t slot = static_cast<int32_t>(mMemoryCache.size());
471     mMemoryCache.emplace_back();
472 
473     return slot;
474 }
475 
create(const sp<V1_2::IPreparedModel> & preparedModel,std::chrono::microseconds pollingTimeWindow)476 std::unique_ptr<ExecutionBurstController> ExecutionBurstController::create(
477         const sp<V1_2::IPreparedModel>& preparedModel,
478         std::chrono::microseconds pollingTimeWindow) {
479     // check inputs
480     if (preparedModel == nullptr) {
481         LOG(ERROR) << "ExecutionBurstController::create passed a nullptr";
482         return nullptr;
483     }
484 
485     // create callback object
486     sp<ExecutionBurstCallback> callback = new ExecutionBurstCallback();
487 
488     // create FMQ objects
489     auto [requestChannelSenderTemp, requestChannelDescriptor] =
490             RequestChannelSender::create(kExecutionBurstChannelLength);
491     auto [resultChannelReceiverTemp, resultChannelDescriptor] =
492             ResultChannelReceiver::create(kExecutionBurstChannelLength, pollingTimeWindow);
493     std::shared_ptr<RequestChannelSender> requestChannelSender =
494             std::move(requestChannelSenderTemp);
495     std::shared_ptr<ResultChannelReceiver> resultChannelReceiver =
496             std::move(resultChannelReceiverTemp);
497 
498     // check FMQ objects
499     if (!requestChannelSender || !resultChannelReceiver || !requestChannelDescriptor ||
500         !resultChannelDescriptor) {
501         LOG(ERROR) << "ExecutionBurstController::create failed to create FastMessageQueue";
502         return nullptr;
503     }
504 
505     // configure burst
506     V1_0::ErrorStatus errorStatus;
507     sp<IBurstContext> burstContext;
508     const Return<void> ret = preparedModel->configureExecutionBurst(
509             callback, *requestChannelDescriptor, *resultChannelDescriptor,
510             [&errorStatus, &burstContext](V1_0::ErrorStatus status,
511                                           const sp<IBurstContext>& context) {
512                 errorStatus = status;
513                 burstContext = context;
514             });
515 
516     // check burst
517     if (!ret.isOk()) {
518         LOG(ERROR) << "IPreparedModel::configureExecutionBurst failed with description "
519                    << ret.description();
520         return nullptr;
521     }
522     if (errorStatus != V1_0::ErrorStatus::NONE) {
523         LOG(ERROR) << "IPreparedModel::configureExecutionBurst failed with status "
524                    << toString(errorStatus);
525         return nullptr;
526     }
527     if (burstContext == nullptr) {
528         LOG(ERROR) << "IPreparedModel::configureExecutionBurst returned nullptr for burst";
529         return nullptr;
530     }
531 
532     // create death handler object
533     BurstContextDeathHandler::Callback onDeathCallback = [requestChannelSender,
534                                                           resultChannelReceiver] {
535         requestChannelSender->invalidate();
536         resultChannelReceiver->invalidate();
537     };
538     const sp<BurstContextDeathHandler> deathHandler = new BurstContextDeathHandler(onDeathCallback);
539 
540     // linkToDeath registers a callback that will be invoked on service death to
541     // proactively handle service crashes. If the linkToDeath call fails,
542     // asynchronous calls are susceptible to hangs if the service crashes before
543     // providing the response.
544     const Return<bool> deathHandlerRet = burstContext->linkToDeath(deathHandler, 0);
545     if (!deathHandlerRet.isOk() || deathHandlerRet != true) {
546         LOG(ERROR) << "ExecutionBurstController::create -- Failed to register a death recipient "
547                       "for the IBurstContext object.";
548         return nullptr;
549     }
550 
551     // make and return controller
552     return std::make_unique<ExecutionBurstController>(requestChannelSender, resultChannelReceiver,
553                                                       burstContext, callback, deathHandler);
554 }
555 
ExecutionBurstController(const std::shared_ptr<RequestChannelSender> & requestChannelSender,const std::shared_ptr<ResultChannelReceiver> & resultChannelReceiver,const sp<IBurstContext> & burstContext,const sp<ExecutionBurstCallback> & callback,const sp<hidl_death_recipient> & deathHandler)556 ExecutionBurstController::ExecutionBurstController(
557         const std::shared_ptr<RequestChannelSender>& requestChannelSender,
558         const std::shared_ptr<ResultChannelReceiver>& resultChannelReceiver,
559         const sp<IBurstContext>& burstContext, const sp<ExecutionBurstCallback>& callback,
560         const sp<hidl_death_recipient>& deathHandler)
561     : mRequestChannelSender(requestChannelSender),
562       mResultChannelReceiver(resultChannelReceiver),
563       mBurstContext(burstContext),
564       mMemoryCache(callback),
565       mDeathHandler(deathHandler) {}
566 
~ExecutionBurstController()567 ExecutionBurstController::~ExecutionBurstController() {
568     // It is safe to ignore any errors resulting from this unlinkToDeath call
569     // because the ExecutionBurstController object is already being destroyed
570     // and its underlying IBurstContext object is no longer being used by the NN
571     // runtime.
572     if (mDeathHandler) {
573         mBurstContext->unlinkToDeath(mDeathHandler).isOk();
574     }
575 }
576 
getExecutionResult(V1_0::ErrorStatus status,std::vector<OutputShape> outputShapes,Timing timing,bool fallback)577 static std::tuple<int, std::vector<OutputShape>, Timing, bool> getExecutionResult(
578         V1_0::ErrorStatus status, std::vector<OutputShape> outputShapes, Timing timing,
579         bool fallback) {
580     auto [n, checkedOutputShapes, checkedTiming] =
581             getExecutionResult(convertToV1_3(status), std::move(outputShapes), timing);
582     return {n, std::move(checkedOutputShapes), checkedTiming, fallback};
583 }
584 
compute(const V1_0::Request & request,MeasureTiming measure,const std::vector<intptr_t> & memoryIds)585 std::tuple<int, std::vector<OutputShape>, Timing, bool> ExecutionBurstController::compute(
586         const V1_0::Request& request, MeasureTiming measure,
587         const std::vector<intptr_t>& memoryIds) {
588     // This is the first point when we know an execution is occurring, so begin
589     // to collect systraces. Note that the first point we can begin collecting
590     // systraces in ExecutionBurstServer is when the RequestChannelReceiver
591     // realizes there is data in the FMQ, so ExecutionBurstServer collects
592     // systraces at different points in the code.
593     NNTRACE_FULL(NNTRACE_LAYER_IPC, NNTRACE_PHASE_EXECUTION, "ExecutionBurstController::compute");
594 
595     std::lock_guard<std::mutex> guard(mMutex);
596 
597     // send request packet
598     const std::vector<int32_t> slots = mMemoryCache->getSlots(request.pools, memoryIds);
599     const bool success = mRequestChannelSender->send(request, measure, slots);
600     if (!success) {
601         LOG(ERROR) << "Error sending FMQ packet";
602         // only use fallback execution path if the packet could not be sent
603         return getExecutionResult(V1_0::ErrorStatus::GENERAL_FAILURE, {}, kNoTiming,
604                                   /*fallback=*/true);
605     }
606 
607     // get result packet
608     const auto result = mResultChannelReceiver->getBlocking();
609     if (!result) {
610         LOG(ERROR) << "Error retrieving FMQ packet";
611         // only use fallback execution path if the packet could not be sent
612         return getExecutionResult(V1_0::ErrorStatus::GENERAL_FAILURE, {}, kNoTiming,
613                                   /*fallback=*/false);
614     }
615 
616     // unpack results and return (only use fallback execution path if the
617     // packet could not be sent)
618     auto [status, outputShapes, timing] = std::move(*result);
619     return getExecutionResult(status, std::move(outputShapes), timing, /*fallback=*/false);
620 }
621 
freeMemory(intptr_t key)622 void ExecutionBurstController::freeMemory(intptr_t key) {
623     std::lock_guard<std::mutex> guard(mMutex);
624 
625     bool valid;
626     int32_t slot;
627     std::tie(valid, slot) = mMemoryCache->freeMemory(key);
628     if (valid) {
629         mBurstContext->freeMemory(slot).isOk();
630     }
631 }
632 
633 }  // namespace android::nn
634