1 /*
2 * Copyright (C) 2019 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #define LOG_TAG "ExecutionBurstUtils"
18
19 #include "ExecutionBurstUtils.h"
20
21 #include <android-base/logging.h>
22 #include <android-base/properties.h>
23 #include <android/hardware/neuralnetworks/1.0/types.h>
24 #include <android/hardware/neuralnetworks/1.1/types.h>
25 #include <android/hardware/neuralnetworks/1.2/types.h>
26 #include <fmq/MessageQueue.h>
27 #include <hidl/MQDescriptor.h>
28 #include <nnapi/Result.h>
29 #include <nnapi/Types.h>
30 #include <nnapi/hal/ProtectCallback.h>
31
32 #include <atomic>
33 #include <chrono>
34 #include <memory>
35 #include <thread>
36 #include <tuple>
37 #include <utility>
38 #include <vector>
39
40 namespace android::hardware::neuralnetworks::V1_2::utils {
41 namespace {
42
43 constexpr V1_2::Timing kNoTiming = {std::numeric_limits<uint64_t>::max(),
44 std::numeric_limits<uint64_t>::max()};
45
getPollingTimeWindow(const std::string & property)46 std::chrono::microseconds getPollingTimeWindow(const std::string& property) {
47 constexpr int32_t kDefaultPollingTimeWindow = 0;
48 #ifdef NN_DEBUGGABLE
49 constexpr int32_t kMinPollingTimeWindow = 0;
50 const int32_t selectedPollingTimeWindow =
51 base::GetIntProperty(property, kDefaultPollingTimeWindow, kMinPollingTimeWindow);
52 return std::chrono::microseconds(selectedPollingTimeWindow);
53 #else
54 (void)property;
55 return std::chrono::microseconds(kDefaultPollingTimeWindow);
56 #endif // NN_DEBUGGABLE
57 }
58
59 } // namespace
60
getBurstControllerPollingTimeWindow()61 std::chrono::microseconds getBurstControllerPollingTimeWindow() {
62 return getPollingTimeWindow("debug.nn.burst-controller-polling-window");
63 }
64
getBurstServerPollingTimeWindow()65 std::chrono::microseconds getBurstServerPollingTimeWindow() {
66 return getPollingTimeWindow("debug.nn.burst-server-polling-window");
67 }
68
69 // serialize a request into a packet
serialize(const V1_0::Request & request,V1_2::MeasureTiming measure,const std::vector<int32_t> & slots)70 std::vector<FmqRequestDatum> serialize(const V1_0::Request& request, V1_2::MeasureTiming measure,
71 const std::vector<int32_t>& slots) {
72 // count how many elements need to be sent for a request
73 size_t count = 2 + request.inputs.size() + request.outputs.size() + slots.size();
74 for (const auto& input : request.inputs) {
75 count += input.dimensions.size();
76 }
77 for (const auto& output : request.outputs) {
78 count += output.dimensions.size();
79 }
80 CHECK_LE(count, std::numeric_limits<uint32_t>::max());
81
82 // create buffer to temporarily store elements
83 std::vector<FmqRequestDatum> data;
84 data.reserve(count);
85
86 // package packetInfo
87 data.emplace_back();
88 data.back().packetInformation(
89 {.packetSize = static_cast<uint32_t>(count),
90 .numberOfInputOperands = static_cast<uint32_t>(request.inputs.size()),
91 .numberOfOutputOperands = static_cast<uint32_t>(request.outputs.size()),
92 .numberOfPools = static_cast<uint32_t>(slots.size())});
93
94 // package input data
95 for (const auto& input : request.inputs) {
96 // package operand information
97 data.emplace_back();
98 data.back().inputOperandInformation(
99 {.hasNoValue = input.hasNoValue,
100 .location = input.location,
101 .numberOfDimensions = static_cast<uint32_t>(input.dimensions.size())});
102
103 // package operand dimensions
104 for (uint32_t dimension : input.dimensions) {
105 data.emplace_back();
106 data.back().inputOperandDimensionValue(dimension);
107 }
108 }
109
110 // package output data
111 for (const auto& output : request.outputs) {
112 // package operand information
113 data.emplace_back();
114 data.back().outputOperandInformation(
115 {.hasNoValue = output.hasNoValue,
116 .location = output.location,
117 .numberOfDimensions = static_cast<uint32_t>(output.dimensions.size())});
118
119 // package operand dimensions
120 for (uint32_t dimension : output.dimensions) {
121 data.emplace_back();
122 data.back().outputOperandDimensionValue(dimension);
123 }
124 }
125
126 // package pool identifier
127 for (int32_t slot : slots) {
128 data.emplace_back();
129 data.back().poolIdentifier(slot);
130 }
131
132 // package measureTiming
133 data.emplace_back();
134 data.back().measureTiming(measure);
135
136 CHECK_EQ(data.size(), count);
137
138 // return packet
139 return data;
140 }
141
142 // serialize result
serialize(V1_0::ErrorStatus errorStatus,const std::vector<V1_2::OutputShape> & outputShapes,V1_2::Timing timing)143 std::vector<FmqResultDatum> serialize(V1_0::ErrorStatus errorStatus,
144 const std::vector<V1_2::OutputShape>& outputShapes,
145 V1_2::Timing timing) {
146 // count how many elements need to be sent for a request
147 size_t count = 2 + outputShapes.size();
148 for (const auto& outputShape : outputShapes) {
149 count += outputShape.dimensions.size();
150 }
151
152 // create buffer to temporarily store elements
153 std::vector<FmqResultDatum> data;
154 data.reserve(count);
155
156 // package packetInfo
157 data.emplace_back();
158 data.back().packetInformation({.packetSize = static_cast<uint32_t>(count),
159 .errorStatus = errorStatus,
160 .numberOfOperands = static_cast<uint32_t>(outputShapes.size())});
161
162 // package output shape data
163 for (const auto& operand : outputShapes) {
164 // package operand information
165 data.emplace_back();
166 data.back().operandInformation(
167 {.isSufficient = operand.isSufficient,
168 .numberOfDimensions = static_cast<uint32_t>(operand.dimensions.size())});
169
170 // package operand dimensions
171 for (uint32_t dimension : operand.dimensions) {
172 data.emplace_back();
173 data.back().operandDimensionValue(dimension);
174 }
175 }
176
177 // package executionTiming
178 data.emplace_back();
179 data.back().executionTiming(timing);
180
181 CHECK_EQ(data.size(), count);
182
183 // return result
184 return data;
185 }
186
187 // deserialize request
deserialize(const std::vector<FmqRequestDatum> & data)188 nn::Result<std::tuple<V1_0::Request, std::vector<int32_t>, V1_2::MeasureTiming>> deserialize(
189 const std::vector<FmqRequestDatum>& data) {
190 using discriminator = FmqRequestDatum::hidl_discriminator;
191
192 size_t index = 0;
193
194 // validate packet information
195 if (data.size() == 0 || data[index].getDiscriminator() != discriminator::packetInformation) {
196 return NN_ERROR() << "FMQ Request packet ill-formed";
197 }
198
199 // unpackage packet information
200 const FmqRequestDatum::PacketInformation& packetInfo = data[index].packetInformation();
201 index++;
202 const uint32_t packetSize = packetInfo.packetSize;
203 const uint32_t numberOfInputOperands = packetInfo.numberOfInputOperands;
204 const uint32_t numberOfOutputOperands = packetInfo.numberOfOutputOperands;
205 const uint32_t numberOfPools = packetInfo.numberOfPools;
206
207 // verify packet size
208 if (data.size() != packetSize) {
209 return NN_ERROR() << "FMQ Request packet ill-formed";
210 }
211
212 // unpackage input operands
213 std::vector<V1_0::RequestArgument> inputs;
214 inputs.reserve(numberOfInputOperands);
215 for (size_t operand = 0; operand < numberOfInputOperands; ++operand) {
216 // validate input operand information
217 if (data[index].getDiscriminator() != discriminator::inputOperandInformation) {
218 return NN_ERROR() << "FMQ Request packet ill-formed";
219 }
220
221 // unpackage operand information
222 const FmqRequestDatum::OperandInformation& operandInfo =
223 data[index].inputOperandInformation();
224 index++;
225 const bool hasNoValue = operandInfo.hasNoValue;
226 const V1_0::DataLocation location = operandInfo.location;
227 const uint32_t numberOfDimensions = operandInfo.numberOfDimensions;
228
229 // unpackage operand dimensions
230 std::vector<uint32_t> dimensions;
231 dimensions.reserve(numberOfDimensions);
232 for (size_t i = 0; i < numberOfDimensions; ++i) {
233 // validate dimension
234 if (data[index].getDiscriminator() != discriminator::inputOperandDimensionValue) {
235 return NN_ERROR() << "FMQ Request packet ill-formed";
236 }
237
238 // unpackage dimension
239 const uint32_t dimension = data[index].inputOperandDimensionValue();
240 index++;
241
242 // store result
243 dimensions.push_back(dimension);
244 }
245
246 // store result
247 inputs.push_back(
248 {.hasNoValue = hasNoValue, .location = location, .dimensions = dimensions});
249 }
250
251 // unpackage output operands
252 std::vector<V1_0::RequestArgument> outputs;
253 outputs.reserve(numberOfOutputOperands);
254 for (size_t operand = 0; operand < numberOfOutputOperands; ++operand) {
255 // validate output operand information
256 if (data[index].getDiscriminator() != discriminator::outputOperandInformation) {
257 return NN_ERROR() << "FMQ Request packet ill-formed";
258 }
259
260 // unpackage operand information
261 const FmqRequestDatum::OperandInformation& operandInfo =
262 data[index].outputOperandInformation();
263 index++;
264 const bool hasNoValue = operandInfo.hasNoValue;
265 const V1_0::DataLocation location = operandInfo.location;
266 const uint32_t numberOfDimensions = operandInfo.numberOfDimensions;
267
268 // unpackage operand dimensions
269 std::vector<uint32_t> dimensions;
270 dimensions.reserve(numberOfDimensions);
271 for (size_t i = 0; i < numberOfDimensions; ++i) {
272 // validate dimension
273 if (data[index].getDiscriminator() != discriminator::outputOperandDimensionValue) {
274 return NN_ERROR() << "FMQ Request packet ill-formed";
275 }
276
277 // unpackage dimension
278 const uint32_t dimension = data[index].outputOperandDimensionValue();
279 index++;
280
281 // store result
282 dimensions.push_back(dimension);
283 }
284
285 // store result
286 outputs.push_back(
287 {.hasNoValue = hasNoValue, .location = location, .dimensions = dimensions});
288 }
289
290 // unpackage pools
291 std::vector<int32_t> slots;
292 slots.reserve(numberOfPools);
293 for (size_t pool = 0; pool < numberOfPools; ++pool) {
294 // validate input operand information
295 if (data[index].getDiscriminator() != discriminator::poolIdentifier) {
296 return NN_ERROR() << "FMQ Request packet ill-formed";
297 }
298
299 // unpackage operand information
300 const int32_t poolId = data[index].poolIdentifier();
301 index++;
302
303 // store result
304 slots.push_back(poolId);
305 }
306
307 // validate measureTiming
308 if (data[index].getDiscriminator() != discriminator::measureTiming) {
309 return NN_ERROR() << "FMQ Request packet ill-formed";
310 }
311
312 // unpackage measureTiming
313 const V1_2::MeasureTiming measure = data[index].measureTiming();
314 index++;
315
316 // validate packet information
317 if (index != packetSize) {
318 return NN_ERROR() << "FMQ Result packet ill-formed";
319 }
320
321 // return request
322 V1_0::Request request = {.inputs = inputs, .outputs = outputs, .pools = {}};
323 return std::make_tuple(std::move(request), std::move(slots), measure);
324 }
325
326 // deserialize a packet into the result
deserialize(const std::vector<FmqResultDatum> & data)327 nn::Result<std::tuple<V1_0::ErrorStatus, std::vector<V1_2::OutputShape>, V1_2::Timing>> deserialize(
328 const std::vector<FmqResultDatum>& data) {
329 using discriminator = FmqResultDatum::hidl_discriminator;
330 size_t index = 0;
331
332 // validate packet information
333 if (data.size() == 0 || data[index].getDiscriminator() != discriminator::packetInformation) {
334 return NN_ERROR() << "FMQ Result packet ill-formed";
335 }
336
337 // unpackage packet information
338 const FmqResultDatum::PacketInformation& packetInfo = data[index].packetInformation();
339 index++;
340 const uint32_t packetSize = packetInfo.packetSize;
341 const V1_0::ErrorStatus errorStatus = packetInfo.errorStatus;
342 const uint32_t numberOfOperands = packetInfo.numberOfOperands;
343
344 // verify packet size
345 if (data.size() != packetSize) {
346 return NN_ERROR() << "FMQ Result packet ill-formed";
347 }
348
349 // unpackage operands
350 std::vector<V1_2::OutputShape> outputShapes;
351 outputShapes.reserve(numberOfOperands);
352 for (size_t operand = 0; operand < numberOfOperands; ++operand) {
353 // validate operand information
354 if (data[index].getDiscriminator() != discriminator::operandInformation) {
355 return NN_ERROR() << "FMQ Result packet ill-formed";
356 }
357
358 // unpackage operand information
359 const FmqResultDatum::OperandInformation& operandInfo = data[index].operandInformation();
360 index++;
361 const bool isSufficient = operandInfo.isSufficient;
362 const uint32_t numberOfDimensions = operandInfo.numberOfDimensions;
363
364 // unpackage operand dimensions
365 std::vector<uint32_t> dimensions;
366 dimensions.reserve(numberOfDimensions);
367 for (size_t i = 0; i < numberOfDimensions; ++i) {
368 // validate dimension
369 if (data[index].getDiscriminator() != discriminator::operandDimensionValue) {
370 return NN_ERROR() << "FMQ Result packet ill-formed";
371 }
372
373 // unpackage dimension
374 const uint32_t dimension = data[index].operandDimensionValue();
375 index++;
376
377 // store result
378 dimensions.push_back(dimension);
379 }
380
381 // store result
382 outputShapes.push_back({.dimensions = dimensions, .isSufficient = isSufficient});
383 }
384
385 // validate execution timing
386 if (data[index].getDiscriminator() != discriminator::executionTiming) {
387 return NN_ERROR() << "FMQ Result packet ill-formed";
388 }
389
390 // unpackage execution timing
391 const V1_2::Timing timing = data[index].executionTiming();
392 index++;
393
394 // validate packet information
395 if (index != packetSize) {
396 return NN_ERROR() << "FMQ Result packet ill-formed";
397 }
398
399 // return result
400 return std::make_tuple(errorStatus, std::move(outputShapes), timing);
401 }
402
403 // RequestChannelSender methods
404
405 nn::GeneralResult<
406 std::pair<std::unique_ptr<RequestChannelSender>, const MQDescriptorSync<FmqRequestDatum>*>>
create(size_t channelLength)407 RequestChannelSender::create(size_t channelLength) {
408 auto requestChannelSender =
409 std::make_unique<RequestChannelSender>(PrivateConstructorTag{}, channelLength);
410 if (!requestChannelSender->mFmqRequestChannel.isValid()) {
411 return NN_ERROR() << "Unable to create RequestChannelSender";
412 }
413
414 const MQDescriptorSync<FmqRequestDatum>* descriptor =
415 requestChannelSender->mFmqRequestChannel.getDesc();
416 return std::make_pair(std::move(requestChannelSender), descriptor);
417 }
418
RequestChannelSender(PrivateConstructorTag,size_t channelLength)419 RequestChannelSender::RequestChannelSender(PrivateConstructorTag /*tag*/, size_t channelLength)
420 : mFmqRequestChannel(channelLength, /*configureEventFlagWord=*/true) {}
421
send(const V1_0::Request & request,V1_2::MeasureTiming measure,const std::vector<int32_t> & slots)422 nn::Result<void> RequestChannelSender::send(const V1_0::Request& request,
423 V1_2::MeasureTiming measure,
424 const std::vector<int32_t>& slots) {
425 const std::vector<FmqRequestDatum> serialized = serialize(request, measure, slots);
426 return sendPacket(serialized);
427 }
428
sendPacket(const std::vector<FmqRequestDatum> & packet)429 nn::Result<void> RequestChannelSender::sendPacket(const std::vector<FmqRequestDatum>& packet) {
430 if (!mValid) {
431 return NN_ERROR() << "FMQ object is invalid";
432 }
433
434 if (packet.size() > mFmqRequestChannel.availableToWrite()) {
435 return NN_ERROR()
436 << "RequestChannelSender::sendPacket -- packet size exceeds size available in FMQ";
437 }
438
439 // Always send the packet with "blocking" because this signals the futex and unblocks the
440 // consumer if it is waiting on the futex.
441 const bool success = mFmqRequestChannel.writeBlocking(packet.data(), packet.size());
442 if (!success) {
443 return NN_ERROR()
444 << "RequestChannelSender::sendPacket -- FMQ's writeBlocking returned an error";
445 }
446
447 return {};
448 }
449
notifyAsDeadObject()450 void RequestChannelSender::notifyAsDeadObject() {
451 mValid = false;
452 }
453
454 // RequestChannelReceiver methods
455
create(const MQDescriptorSync<FmqRequestDatum> & requestChannel,std::chrono::microseconds pollingTimeWindow)456 nn::GeneralResult<std::unique_ptr<RequestChannelReceiver>> RequestChannelReceiver::create(
457 const MQDescriptorSync<FmqRequestDatum>& requestChannel,
458 std::chrono::microseconds pollingTimeWindow) {
459 auto requestChannelReceiver = std::make_unique<RequestChannelReceiver>(
460 PrivateConstructorTag{}, requestChannel, pollingTimeWindow);
461
462 if (!requestChannelReceiver->mFmqRequestChannel.isValid()) {
463 return NN_ERROR() << "Unable to create RequestChannelReceiver";
464 }
465 if (requestChannelReceiver->mFmqRequestChannel.getEventFlagWord() == nullptr) {
466 return NN_ERROR()
467 << "RequestChannelReceiver::create was passed an MQDescriptor without an EventFlag";
468 }
469
470 return requestChannelReceiver;
471 }
472
RequestChannelReceiver(PrivateConstructorTag,const MQDescriptorSync<FmqRequestDatum> & requestChannel,std::chrono::microseconds pollingTimeWindow)473 RequestChannelReceiver::RequestChannelReceiver(
474 PrivateConstructorTag /*tag*/, const MQDescriptorSync<FmqRequestDatum>& requestChannel,
475 std::chrono::microseconds pollingTimeWindow)
476 : mFmqRequestChannel(requestChannel), kPollingTimeWindow(pollingTimeWindow) {}
477
478 nn::Result<std::tuple<V1_0::Request, std::vector<int32_t>, V1_2::MeasureTiming>>
getBlocking()479 RequestChannelReceiver::getBlocking() {
480 const auto packet = NN_TRY(getPacketBlocking());
481 return deserialize(packet);
482 }
483
invalidate()484 void RequestChannelReceiver::invalidate() {
485 mTeardown = true;
486
487 // force unblock
488 // ExecutionBurstServer is by default waiting on a request packet. If the client process
489 // destroys its burst object, the server may still be waiting on the futex. This force unblock
490 // wakes up any thread waiting on the futex.
491 const auto data = serialize(V1_0::Request{}, V1_2::MeasureTiming::NO, {});
492 mFmqRequestChannel.writeBlocking(data.data(), data.size());
493 }
494
getPacketBlocking()495 nn::Result<std::vector<FmqRequestDatum>> RequestChannelReceiver::getPacketBlocking() {
496 if (mTeardown) {
497 return NN_ERROR() << "FMQ object is being torn down";
498 }
499
500 // First spend time polling if results are available in FMQ instead of waiting on the futex.
501 // Polling is more responsive (yielding lower latencies), but can take up more power, so only
502 // poll for a limited period of time.
503
504 auto& getCurrentTime = std::chrono::high_resolution_clock::now;
505 const auto timeToStopPolling = getCurrentTime() + kPollingTimeWindow;
506
507 while (getCurrentTime() < timeToStopPolling) {
508 // if class is being torn down, immediately return
509 if (mTeardown.load(std::memory_order_relaxed)) {
510 return NN_ERROR() << "FMQ object is being torn down";
511 }
512
513 // Check if data is available. If it is, immediately retrieve it and return.
514 const size_t available = mFmqRequestChannel.availableToRead();
515 if (available > 0) {
516 std::vector<FmqRequestDatum> packet(available);
517 const bool success = mFmqRequestChannel.readBlocking(packet.data(), available);
518 if (!success) {
519 return NN_ERROR() << "Error receiving packet";
520 }
521 return packet;
522 }
523
524 std::this_thread::yield();
525 }
526
527 // If we get to this point, we either stopped polling because it was taking too long or polling
528 // was not allowed. Instead, perform a blocking call which uses a futex to save power.
529
530 // wait for request packet and read first element of request packet
531 FmqRequestDatum datum;
532 bool success = mFmqRequestChannel.readBlocking(&datum, 1);
533
534 // retrieve remaining elements
535 // NOTE: all of the data is already available at this point, so there's no need to do a blocking
536 // wait to wait for more data. This is known because in FMQ, all writes are published (made
537 // available) atomically. Currently, the producer always publishes the entire packet in one
538 // function call, so if the first element of the packet is available, the remaining elements are
539 // also available.
540 const size_t count = mFmqRequestChannel.availableToRead();
541 std::vector<FmqRequestDatum> packet(count + 1);
542 std::memcpy(&packet.front(), &datum, sizeof(datum));
543 success &= mFmqRequestChannel.read(packet.data() + 1, count);
544
545 // terminate loop
546 if (mTeardown) {
547 return NN_ERROR() << "FMQ object is being torn down";
548 }
549
550 // ensure packet was successfully received
551 if (!success) {
552 return NN_ERROR() << "Error receiving packet";
553 }
554
555 return packet;
556 }
557
558 // ResultChannelSender methods
559
create(const MQDescriptorSync<FmqResultDatum> & resultChannel)560 nn::GeneralResult<std::unique_ptr<ResultChannelSender>> ResultChannelSender::create(
561 const MQDescriptorSync<FmqResultDatum>& resultChannel) {
562 auto resultChannelSender =
563 std::make_unique<ResultChannelSender>(PrivateConstructorTag{}, resultChannel);
564
565 if (!resultChannelSender->mFmqResultChannel.isValid()) {
566 return NN_ERROR() << "Unable to create RequestChannelSender";
567 }
568 if (resultChannelSender->mFmqResultChannel.getEventFlagWord() == nullptr) {
569 return NN_ERROR()
570 << "ResultChannelSender::create was passed an MQDescriptor without an EventFlag";
571 }
572
573 return resultChannelSender;
574 }
575
ResultChannelSender(PrivateConstructorTag,const MQDescriptorSync<FmqResultDatum> & resultChannel)576 ResultChannelSender::ResultChannelSender(PrivateConstructorTag /*tag*/,
577 const MQDescriptorSync<FmqResultDatum>& resultChannel)
578 : mFmqResultChannel(resultChannel) {}
579
send(V1_0::ErrorStatus errorStatus,const std::vector<V1_2::OutputShape> & outputShapes,V1_2::Timing timing)580 void ResultChannelSender::send(V1_0::ErrorStatus errorStatus,
581 const std::vector<V1_2::OutputShape>& outputShapes,
582 V1_2::Timing timing) {
583 const std::vector<FmqResultDatum> serialized = serialize(errorStatus, outputShapes, timing);
584 sendPacket(serialized);
585 }
586
sendPacket(const std::vector<FmqResultDatum> & packet)587 void ResultChannelSender::sendPacket(const std::vector<FmqResultDatum>& packet) {
588 if (packet.size() > mFmqResultChannel.availableToWrite()) {
589 LOG(ERROR)
590 << "ResultChannelSender::sendPacket -- packet size exceeds size available in FMQ";
591 const std::vector<FmqResultDatum> errorPacket =
592 serialize(V1_0::ErrorStatus::GENERAL_FAILURE, {}, kNoTiming);
593
594 // Always send the packet with "blocking" because this signals the futex and unblocks the
595 // consumer if it is waiting on the futex.
596 mFmqResultChannel.writeBlocking(errorPacket.data(), errorPacket.size());
597 } else {
598 // Always send the packet with "blocking" because this signals the futex and unblocks the
599 // consumer if it is waiting on the futex.
600 mFmqResultChannel.writeBlocking(packet.data(), packet.size());
601 }
602 }
603
604 // ResultChannelReceiver methods
605
606 nn::GeneralResult<
607 std::pair<std::unique_ptr<ResultChannelReceiver>, const MQDescriptorSync<FmqResultDatum>*>>
create(size_t channelLength,std::chrono::microseconds pollingTimeWindow)608 ResultChannelReceiver::create(size_t channelLength, std::chrono::microseconds pollingTimeWindow) {
609 auto resultChannelReceiver = std::make_unique<ResultChannelReceiver>(
610 PrivateConstructorTag{}, channelLength, pollingTimeWindow);
611 if (!resultChannelReceiver->mFmqResultChannel.isValid()) {
612 return NN_ERROR() << "Unable to create ResultChannelReceiver";
613 }
614
615 const MQDescriptorSync<FmqResultDatum>* descriptor =
616 resultChannelReceiver->mFmqResultChannel.getDesc();
617 return std::make_pair(std::move(resultChannelReceiver), descriptor);
618 }
619
ResultChannelReceiver(PrivateConstructorTag,size_t channelLength,std::chrono::microseconds pollingTimeWindow)620 ResultChannelReceiver::ResultChannelReceiver(PrivateConstructorTag /*tag*/, size_t channelLength,
621 std::chrono::microseconds pollingTimeWindow)
622 : mFmqResultChannel(channelLength, /*configureEventFlagWord=*/true),
623 kPollingTimeWindow(pollingTimeWindow) {}
624
625 nn::Result<std::tuple<V1_0::ErrorStatus, std::vector<V1_2::OutputShape>, V1_2::Timing>>
getBlocking()626 ResultChannelReceiver::getBlocking() {
627 const auto packet = NN_TRY(getPacketBlocking());
628 return deserialize(packet);
629 }
630
notifyAsDeadObject()631 void ResultChannelReceiver::notifyAsDeadObject() {
632 mValid = false;
633
634 // force unblock
635 // ExecutionBurstController waits on a result packet after sending a request. If the driver
636 // containing ExecutionBurstServer crashes, the controller may be waiting on the futex. This
637 // force unblock wakes up any thread waiting on the futex.
638 const auto data = serialize(V1_0::ErrorStatus::GENERAL_FAILURE, {}, kNoTiming);
639 mFmqResultChannel.writeBlocking(data.data(), data.size());
640 }
641
getPacketBlocking()642 nn::Result<std::vector<FmqResultDatum>> ResultChannelReceiver::getPacketBlocking() {
643 if (!mValid) {
644 return NN_ERROR() << "FMQ object is invalid";
645 }
646
647 // First spend time polling if results are available in FMQ instead of waiting on the futex.
648 // Polling is more responsive (yielding lower latencies), but can take up more power, so only
649 // poll for a limited period of time.
650
651 auto& getCurrentTime = std::chrono::high_resolution_clock::now;
652 const auto timeToStopPolling = getCurrentTime() + kPollingTimeWindow;
653
654 while (getCurrentTime() < timeToStopPolling) {
655 // if class is being torn down, immediately return
656 if (!mValid.load(std::memory_order_relaxed)) {
657 return NN_ERROR() << "FMQ object is invalid";
658 }
659
660 // Check if data is available. If it is, immediately retrieve it and return.
661 const size_t available = mFmqResultChannel.availableToRead();
662 if (available > 0) {
663 std::vector<FmqResultDatum> packet(available);
664 const bool success = mFmqResultChannel.readBlocking(packet.data(), available);
665 if (!success) {
666 return NN_ERROR() << "Error receiving packet";
667 }
668 return packet;
669 }
670
671 std::this_thread::yield();
672 }
673
674 // If we get to this point, we either stopped polling because it was taking too long or polling
675 // was not allowed. Instead, perform a blocking call which uses a futex to save power.
676
677 // wait for result packet and read first element of result packet
678 FmqResultDatum datum;
679 bool success = mFmqResultChannel.readBlocking(&datum, 1);
680
681 // retrieve remaining elements
682 // NOTE: all of the data is already available at this point, so there's no need to do a blocking
683 // wait to wait for more data. This is known because in FMQ, all writes are published (made
684 // available) atomically. Currently, the producer always publishes the entire packet in one
685 // function call, so if the first element of the packet is available, the remaining elements are
686 // also available.
687 const size_t count = mFmqResultChannel.availableToRead();
688 std::vector<FmqResultDatum> packet(count + 1);
689 std::memcpy(&packet.front(), &datum, sizeof(datum));
690 success &= mFmqResultChannel.read(packet.data() + 1, count);
691
692 if (!mValid) {
693 return NN_ERROR() << "FMQ object is invalid";
694 }
695
696 // ensure packet was successfully received
697 if (!success) {
698 return NN_ERROR() << "Error receiving packet";
699 }
700
701 return packet;
702 }
703
704 } // namespace android::hardware::neuralnetworks::V1_2::utils
705