1 /* 2 * Copyright (C) 2019 The Android Open Source Project 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 #ifndef ANDROID_PACKAGES_MODULES_NEURALNETWORKS_COMMON_META_MODEL_H 18 #define ANDROID_PACKAGES_MODULES_NEURALNETWORKS_COMMON_META_MODEL_H 19 20 #include <android-base/macros.h> 21 22 #include <functional> 23 #include <map> 24 #include <optional> 25 #include <set> 26 #include <utility> 27 #include <vector> 28 29 #include "nnapi/Types.h" 30 31 namespace android::nn { 32 33 // The MetaModel class encapsulates a Model and provides machinery to create 34 // from that original Model a "slice" of that Model consisting of: 35 // - the subset of operations that is compliant with a particular version; and 36 // - a mechanism for mapping operations from the slice back to operations of the 37 // original Model. 38 // The slice is intended to be passed to IDevice::getSupportedOperations(), 39 // with the mapping used to translate the results of that call from the slice's 40 // operations to the original Model's operations. The slice has no other 41 // purpose (for example, it is not guaranteed to have the same topology as a 42 // subgraph of the original model). 43 // 44 // When getSlice() is called, a slice is created and cached, if necessary; and 45 // then the cached slice is returned. 46 // 47 // The meaning of the return value of getSlice() is explained by the following 48 // example: 49 // 50 // const MetaModel& metaModel = ...; 51 // auto ret = metaModel.getSlice(kVersionFeatureLevel1); 52 // if (ret.has_value()) { 53 // const Model model = ret->first; // the slice 54 // auto mapper = ret->second; 55 // // mapper is a functor that takes an operation index in the 56 // // slice and returns the corresponding operation index in the 57 // // original Model. The functor will remain valid for the lifetime 58 // // of the MetaModel. 59 // } else { 60 // // Could not obtain a slice. For example, perhaps none of the 61 // // original model's operations are compliant with 62 // // kVersionFeatureLevel1. 63 // } 64 // 65 class MetaModel { 66 public: 67 using Mapper = std::function<uint32_t(uint32_t)>; 68 69 using ReturnedSlice = std::optional<std::pair<Model, Mapper>>; 70 71 // Precondition: validate(model).has_value() 72 MetaModel(Model model, bool strictSlicing); 73 getModel()74 const Model& getModel() const { return mModel; } 75 76 ReturnedSlice getSlice(Version version) const; 77 78 // Disallowing copy constructor and assignment operator is for efficiency, 79 // not for correctness. The default copy constructor and assignment 80 // operator would work fine. However, they could be surprisingly expensive 81 // if the mCachedSlices member gets copied: Up to one Model instance and 82 // one std::vector instance per version could be copied. We could choose 83 // to accept this expense; or we could write custom copy and assign that do 84 // not copy the mCachedSlices member but instead set the destination 85 // mCachedSlices Slice::mState members to SliceState::UNINITIALIZED. 86 // 87 // There are no such issues with move constructor and move assignment. 88 MetaModel(const MetaModel&) = delete; 89 MetaModel& operator=(const MetaModel&) = delete; 90 MetaModel(MetaModel&&) = default; 91 MetaModel& operator=(MetaModel&&) = default; 92 93 private: 94 Model mModel; 95 Version mModelMinimumSupportedVersion; 96 97 // mStrictSlicing controls validity checking. If the slicing algorithm 98 // produces an invalid model (because something has gone wrong with the 99 // algorithm or with a utility function it depends on), getSlice() can 100 // return an std::optional<> for which has_value() returns false, signifying 101 // that no slice is available. However, if mStrictSlicing is true, 102 // getSlice() cause a CHECK*() to fail. This can be used in debugging to 103 // find situations where slicing has failed unexpectedly. 104 bool mStrictSlicing; 105 106 enum class SliceState { UNINITIALIZED, INVALID, NORMAL }; 107 struct Slice { 108 SliceState mState = SliceState::UNINITIALIZED; 109 Model mModel; 110 std::vector<uint32_t> mSlicedOperationIndexToOrigIndex; 111 }; 112 113 struct Comparison { 114 bool operator()(Version lhs, Version rhs) const; 115 }; 116 mutable std::map<Version, Slice, Comparison> mCachedSlices; 117 118 Slice makeSlice(Version version) const; 119 120 std::set<uint32_t> getNoncompliantOperations(Version version) const; 121 122 // Utility class for makeSlice(). 123 class OrigOperandToSlicedInputOperandIndex; 124 125 // Utility function for makeSlice(): Walks operations of original 126 // model and populates sliced model accordingly. 127 void processOperations( 128 Slice* slice, std::map<uint32_t, uint32_t>* origOperandIndexToSlicedIndex, 129 OrigOperandToSlicedInputOperandIndex* origOperandToSlicedInputOperandIndex, 130 const std::set<uint32_t>& noncompliantOperations, 131 const std::set<uint32_t>& inputOperandIndexesOfCompliantOperations) const; 132 }; 133 134 } // namespace android::nn 135 136 #endif // ANDROID_PACKAGES_MODULES_NEURALNETWORKS_COMMON_META_MODEL_H 137