1 /*
2  * Copyright (C) 2019 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #define LOG_TAG "MetaModel"
18 
19 #include "MetaModel.h"
20 
21 #include <algorithm>
22 #include <map>
23 #include <numeric>
24 #include <set>
25 #include <sstream>
26 #include <type_traits>
27 #include <utility>
28 #include <vector>
29 
30 #include "GraphDump.h"
31 #include "LegacyUtils.h"
32 #include "nnapi/TypeUtils.h"
33 #include "nnapi/Types.h"
34 #include "nnapi/Validation.h"
35 
36 namespace android::nn {
37 
38 namespace {
39 
40 // Add an element to the end of the vector, set it to the specified value, and
41 // return a pair consisting of the index of the new element and a pointer to the
42 // new element.
43 template <class T>
extend(std::vector<T> * vec,const T & val)44 std::pair<uint32_t, T*> extend(std::vector<T>* vec, const T& val) {
45     vec->push_back(val);
46     return {vec->size() - 1, &vec->back()};
47 }
48 
49 // Add an element to the end of the vector and return a pair consisting of the
50 // index of the new element and a pointer to the new element.
51 template <class T>
extend(std::vector<T> * vec)52 std::pair<uint32_t, T*> extend(std::vector<T>* vec) {
53     return extend(vec, {});
54 }
55 
invalid(const Model & model,Version version,bool strictSlicing)56 bool invalid(const Model& model, Version version, bool strictSlicing) {
57     // A model must have at least one operation.  However, it's possible that a
58     // slice has no operations (because no operations from the original model
59     // are compliant with the sliced model type).  In this case, the sliced
60     // model would be invalid.
61     const bool looksEmpty = (model.main.operations.size() == 0);
62     if (strictSlicing) {
63         CHECK_EQ(looksEmpty, (model.main.operands.size() == 0));
64     }
65     if (looksEmpty) return true;
66 
67     // A model must have at least one output.  However, it's possible for a
68     // model to contain dead operations (i.e., outputs on which no model outputs
69     // are data dependent).  A slice might contain only dead operations, and
70     // hence have no model outputs.  In this case, the sliced model would be
71     // invalid.
72     if (model.main.outputIndexes.size() == 0) return true;
73 
74     // We shouldn't have to check whether the model is valid. However, it could
75     // be invalid if there is an error in the slicing algorithm.
76     auto maybeVersion = validate(model);
77     if (!maybeVersion.has_value()) {
78         LOG(WARNING) << "Sliced model fails validate(): " << maybeVersion.error();
79         CHECK(!strictSlicing);
80         return true;
81     }
82     if (!isCompliantVersion(maybeVersion.value(), version)) {
83         LOG(WARNING) << "Sliced model fails validate(): insufficient version ("
84                      << maybeVersion.value() << " vs " << version << ")";
85         CHECK(!strictSlicing);
86         return true;
87     }
88 
89     return false;
90 }
91 
92 }  // anonymous namespace
93 
MetaModel(Model model,bool strictSlicing)94 MetaModel::MetaModel(Model model, bool strictSlicing)
95     : mModel(std::move(model)),
96       mModelMinimumSupportedVersion(validate(mModel).value()),
97       mStrictSlicing(strictSlicing) {}
98 
getSlice(Version version) const99 MetaModel::ReturnedSlice MetaModel::getSlice(Version version) const {
100     // All slices of versions of at least mModelMinimumSupportedVersion are identical, so do not
101     // create more than one such slice.
102     version.level = std::min(version.level, mModelMinimumSupportedVersion.level);
103     version.runtimeOnlyFeatures &= mModelMinimumSupportedVersion.runtimeOnlyFeatures;
104 
105     auto& slice = mCachedSlices[version];
106     if (slice.mState == SliceState::UNINITIALIZED) {
107         slice = makeSlice(version);
108     }
109     if (slice.mState == SliceState::INVALID) {
110         return {};
111     }
112     return MetaModel::ReturnedSlice(std::make_pair(
113             slice.mModel, Mapper([&slice](uint32_t slicedOperationIndex) {
114                 return slice.mSlicedOperationIndexToOrigIndex.at(slicedOperationIndex);
115             })));
116 }
117 
118 // Utility class for makeSlice().
119 //
120 // For each output operand of a noncompliant operation that is the input
121 // operand of at least one compliant operation, we will ensure that there is
122 // a sliced model input whose "type" is that of the output operand.  This is
123 // a map from operand "type" (in the original model) to model input operand
124 // index (in the sliced model).  We only use the subset of the fields that are
125 // relevant (OperandType, dimensions, scale, zeroPoint, extraParams), but
126 // exclude irrelevant fields from the map key (lifetime, location).
127 //
128 // We also use this map for model input operands of the original model that
129 // become input operands of the sliced model.  This means that an original
130 // model input operand might be commoned with other original model input
131 // operands and/or with original model temporary operands.
132 class MetaModel::OrigOperandToSlicedInputOperandIndex {
133    public:
134     // `slicedOperands` and `slicedInputIndexes` will be modified as part of
135     // OrigOperandToSlicedInputOperandIndex::getIndex. `slicedVersion`, `operandValuesSize`, and
136     // `poolSizes` are used as a check to ensure that the sliced operand is valid and compliant with
137     // the sliced version. `operandValuesSize` is the size of the operand values in the sliced model
138     // (which is the same as the original model). `poolSizes` is the size of the memories in the
139     // sliced model (which is the same as the original model).
OrigOperandToSlicedInputOperandIndex(std::vector<Operand> * slicedOperands,std::vector<uint32_t> * slicedInputIndexes,Version slicedVersion,size_t operandValuesSize,std::vector<size_t> poolSizes)140     OrigOperandToSlicedInputOperandIndex(std::vector<Operand>* slicedOperands,
141                                          std::vector<uint32_t>* slicedInputIndexes,
142                                          Version slicedVersion, size_t operandValuesSize,
143                                          std::vector<size_t> poolSizes)
144         : mSlicedOperands(*slicedOperands),
145           mSlicedInputIndexes(*slicedInputIndexes),
146           kSlicedVersion(slicedVersion),
147           kOperandValuesSize(operandValuesSize),
148           kPoolSizes(std::move(poolSizes)) {}
149 
150     // Given an operand from the original model, return the index of the
151     // corresponding model input operand from the sliced model.  Creates a
152     // new operand in the sliced model if necessary.
getIndex(Operand operand)153     uint32_t getIndex(Operand operand) {
154         CHECK(operand.lifetime == Operand::LifeTime::SUBGRAPH_INPUT ||
155               operand.lifetime == Operand::LifeTime::SUBGRAPH_OUTPUT ||
156               operand.lifetime == Operand::LifeTime::TEMPORARY_VARIABLE);
157 
158         // Lookup
159         auto it = mMap.find(operand);
160         if (it != mMap.end()) {
161             VLOG(COMPILATION) << "OrigOperandToSlicedInputOperandIndex::getIndex looked for "
162                               << operand << " and found " << it->second << ": " << it->first;
163             return it->second;
164         }
165 
166         // Create
167         operand.lifetime = Operand::LifeTime::SUBGRAPH_INPUT;
168         operand.location = {};
169 
170         // Note that the sliced model does not contain any referenced subgraphs, so both `subgraphs`
171         // and `subgraphVersionCache` are empty.
172         const std::vector<Model::Subgraph> subgraphs;
173         auto subgraphVersionCache = createSubgraphVersionCache(subgraphs.size());
174         const auto minimumSupportedOperandVersion =
175                 validateOperandAndAnythingItDependsOn(operand, kOperandValuesSize, kPoolSizes,
176                                                       subgraphs, subgraphVersionCache.get())
177                         .value();
178         CHECK(isCompliantVersion(minimumSupportedOperandVersion, kSlicedVersion));
179 
180         uint32_t slicedOperandIndex = extend(&mSlicedOperands, operand).first;
181         mMap[operand] = slicedOperandIndex;
182         extend(&mSlicedInputIndexes, slicedOperandIndex);
183         VLOG(COMPILATION) << "OrigOperandToSlicedInputOperandIndex::getIndex created "
184                           << slicedOperandIndex << ": " << operand;
185         return slicedOperandIndex;
186     }
187 
188    private:
189     class Compare {
190        public:
operator ()(const Operand & a,const Operand & b) const191         bool operator()(const Operand& a, const Operand& b) const {
192             if (a.type != b.type) {
193                 return a.type < b.type;
194             }
195             if (a.dimensions != b.dimensions) {
196                 return a.dimensions < b.dimensions;
197             }
198             if (a.scale != b.scale) {
199                 return a.scale < b.scale;
200             }
201             if (a.zeroPoint != b.zeroPoint) {
202                 return a.zeroPoint < b.zeroPoint;
203             }
204             return compare(a.extraParams, b.extraParams);
205         }
206 
207        private:
compare(const Operand::SymmPerChannelQuantParams & a,const Operand::SymmPerChannelQuantParams & b)208         static bool compare(const Operand::SymmPerChannelQuantParams& a,
209                             const Operand::SymmPerChannelQuantParams& b) {
210             if (a.scales != b.scales) {
211                 return a.scales < b.scales;
212             }
213             return a.channelDim < b.channelDim;
214         }
compare(const Operand::ExtraParams & a,const Operand::ExtraParams & b)215         static bool compare(const Operand::ExtraParams& a, const Operand::ExtraParams& b) {
216             if (a.index() != b.index()) {
217                 return a.index() < b.index();
218             }
219             if (std::holds_alternative<Operand::SymmPerChannelQuantParams>(a)) {
220                 return compare(std::get<Operand::SymmPerChannelQuantParams>(a),
221                                std::get<Operand::SymmPerChannelQuantParams>(b));
222             }
223             if (std::holds_alternative<Operand::ExtensionParams>(a)) {
224                 return std::get<Operand::ExtensionParams>(a) <
225                        std::get<Operand::ExtensionParams>(b);
226             }
227             if (std::holds_alternative<Operand::NoParams>(a)) {
228                 return false;
229             }
230             CHECK(false) << "Unexpected";
231             return false;
232         }
233     };
234     std::map<Operand, uint32_t, Compare> mMap;
235     std::vector<Operand>& mSlicedOperands;
236     std::vector<uint32_t>& mSlicedInputIndexes;
237     const Version kSlicedVersion;
238     const size_t kOperandValuesSize;
239     const std::vector<size_t> kPoolSizes;
240 };
241 
processOperations(Slice * slice,std::map<uint32_t,uint32_t> * origOperandIndexToSlicedIndex,OrigOperandToSlicedInputOperandIndex * origOperandToSlicedInputOperandIndex,const std::set<uint32_t> & noncompliantOperations,const std::set<uint32_t> & inputOperandIndexesOfCompliantOperations) const242 void MetaModel::processOperations(
243         Slice* slice, std::map<uint32_t, uint32_t>* origOperandIndexToSlicedIndex,
244         OrigOperandToSlicedInputOperandIndex* origOperandToSlicedInputOperandIndex,
245         const std::set<uint32_t>& noncompliantOperations,
246         const std::set<uint32_t>& inputOperandIndexesOfCompliantOperations) const {
247     const auto& origOperands = mModel.main.operands;
248     const auto& origOperations = mModel.main.operations;
249     auto& slicedOperands = slice->mModel.main.operands;
250     auto& slicedOperations = slice->mModel.main.operations;
251 
252     std::vector<uint32_t> origOperandNumberOfConsumers =
253             countNumberOfConsumers(origOperands.size(), origOperations).value();
254 
255     for (uint32_t origOperationIndex = 0; origOperationIndex < origOperations.size();
256          ++origOperationIndex) {
257         const Operation& origOperation = origOperations[origOperationIndex];
258 
259         if (noncompliantOperations.count(origOperationIndex)) {
260             for (uint32_t output : origOperation.outputs) {
261                 if (!inputOperandIndexesOfCompliantOperations.count(output)) {
262                     continue;
263                 }
264                 const uint32_t slicedIndex =
265                         origOperandToSlicedInputOperandIndex->getIndex(origOperands[output]);
266                 (*origOperandIndexToSlicedIndex)[output] = slicedIndex;
267                 VLOG(COMPILATION)
268                         << "origOperandIndexToSlicedIndex noncompliant output processing created "
269                         << output << " -> " << slicedIndex << ": " << slicedOperands[slicedIndex];
270             }
271         } else {
272             slice->mSlicedOperationIndexToOrigIndex.push_back(origOperationIndex);
273             Operation& slicedOperation = *extend(&slicedOperations).second;
274             CHECK_EQ(slice->mSlicedOperationIndexToOrigIndex.size(), slicedOperations.size());
275 
276             slicedOperation.type = origOperation.type;
277 
278             // Model is topologically sorted, so all operation inputs must be
279             // present in origOperandIndexToSlicedIndex, and no operation
280             // outputs may be.
281 
282             // Operation inputs
283             // - Fill in slicedOperation.inputs
284             slicedOperation.inputs.resize(origOperation.inputs.size());
285             std::transform(
286                     origOperation.inputs.begin(), origOperation.inputs.end(),
287                     slicedOperation.inputs.begin(),
288                     [&origOperandIndexToSlicedIndex, &slicedOperands](uint32_t origOperandIndex) {
289                         uint32_t slicedOperandIndex =
290                                 origOperandIndexToSlicedIndex->at(origOperandIndex);
291                         VLOG(COMPILATION) << "origOperandIndexToSlicedIndex compliant input "
292                                              "processing created "
293                                           << origOperandIndex << " -> " << slicedOperandIndex
294                                           << ": " << slicedOperands[slicedOperandIndex];
295                         return slicedOperandIndex;
296                     });
297 
298             // Operation outputs
299             // - Add new operands to slicedOperands
300             // - Update origOperandIndexToSlicedIndex
301             // - Fill in slicedOperation.outputs
302             // - Record as a model output, if necessary
303             const uint32_t firstOutputSlicedOperandIndex = slicedOperands.size();
304             slicedOperands.resize(firstOutputSlicedOperandIndex + origOperation.outputs.size());
305             slicedOperation.outputs.resize(origOperation.outputs.size());
306             for (uint32_t outputNum = 0; outputNum < slicedOperation.outputs.size(); ++outputNum) {
307                 uint32_t origOperandIndex = origOperation.outputs[outputNum];
308                 uint32_t slicedOperandIndex = firstOutputSlicedOperandIndex + outputNum;
309                 auto& slicedOperand = slicedOperands[slicedOperandIndex];
310                 const auto& origOperand = origOperands[origOperandIndex];
311                 slicedOperand = origOperand;
312 
313                 CHECK_EQ(origOperandIndexToSlicedIndex->count(origOperandIndex), size_t(0));
314                 (*origOperandIndexToSlicedIndex)[origOperandIndex] = slicedOperandIndex;
315                 slicedOperation.outputs[outputNum] = slicedOperandIndex;
316 
317                 const auto subgraphOutputLifetime = Operand::LifeTime::SUBGRAPH_OUTPUT;
318                 if (!inputOperandIndexesOfCompliantOperations.count(origOperandIndex) &&
319                     origOperandNumberOfConsumers[origOperandIndex] != 0) {
320                     // Was consumed only by noncompliant operations; convert to
321                     // an output of the sliced model.
322                     slicedOperand.lifetime = subgraphOutputLifetime;
323                 }
324 
325                 VLOG(COMPILATION) << "origOperandIndexToSlicedIndex compliant output created "
326                                   << origOperandIndex << " -> " << slicedOperandIndex << ": "
327                                   << slicedOperand;
328 
329                 if (slicedOperand.lifetime == subgraphOutputLifetime) {
330                     extend(&slice->mModel.main.outputIndexes, slicedOperandIndex);
331                 }
332             }
333         }
334     }
335 }
336 
getNoncompliantOperations(Version version) const337 std::set<uint32_t> MetaModel::getNoncompliantOperations(Version version) const {
338     const auto [operandValuesSize, poolSizes] = getMemorySizes(mModel);
339 
340     auto subgraphVersionCache = createSubgraphVersionCache(mModel.referenced.size());
341     std::set<uint32_t> noncompliantOperations;
342     for (uint32_t i = 0; i < mModel.main.operations.size(); ++i) {
343         const auto& operation = mModel.main.operations[i];
344         const auto minSupportedVersion =
345                 validateOperationAndAnythingItDependsOn(
346                         operation, mModel.main.operands, operandValuesSize, poolSizes,
347                         mModel.referenced, subgraphVersionCache.get())
348                         .value();
349         if (!isCompliantVersion(minSupportedVersion, version)) {
350             noncompliantOperations.insert(i);
351         }
352     }
353     return noncompliantOperations;
354 }
355 
operator ()(Version lhs,Version rhs) const356 bool MetaModel::Comparison::operator()(Version lhs, Version rhs) const {
357     constexpr auto toTuple = [](const Version& v) {
358         return std::tie(v.level, v.runtimeOnlyFeatures);
359     };
360     // Lexicographical comparison of the fields. The bool is promoted to an integer for the
361     // comparison such that "false < true".
362     return toTuple(lhs) < toTuple(rhs);
363 }
364 
makeSlice(Version version) const365 MetaModel::Slice MetaModel::makeSlice(Version version) const {
366     Slice slice;
367 
368     // Quickly return if the model is already compliant with `version`
369     if (isCompliantVersion(mModelMinimumSupportedVersion, version)) {
370         slice.mModel = mModel;
371         slice.mSlicedOperationIndexToOrigIndex =
372                 std::vector<uint32_t>(mModel.main.operations.size());
373         std::iota(slice.mSlicedOperationIndexToOrigIndex.begin(),
374                   slice.mSlicedOperationIndexToOrigIndex.end(), 0u);
375         slice.mState = SliceState::NORMAL;
376         return slice;
377     }
378 
379     const auto& origOperands = mModel.main.operands;
380     const auto& origOperations = mModel.main.operations;
381     auto& slicedOperands = slice.mModel.main.operands;
382 
383     // Indexes of elements of noncompliant origOperations
384     std::set<uint32_t> noncompliantOperations = getNoncompliantOperations(version);
385 
386     // Check if any compliant operations require a subgraph.
387     bool someCompliantOperationHasASubgraphOperand = false;
388     if (!mModel.referenced.empty()) {
389         for (size_t i = 0; i < mModel.main.operations.size(); ++i) {
390             const auto& operation = mModel.main.operations[i];
391             if (noncompliantOperations.count(i) > 0) {
392                 continue;
393             }
394             const auto isSubgraph = [&origOperands](uint32_t opndIdx) {
395                 return origOperands[opndIdx].lifetime == Operand::LifeTime::SUBGRAPH;
396             };
397             if (std::any_of(operation.inputs.begin(), operation.inputs.end(), isSubgraph)) {
398                 someCompliantOperationHasASubgraphOperand = true;
399                 break;
400             }
401         }
402     }
403 
404     // TODO(b/175418767): Currently, MetaModel is not equipped to slice referenced subgraphs. If the
405     // original model is not compliant with the specified version and contains referenced subgraphs
406     // needed by the slice, return an invalidated slice.
407     if (someCompliantOperationHasASubgraphOperand) {
408         slice.mState = SliceState::INVALID;
409         return slice;
410     }
411 
412     // Map from an operand index in origOperands to the corresponding operand index in
413     // slicedOperands
414     std::map<uint32_t, uint32_t> origOperandIndexToSlicedIndex;
415 
416     // Collect the operand indexes of every operand that is an input to a
417     // compliant operation.  If the operand is a CONSTANT_*, POINTER, or a
418     // NO_VALUE, copy it to the sliced model and update
419     // origOperandIndexToSlicedIndex accordingly.  Otherwise, we'll deal with
420     // the operand in the subsequent "Main loop", where we process operation
421     // outputs (intermediates and model outputs).
422     std::set<uint32_t> inputOperandIndexesOfCompliantOperations;
423     for (uint32_t origOperationIndex = 0; origOperationIndex < origOperations.size();
424          ++origOperationIndex) {
425         if (noncompliantOperations.count(origOperationIndex)) {
426             continue;
427         }
428         for (uint32_t input : origOperations[origOperationIndex].inputs) {
429             if (inputOperandIndexesOfCompliantOperations.insert(input).second) {
430                 const Operand& origOperand = origOperands[input];
431                 switch (origOperand.lifetime) {
432                     case Operand::LifeTime::CONSTANT_COPY:
433                     case Operand::LifeTime::CONSTANT_REFERENCE:
434                     case Operand::LifeTime::POINTER:
435                     case Operand::LifeTime::NO_VALUE: {
436                         const uint32_t slicedOperandIndex =
437                                 extend(&slicedOperands, origOperand).first;
438                         origOperandIndexToSlicedIndex[input] = slicedOperandIndex;
439                         VLOG(COMPILATION) << "origOperandIndexToSlicedIndex initialization created "
440                                           << input << " -> " << slicedOperandIndex << ": "
441                                           << slicedOperands[slicedOperandIndex];
442                         break;
443                     }
444                     default:
445                         break;
446                 }
447             }
448         }
449     }
450 
451     const auto [operandValuesSize, poolSizes] = getMemorySizes(mModel);
452 
453     OrigOperandToSlicedInputOperandIndex origOperandToSlicedInputOperandIndex(
454             &slicedOperands, &slice.mModel.main.inputIndexes, version, operandValuesSize,
455             poolSizes);
456 
457     // An input of the original model is an input of the sliced model if and
458     // only if it is consumed by at least one compliant operation.  Note that in
459     // the sliced model we share all model inputs of the same "type"; and that
460     // we may later add model inputs to the sliced model.
461     for (uint32_t origInputIndex : mModel.main.inputIndexes) {
462         if (inputOperandIndexesOfCompliantOperations.count(origInputIndex)) {
463             const uint32_t slicedIndex =
464                     origOperandToSlicedInputOperandIndex.getIndex(origOperands[origInputIndex]);
465             origOperandIndexToSlicedIndex[origInputIndex] = slicedIndex;
466             VLOG(COMPILATION) << "origOperandIndexToSlicedIndex inputIndexes processing created "
467                               << origInputIndex << " -> " << slicedIndex << ": "
468                               << slicedOperands[slicedIndex];
469         }
470     }
471 
472     // Main loop: Process each operation of the original model.
473     processOperations(&slice, &origOperandIndexToSlicedIndex, &origOperandToSlicedInputOperandIndex,
474                       noncompliantOperations, inputOperandIndexesOfCompliantOperations);
475 
476     // To keep things simple, we copy over these fields as-is.  We could instead
477     // opt to regenerate them based on the operands present in the sliced model:
478     // This would be more complex and probably take more computation time, but
479     // it would reduce the size of the sliced model, and hence the time spent
480     // copying it around and potentially passing it across process boundaries.
481     slice.mModel.operandValues = mModel.operandValues;
482     slice.mModel.pools = mModel.pools;
483 
484     if (VLOG_IS_ON(COMPILATION)) {
485         {
486             std::ostringstream fromName;
487             fromName << "Slice: From canonical";
488             graphDump(fromName.str().c_str(), mModel);
489         }
490         {
491             std::ostringstream toName;
492             toName << "Slice: To " << version;
493             graphDump(toName.str().c_str(), slice.mModel);
494         }
495     }
496 
497     slice.mState = invalid(slice.mModel, version, mStrictSlicing) ? SliceState::INVALID
498                                                                   : SliceState::NORMAL;
499 
500     return slice;
501 }
502 
503 }  // namespace android::nn
504