1 /* 2 * Copyright (C) 2019 The Android Open Source Project 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 #ifndef ANDROID_FRAMEWORKS_ML_NN_COMMON_META_MODEL_H 18 #define ANDROID_FRAMEWORKS_ML_NN_COMMON_META_MODEL_H 19 20 #include "HalInterfaces.h" 21 22 #include <android-base/macros.h> 23 #include <functional> 24 #include <map> 25 #include <optional> 26 #include <set> 27 #include <utility> 28 #include <vector> 29 30 namespace android::nn { 31 32 // The MetaModel class encapsulates a Model and provides machinery to create 33 // from that original Model a "slice" of that Model consisting of: 34 // - the subset of operations that is compliant with a particular HAL version; and 35 // - a mechanism for mapping operations from the slice back to operations of the 36 // original Model. 37 // The slice is intended to be passed to IDevice::getSupportedOperations*(), 38 // with the mapping used to translate the results of that call from the slice's 39 // operations to the original Model's operations. The slice has no other 40 // purpose (for example, it is not guaranteed to have the same topology as a 41 // subgraph of the original model). 42 // 43 // When a getSlice*() method is called, a slice is created and cached, if 44 // necessary; and then the cached slice is returned. 45 // 46 // The meaning of the return value of the getSlice*() methods is explained by 47 // the following example: 48 // 49 // const MetaModel& metaModel = ...; 50 // auto ret = metaModel.getSliceV1_0(); // getSliceV1_1() is similar 51 // if (ret.has_value()) { 52 // const V1_0::Model model = ret->first; // the slice 53 // auto mapper = ret->second; 54 // // mapper is a functor that takes an operation index in the 55 // // slice and returns the corresponding operation index in the 56 // // original Model. The functor will remain valid for the lifetime 57 // // of the MetaModel. 58 // } else { 59 // // Could not obtain a slice. For example, perhaps none of the 60 // // original model's operations are compliant with V1_0. 61 // } 62 // 63 class MetaModel { 64 public: 65 using Mapper = std::function<uint32_t(uint32_t)>; 66 67 template <class T_Model> 68 using ReturnedSlice = std::optional<std::pair<T_Model, Mapper>>; 69 MetaModel(hal::Model model,bool strictSlicing)70 MetaModel(hal::Model model, bool strictSlicing) 71 : mHidlModel(std::move(model)), mStrictSlicing(strictSlicing) {} 72 getModel()73 const hal::Model& getModel() const { return mHidlModel; } 74 getSliceV1_0()75 ReturnedSlice<hal::V1_0::Model> getSliceV1_0() const { return getSlice(&mSliceV1_0); } getSliceV1_1()76 ReturnedSlice<hal::V1_1::Model> getSliceV1_1() const { return getSlice(&mSliceV1_1); } getSliceV1_2()77 ReturnedSlice<hal::V1_2::Model> getSliceV1_2() const { return getSlice(&mSliceV1_2); } 78 79 // Disallowing copy constructor and assignment operator is for efficiency, 80 // not for correctness. The default copy constructor and assignment 81 // operator would work fine. However, they could be surprisingly expensive 82 // if the mSlice* members get copied: Up to three Model instances and two 83 // std::vector instances could be copied. We could choose to accept this 84 // expense; or we could write custom copy and assign that do not copy the 85 // mSlice* members but instead set the destination mSlice* members to 86 // SliceState::UNINITIALIZED. 87 // 88 // There are no such issues with move constructor and move assignment. 89 MetaModel(const MetaModel&) = delete; 90 MetaModel& operator=(const MetaModel&) = delete; 91 MetaModel(MetaModel&&) = default; 92 MetaModel& operator=(MetaModel&&) = default; 93 94 private: 95 hal::Model mHidlModel; 96 97 // mStrictSlicing controls sanity checking. If the slicing algorithm 98 // produces an invalid model (because something has gone wrong with the 99 // algorithm or with a utility function it depends on), getSlice*() can 100 // return an std::optional<> for which has_value() returns false, signifying 101 // that no slice is available. However, if mStrictSlicing is true, 102 // getSlice*() cause a CHECK*() to fail. This can be used in debugging to 103 // find situations where slicing has failed unexpectedly. 104 bool mStrictSlicing; 105 106 enum class SliceState { UNINITIALIZED, INVALID, NORMAL }; 107 template <class T_SlicedModel> 108 struct Slice { 109 SliceState mState = SliceState::UNINITIALIZED; 110 T_SlicedModel mHidlModel; 111 std::vector<uint32_t> mSlicedOperationIndexToOrigIndex; 112 113 using Operand = typename decltype(mHidlModel.operands)::value_type; 114 using Operation = typename decltype(mHidlModel.operations)::value_type; 115 using OperationType = decltype(Operation::type); 116 }; 117 mutable Slice<hal::V1_0::Model> mSliceV1_0; 118 mutable Slice<hal::V1_1::Model> mSliceV1_1; 119 mutable Slice<hal::V1_2::Model> mSliceV1_2; 120 121 template <class T_SlicedModel> 122 ReturnedSlice<T_SlicedModel> getSlice(Slice<T_SlicedModel>* slice) const; 123 124 template <class T_SlicedModel> 125 Slice<T_SlicedModel> makeSlice() const; 126 127 // Utility class for makeSlice(). 128 template <typename T_SlicedOperand> 129 class OrigOperandToSlicedInputOperandIndex; 130 131 // Utility function for makeSlice(): Walks operations of original 132 // model and populates sliced model accordingly. 133 template <class T_SlicedModel> 134 void processOperations( 135 Slice<T_SlicedModel>* slice, 136 std::map<uint32_t, uint32_t>* origOperandIndexToSlicedIndex, 137 OrigOperandToSlicedInputOperandIndex<typename Slice<T_SlicedModel>::Operand>* 138 origOperandToSlicedInputOperandIndex, 139 const std::set<uint32_t>& noncompliantOperations, 140 const std::set<uint32_t>& inputOperandIndexesOfCompliantOperations) const; 141 }; 142 143 } // namespace android::nn 144 145 #endif // ANDROID_FRAMEWORKS_ML_NN_COMMON_META_MODEL_H 146