1 /*
2 * Copyright (C) 2019 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #define LOG_TAG "MetaModel"
18
19 #include "MetaModel.h"
20
21 #include <algorithm>
22 #include <map>
23 #include <numeric>
24 #include <set>
25 #include <sstream>
26 #include <type_traits>
27 #include <utility>
28 #include <vector>
29
30 #include "GraphDump.h"
31 #include "LegacyUtils.h"
32 #include "nnapi/TypeUtils.h"
33 #include "nnapi/Types.h"
34 #include "nnapi/Validation.h"
35
36 namespace android::nn {
37
38 namespace {
39
40 // Add an element to the end of the vector, set it to the specified value, and
41 // return a pair consisting of the index of the new element and a pointer to the
42 // new element.
43 template <class T>
extend(std::vector<T> * vec,const T & val)44 std::pair<uint32_t, T*> extend(std::vector<T>* vec, const T& val) {
45 vec->push_back(val);
46 return {vec->size() - 1, &vec->back()};
47 }
48
49 // Add an element to the end of the vector and return a pair consisting of the
50 // index of the new element and a pointer to the new element.
51 template <class T>
extend(std::vector<T> * vec)52 std::pair<uint32_t, T*> extend(std::vector<T>* vec) {
53 return extend(vec, {});
54 }
55
invalid(const Model & model,Version version,bool strictSlicing)56 bool invalid(const Model& model, Version version, bool strictSlicing) {
57 // A model must have at least one operation. However, it's possible that a
58 // slice has no operations (because no operations from the original model
59 // are compliant with the sliced model type). In this case, the sliced
60 // model would be invalid.
61 const bool looksEmpty = (model.main.operations.size() == 0);
62 if (strictSlicing) {
63 CHECK_EQ(looksEmpty, (model.main.operands.size() == 0));
64 }
65 if (looksEmpty) return true;
66
67 // A model must have at least one output. However, it's possible for a
68 // model to contain dead operations (i.e., outputs on which no model outputs
69 // are data dependent). A slice might contain only dead operations, and
70 // hence have no model outputs. In this case, the sliced model would be
71 // invalid.
72 if (model.main.outputIndexes.size() == 0) return true;
73
74 // We shouldn't have to check whether the model is valid. However, it could
75 // be invalid if there is an error in the slicing algorithm.
76 auto maybeVersion = validate(model);
77 if (!maybeVersion.has_value()) {
78 LOG(WARNING) << "Sliced model fails validate(): " << maybeVersion.error();
79 CHECK(!strictSlicing);
80 return true;
81 }
82 if (!isCompliantVersion(maybeVersion.value(), version)) {
83 LOG(WARNING) << "Sliced model fails validate(): insufficient version ("
84 << maybeVersion.value() << " vs " << version << ")";
85 CHECK(!strictSlicing);
86 return true;
87 }
88
89 return false;
90 }
91
92 } // anonymous namespace
93
MetaModel(Model model,bool strictSlicing)94 MetaModel::MetaModel(Model model, bool strictSlicing)
95 : mModel(std::move(model)),
96 mModelMinimumSupportedVersion(validate(mModel).value()),
97 mStrictSlicing(strictSlicing) {}
98
getSlice(Version version) const99 MetaModel::ReturnedSlice MetaModel::getSlice(Version version) const {
100 // All slices of versions of at least mModelMinimumSupportedVersion are identical, so do not
101 // create more than one such slice.
102 version.level = std::min(version.level, mModelMinimumSupportedVersion.level);
103 version.runtimeOnlyFeatures &= mModelMinimumSupportedVersion.runtimeOnlyFeatures;
104
105 auto& slice = mCachedSlices[version];
106 if (slice.mState == SliceState::UNINITIALIZED) {
107 slice = makeSlice(version);
108 }
109 if (slice.mState == SliceState::INVALID) {
110 return {};
111 }
112 return MetaModel::ReturnedSlice(std::make_pair(
113 slice.mModel, Mapper([&slice](uint32_t slicedOperationIndex) {
114 return slice.mSlicedOperationIndexToOrigIndex.at(slicedOperationIndex);
115 })));
116 }
117
118 // Utility class for makeSlice().
119 //
120 // For each output operand of a noncompliant operation that is the input
121 // operand of at least one compliant operation, we will ensure that there is
122 // a sliced model input whose "type" is that of the output operand. This is
123 // a map from operand "type" (in the original model) to model input operand
124 // index (in the sliced model). We only use the subset of the fields that are
125 // relevant (OperandType, dimensions, scale, zeroPoint, extraParams), but
126 // exclude irrelevant fields from the map key (lifetime, location).
127 //
128 // We also use this map for model input operands of the original model that
129 // become input operands of the sliced model. This means that an original
130 // model input operand might be commoned with other original model input
131 // operands and/or with original model temporary operands.
132 class MetaModel::OrigOperandToSlicedInputOperandIndex {
133 public:
134 // `slicedOperands` and `slicedInputIndexes` will be modified as part of
135 // OrigOperandToSlicedInputOperandIndex::getIndex. `slicedVersion`, `operandValuesSize`, and
136 // `poolSizes` are used as a check to ensure that the sliced operand is valid and compliant with
137 // the sliced version. `operandValuesSize` is the size of the operand values in the sliced model
138 // (which is the same as the original model). `poolSizes` is the size of the memories in the
139 // sliced model (which is the same as the original model).
OrigOperandToSlicedInputOperandIndex(std::vector<Operand> * slicedOperands,std::vector<uint32_t> * slicedInputIndexes,Version slicedVersion,size_t operandValuesSize,std::vector<size_t> poolSizes)140 OrigOperandToSlicedInputOperandIndex(std::vector<Operand>* slicedOperands,
141 std::vector<uint32_t>* slicedInputIndexes,
142 Version slicedVersion, size_t operandValuesSize,
143 std::vector<size_t> poolSizes)
144 : mSlicedOperands(*slicedOperands),
145 mSlicedInputIndexes(*slicedInputIndexes),
146 kSlicedVersion(slicedVersion),
147 kOperandValuesSize(operandValuesSize),
148 kPoolSizes(std::move(poolSizes)) {}
149
150 // Given an operand from the original model, return the index of the
151 // corresponding model input operand from the sliced model. Creates a
152 // new operand in the sliced model if necessary.
getIndex(Operand operand)153 uint32_t getIndex(Operand operand) {
154 CHECK(operand.lifetime == Operand::LifeTime::SUBGRAPH_INPUT ||
155 operand.lifetime == Operand::LifeTime::SUBGRAPH_OUTPUT ||
156 operand.lifetime == Operand::LifeTime::TEMPORARY_VARIABLE);
157
158 // Lookup
159 auto it = mMap.find(operand);
160 if (it != mMap.end()) {
161 VLOG(COMPILATION) << "OrigOperandToSlicedInputOperandIndex::getIndex looked for "
162 << operand << " and found " << it->second << ": " << it->first;
163 return it->second;
164 }
165
166 // Create
167 operand.lifetime = Operand::LifeTime::SUBGRAPH_INPUT;
168 operand.location = {};
169
170 // Note that the sliced model does not contain any referenced subgraphs, so both `subgraphs`
171 // and `subgraphVersionCache` are empty.
172 const std::vector<Model::Subgraph> subgraphs;
173 auto subgraphVersionCache = createSubgraphVersionCache(subgraphs.size());
174 const auto minimumSupportedOperandVersion =
175 validateOperandAndAnythingItDependsOn(operand, kOperandValuesSize, kPoolSizes,
176 subgraphs, subgraphVersionCache.get())
177 .value();
178 CHECK(isCompliantVersion(minimumSupportedOperandVersion, kSlicedVersion));
179
180 uint32_t slicedOperandIndex = extend(&mSlicedOperands, operand).first;
181 mMap[operand] = slicedOperandIndex;
182 extend(&mSlicedInputIndexes, slicedOperandIndex);
183 VLOG(COMPILATION) << "OrigOperandToSlicedInputOperandIndex::getIndex created "
184 << slicedOperandIndex << ": " << operand;
185 return slicedOperandIndex;
186 }
187
188 private:
189 class Compare {
190 public:
operator ()(const Operand & a,const Operand & b) const191 bool operator()(const Operand& a, const Operand& b) const {
192 if (a.type != b.type) {
193 return a.type < b.type;
194 }
195 if (a.dimensions != b.dimensions) {
196 return a.dimensions < b.dimensions;
197 }
198 if (a.scale != b.scale) {
199 return a.scale < b.scale;
200 }
201 if (a.zeroPoint != b.zeroPoint) {
202 return a.zeroPoint < b.zeroPoint;
203 }
204 return compare(a.extraParams, b.extraParams);
205 }
206
207 private:
compare(const Operand::SymmPerChannelQuantParams & a,const Operand::SymmPerChannelQuantParams & b)208 static bool compare(const Operand::SymmPerChannelQuantParams& a,
209 const Operand::SymmPerChannelQuantParams& b) {
210 if (a.scales != b.scales) {
211 return a.scales < b.scales;
212 }
213 return a.channelDim < b.channelDim;
214 }
compare(const Operand::ExtraParams & a,const Operand::ExtraParams & b)215 static bool compare(const Operand::ExtraParams& a, const Operand::ExtraParams& b) {
216 if (a.index() != b.index()) {
217 return a.index() < b.index();
218 }
219 if (std::holds_alternative<Operand::SymmPerChannelQuantParams>(a)) {
220 return compare(std::get<Operand::SymmPerChannelQuantParams>(a),
221 std::get<Operand::SymmPerChannelQuantParams>(b));
222 }
223 if (std::holds_alternative<Operand::ExtensionParams>(a)) {
224 return std::get<Operand::ExtensionParams>(a) <
225 std::get<Operand::ExtensionParams>(b);
226 }
227 if (std::holds_alternative<Operand::NoParams>(a)) {
228 return false;
229 }
230 CHECK(false) << "Unexpected";
231 return false;
232 }
233 };
234 std::map<Operand, uint32_t, Compare> mMap;
235 std::vector<Operand>& mSlicedOperands;
236 std::vector<uint32_t>& mSlicedInputIndexes;
237 const Version kSlicedVersion;
238 const size_t kOperandValuesSize;
239 const std::vector<size_t> kPoolSizes;
240 };
241
processOperations(Slice * slice,std::map<uint32_t,uint32_t> * origOperandIndexToSlicedIndex,OrigOperandToSlicedInputOperandIndex * origOperandToSlicedInputOperandIndex,const std::set<uint32_t> & noncompliantOperations,const std::set<uint32_t> & inputOperandIndexesOfCompliantOperations) const242 void MetaModel::processOperations(
243 Slice* slice, std::map<uint32_t, uint32_t>* origOperandIndexToSlicedIndex,
244 OrigOperandToSlicedInputOperandIndex* origOperandToSlicedInputOperandIndex,
245 const std::set<uint32_t>& noncompliantOperations,
246 const std::set<uint32_t>& inputOperandIndexesOfCompliantOperations) const {
247 const auto& origOperands = mModel.main.operands;
248 const auto& origOperations = mModel.main.operations;
249 auto& slicedOperands = slice->mModel.main.operands;
250 auto& slicedOperations = slice->mModel.main.operations;
251
252 std::vector<uint32_t> origOperandNumberOfConsumers =
253 countNumberOfConsumers(origOperands.size(), origOperations).value();
254
255 for (uint32_t origOperationIndex = 0; origOperationIndex < origOperations.size();
256 ++origOperationIndex) {
257 const Operation& origOperation = origOperations[origOperationIndex];
258
259 if (noncompliantOperations.count(origOperationIndex)) {
260 for (uint32_t output : origOperation.outputs) {
261 if (!inputOperandIndexesOfCompliantOperations.count(output)) {
262 continue;
263 }
264 const uint32_t slicedIndex =
265 origOperandToSlicedInputOperandIndex->getIndex(origOperands[output]);
266 (*origOperandIndexToSlicedIndex)[output] = slicedIndex;
267 VLOG(COMPILATION)
268 << "origOperandIndexToSlicedIndex noncompliant output processing created "
269 << output << " -> " << slicedIndex << ": " << slicedOperands[slicedIndex];
270 }
271 } else {
272 slice->mSlicedOperationIndexToOrigIndex.push_back(origOperationIndex);
273 Operation& slicedOperation = *extend(&slicedOperations).second;
274 CHECK_EQ(slice->mSlicedOperationIndexToOrigIndex.size(), slicedOperations.size());
275
276 slicedOperation.type = origOperation.type;
277
278 // Model is topologically sorted, so all operation inputs must be
279 // present in origOperandIndexToSlicedIndex, and no operation
280 // outputs may be.
281
282 // Operation inputs
283 // - Fill in slicedOperation.inputs
284 slicedOperation.inputs.resize(origOperation.inputs.size());
285 std::transform(
286 origOperation.inputs.begin(), origOperation.inputs.end(),
287 slicedOperation.inputs.begin(),
288 [&origOperandIndexToSlicedIndex, &slicedOperands](uint32_t origOperandIndex) {
289 uint32_t slicedOperandIndex =
290 origOperandIndexToSlicedIndex->at(origOperandIndex);
291 VLOG(COMPILATION) << "origOperandIndexToSlicedIndex compliant input "
292 "processing created "
293 << origOperandIndex << " -> " << slicedOperandIndex
294 << ": " << slicedOperands[slicedOperandIndex];
295 return slicedOperandIndex;
296 });
297
298 // Operation outputs
299 // - Add new operands to slicedOperands
300 // - Update origOperandIndexToSlicedIndex
301 // - Fill in slicedOperation.outputs
302 // - Record as a model output, if necessary
303 const uint32_t firstOutputSlicedOperandIndex = slicedOperands.size();
304 slicedOperands.resize(firstOutputSlicedOperandIndex + origOperation.outputs.size());
305 slicedOperation.outputs.resize(origOperation.outputs.size());
306 for (uint32_t outputNum = 0; outputNum < slicedOperation.outputs.size(); ++outputNum) {
307 uint32_t origOperandIndex = origOperation.outputs[outputNum];
308 uint32_t slicedOperandIndex = firstOutputSlicedOperandIndex + outputNum;
309 auto& slicedOperand = slicedOperands[slicedOperandIndex];
310 const auto& origOperand = origOperands[origOperandIndex];
311 slicedOperand = origOperand;
312
313 CHECK_EQ(origOperandIndexToSlicedIndex->count(origOperandIndex), size_t(0));
314 (*origOperandIndexToSlicedIndex)[origOperandIndex] = slicedOperandIndex;
315 slicedOperation.outputs[outputNum] = slicedOperandIndex;
316
317 const auto subgraphOutputLifetime = Operand::LifeTime::SUBGRAPH_OUTPUT;
318 if (!inputOperandIndexesOfCompliantOperations.count(origOperandIndex) &&
319 origOperandNumberOfConsumers[origOperandIndex] != 0) {
320 // Was consumed only by noncompliant operations; convert to
321 // an output of the sliced model.
322 slicedOperand.lifetime = subgraphOutputLifetime;
323 }
324
325 VLOG(COMPILATION) << "origOperandIndexToSlicedIndex compliant output created "
326 << origOperandIndex << " -> " << slicedOperandIndex << ": "
327 << slicedOperand;
328
329 if (slicedOperand.lifetime == subgraphOutputLifetime) {
330 extend(&slice->mModel.main.outputIndexes, slicedOperandIndex);
331 }
332 }
333 }
334 }
335 }
336
getNoncompliantOperations(Version version) const337 std::set<uint32_t> MetaModel::getNoncompliantOperations(Version version) const {
338 const auto [operandValuesSize, poolSizes] = getMemorySizes(mModel);
339
340 auto subgraphVersionCache = createSubgraphVersionCache(mModel.referenced.size());
341 std::set<uint32_t> noncompliantOperations;
342 for (uint32_t i = 0; i < mModel.main.operations.size(); ++i) {
343 const auto& operation = mModel.main.operations[i];
344 const auto minSupportedVersion =
345 validateOperationAndAnythingItDependsOn(
346 operation, mModel.main.operands, operandValuesSize, poolSizes,
347 mModel.referenced, subgraphVersionCache.get())
348 .value();
349 if (!isCompliantVersion(minSupportedVersion, version)) {
350 noncompliantOperations.insert(i);
351 }
352 }
353 return noncompliantOperations;
354 }
355
operator ()(Version lhs,Version rhs) const356 bool MetaModel::Comparison::operator()(Version lhs, Version rhs) const {
357 constexpr auto toTuple = [](const Version& v) {
358 return std::tie(v.level, v.runtimeOnlyFeatures);
359 };
360 // Lexicographical comparison of the fields. The bool is promoted to an integer for the
361 // comparison such that "false < true".
362 return toTuple(lhs) < toTuple(rhs);
363 }
364
makeSlice(Version version) const365 MetaModel::Slice MetaModel::makeSlice(Version version) const {
366 Slice slice;
367
368 // Quickly return if the model is already compliant with `version`
369 if (isCompliantVersion(mModelMinimumSupportedVersion, version)) {
370 slice.mModel = mModel;
371 slice.mSlicedOperationIndexToOrigIndex =
372 std::vector<uint32_t>(mModel.main.operations.size());
373 std::iota(slice.mSlicedOperationIndexToOrigIndex.begin(),
374 slice.mSlicedOperationIndexToOrigIndex.end(), 0u);
375 slice.mState = SliceState::NORMAL;
376 return slice;
377 }
378
379 const auto& origOperands = mModel.main.operands;
380 const auto& origOperations = mModel.main.operations;
381 auto& slicedOperands = slice.mModel.main.operands;
382
383 // Indexes of elements of noncompliant origOperations
384 std::set<uint32_t> noncompliantOperations = getNoncompliantOperations(version);
385
386 // Check if any compliant operations require a subgraph.
387 bool someCompliantOperationHasASubgraphOperand = false;
388 if (!mModel.referenced.empty()) {
389 for (size_t i = 0; i < mModel.main.operations.size(); ++i) {
390 const auto& operation = mModel.main.operations[i];
391 if (noncompliantOperations.count(i) > 0) {
392 continue;
393 }
394 const auto isSubgraph = [&origOperands](uint32_t opndIdx) {
395 return origOperands[opndIdx].lifetime == Operand::LifeTime::SUBGRAPH;
396 };
397 if (std::any_of(operation.inputs.begin(), operation.inputs.end(), isSubgraph)) {
398 someCompliantOperationHasASubgraphOperand = true;
399 break;
400 }
401 }
402 }
403
404 // TODO(b/175418767): Currently, MetaModel is not equipped to slice referenced subgraphs. If the
405 // original model is not compliant with the specified version and contains referenced subgraphs
406 // needed by the slice, return an invalidated slice.
407 if (someCompliantOperationHasASubgraphOperand) {
408 slice.mState = SliceState::INVALID;
409 return slice;
410 }
411
412 // Map from an operand index in origOperands to the corresponding operand index in
413 // slicedOperands
414 std::map<uint32_t, uint32_t> origOperandIndexToSlicedIndex;
415
416 // Collect the operand indexes of every operand that is an input to a
417 // compliant operation. If the operand is a CONSTANT_*, POINTER, or a
418 // NO_VALUE, copy it to the sliced model and update
419 // origOperandIndexToSlicedIndex accordingly. Otherwise, we'll deal with
420 // the operand in the subsequent "Main loop", where we process operation
421 // outputs (intermediates and model outputs).
422 std::set<uint32_t> inputOperandIndexesOfCompliantOperations;
423 for (uint32_t origOperationIndex = 0; origOperationIndex < origOperations.size();
424 ++origOperationIndex) {
425 if (noncompliantOperations.count(origOperationIndex)) {
426 continue;
427 }
428 for (uint32_t input : origOperations[origOperationIndex].inputs) {
429 if (inputOperandIndexesOfCompliantOperations.insert(input).second) {
430 const Operand& origOperand = origOperands[input];
431 switch (origOperand.lifetime) {
432 case Operand::LifeTime::CONSTANT_COPY:
433 case Operand::LifeTime::CONSTANT_REFERENCE:
434 case Operand::LifeTime::POINTER:
435 case Operand::LifeTime::NO_VALUE: {
436 const uint32_t slicedOperandIndex =
437 extend(&slicedOperands, origOperand).first;
438 origOperandIndexToSlicedIndex[input] = slicedOperandIndex;
439 VLOG(COMPILATION) << "origOperandIndexToSlicedIndex initialization created "
440 << input << " -> " << slicedOperandIndex << ": "
441 << slicedOperands[slicedOperandIndex];
442 break;
443 }
444 default:
445 break;
446 }
447 }
448 }
449 }
450
451 const auto [operandValuesSize, poolSizes] = getMemorySizes(mModel);
452
453 OrigOperandToSlicedInputOperandIndex origOperandToSlicedInputOperandIndex(
454 &slicedOperands, &slice.mModel.main.inputIndexes, version, operandValuesSize,
455 poolSizes);
456
457 // An input of the original model is an input of the sliced model if and
458 // only if it is consumed by at least one compliant operation. Note that in
459 // the sliced model we share all model inputs of the same "type"; and that
460 // we may later add model inputs to the sliced model.
461 for (uint32_t origInputIndex : mModel.main.inputIndexes) {
462 if (inputOperandIndexesOfCompliantOperations.count(origInputIndex)) {
463 const uint32_t slicedIndex =
464 origOperandToSlicedInputOperandIndex.getIndex(origOperands[origInputIndex]);
465 origOperandIndexToSlicedIndex[origInputIndex] = slicedIndex;
466 VLOG(COMPILATION) << "origOperandIndexToSlicedIndex inputIndexes processing created "
467 << origInputIndex << " -> " << slicedIndex << ": "
468 << slicedOperands[slicedIndex];
469 }
470 }
471
472 // Main loop: Process each operation of the original model.
473 processOperations(&slice, &origOperandIndexToSlicedIndex, &origOperandToSlicedInputOperandIndex,
474 noncompliantOperations, inputOperandIndexesOfCompliantOperations);
475
476 // To keep things simple, we copy over these fields as-is. We could instead
477 // opt to regenerate them based on the operands present in the sliced model:
478 // This would be more complex and probably take more computation time, but
479 // it would reduce the size of the sliced model, and hence the time spent
480 // copying it around and potentially passing it across process boundaries.
481 slice.mModel.operandValues = mModel.operandValues;
482 slice.mModel.pools = mModel.pools;
483
484 if (VLOG_IS_ON(COMPILATION)) {
485 {
486 std::ostringstream fromName;
487 fromName << "Slice: From canonical";
488 graphDump(fromName.str().c_str(), mModel);
489 }
490 {
491 std::ostringstream toName;
492 toName << "Slice: To " << version;
493 graphDump(toName.str().c_str(), slice.mModel);
494 }
495 }
496
497 slice.mState = invalid(slice.mModel, version, mStrictSlicing) ? SliceState::INVALID
498 : SliceState::NORMAL;
499
500 return slice;
501 }
502
503 } // namespace android::nn
504