1 /*
2  * Copyright (C) 2019 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #ifndef ANDROID_PACKAGES_MODULES_NEURALNETWORKS_COMMON_META_MODEL_H
18 #define ANDROID_PACKAGES_MODULES_NEURALNETWORKS_COMMON_META_MODEL_H
19 
20 #include <android-base/macros.h>
21 
22 #include <functional>
23 #include <map>
24 #include <optional>
25 #include <set>
26 #include <utility>
27 #include <vector>
28 
29 #include "nnapi/Types.h"
30 
31 namespace android::nn {
32 
33 // The MetaModel class encapsulates a Model and provides machinery to create
34 // from that original Model a "slice" of that Model consisting of:
35 // - the subset of operations that is compliant with a particular version; and
36 // - a mechanism for mapping operations from the slice back to operations of the
37 //   original Model.
38 // The slice is intended to be passed to IDevice::getSupportedOperations(),
39 // with the mapping used to translate the results of that call from the slice's
40 // operations to the original Model's operations.  The slice has no other
41 // purpose (for example, it is not guaranteed to have the same topology as a
42 // subgraph of the original model).
43 //
44 // When getSlice() is called, a slice is created and cached, if necessary; and
45 // then the cached slice is returned.
46 //
47 // The meaning of the return value of getSlice() is explained by the following
48 // example:
49 //
50 //     const MetaModel& metaModel = ...;
51 //     auto ret = metaModel.getSlice(kVersionFeatureLevel1);
52 //     if (ret.has_value()) {
53 //         const Model model = ret->first;  // the slice
54 //         auto mapper = ret->second;
55 //         // mapper is a functor that takes an operation index in the
56 //         // slice and returns the corresponding operation index in the
57 //         // original Model.  The functor will remain valid for the lifetime
58 //         // of the MetaModel.
59 //     } else {
60 //         // Could not obtain a slice.  For example, perhaps none of the
61 //         // original model's operations are compliant with
62 //         // kVersionFeatureLevel1.
63 //     }
64 //
65 class MetaModel {
66    public:
67     using Mapper = std::function<uint32_t(uint32_t)>;
68 
69     using ReturnedSlice = std::optional<std::pair<Model, Mapper>>;
70 
71     // Precondition: validate(model).has_value()
72     MetaModel(Model model, bool strictSlicing);
73 
getModel()74     const Model& getModel() const { return mModel; }
75 
76     ReturnedSlice getSlice(Version version) const;
77 
78     // Disallowing copy constructor and assignment operator is for efficiency,
79     // not for correctness.  The default copy constructor and assignment
80     // operator would work fine.  However, they could be surprisingly expensive
81     // if the mCachedSlices member gets copied: Up to one Model instance and
82     // one std::vector instance per version could be copied.  We could choose
83     // to accept this expense; or we could write custom copy and assign that do
84     // not copy the mCachedSlices member but instead set the destination
85     // mCachedSlices Slice::mState members to SliceState::UNINITIALIZED.
86     //
87     // There are no such issues with move constructor and move assignment.
88     MetaModel(const MetaModel&) = delete;
89     MetaModel& operator=(const MetaModel&) = delete;
90     MetaModel(MetaModel&&) = default;
91     MetaModel& operator=(MetaModel&&) = default;
92 
93    private:
94     Model mModel;
95     Version mModelMinimumSupportedVersion;
96 
97     // mStrictSlicing controls validity checking.  If the slicing algorithm
98     // produces an invalid model (because something has gone wrong with the
99     // algorithm or with a utility function it depends on), getSlice() can
100     // return an std::optional<> for which has_value() returns false, signifying
101     // that no slice is available.  However, if mStrictSlicing is true,
102     // getSlice() cause a CHECK*() to fail.  This can be used in debugging to
103     // find situations where slicing has failed unexpectedly.
104     bool mStrictSlicing;
105 
106     enum class SliceState { UNINITIALIZED, INVALID, NORMAL };
107     struct Slice {
108         SliceState mState = SliceState::UNINITIALIZED;
109         Model mModel;
110         std::vector<uint32_t> mSlicedOperationIndexToOrigIndex;
111     };
112 
113     struct Comparison {
114         bool operator()(Version lhs, Version rhs) const;
115     };
116     mutable std::map<Version, Slice, Comparison> mCachedSlices;
117 
118     Slice makeSlice(Version version) const;
119 
120     std::set<uint32_t> getNoncompliantOperations(Version version) const;
121 
122     // Utility class for makeSlice().
123     class OrigOperandToSlicedInputOperandIndex;
124 
125     // Utility function for makeSlice(): Walks operations of original
126     // model and populates sliced model accordingly.
127     void processOperations(
128             Slice* slice, std::map<uint32_t, uint32_t>* origOperandIndexToSlicedIndex,
129             OrigOperandToSlicedInputOperandIndex* origOperandToSlicedInputOperandIndex,
130             const std::set<uint32_t>& noncompliantOperations,
131             const std::set<uint32_t>& inputOperandIndexesOfCompliantOperations) const;
132 };
133 
134 }  // namespace android::nn
135 
136 #endif  // ANDROID_PACKAGES_MODULES_NEURALNETWORKS_COMMON_META_MODEL_H
137