1 /*
2  * Copyright (C) 2019 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #ifndef ANDROID_PACKAGES_MODULES_NEURALNETWORKS_RUNTIME_TYPE_MANAGER_H
18 #define ANDROID_PACKAGES_MODULES_NEURALNETWORKS_RUNTIME_TYPE_MANAGER_H
19 
20 #include <map>
21 #include <set>
22 #include <string>
23 #include <vector>
24 
25 #include "Manager.h"
26 
27 #ifndef NN_COMPATIBILITY_LIBRARY_BUILD
28 #include "AppInfoFetcher.h"
29 #endif  // NN_COMPATIBILITY_LIBRARY_BUILD
30 
31 namespace android {
32 namespace nn {
33 
34 // Manages runtime operand and operation type information.
35 //
36 // This class gathers information about extension types from all devices
37 // and provides a unified way to access information about any known type.
38 class TypeManager {
39    public:
get()40     static TypeManager* get() {
41         static TypeManager manager;
42         return &manager;
43     }
44 
45     // Creates an operand/operation type corresponding to a given extension
46     // name and type within extension.
47     //
48     // Returns false if the extension is unknown.
49     bool getExtensionType(const char* extensionName, uint16_t typeWithinExtension, int32_t* type);
50 
51     // Looks up information about the extension corresponding to the given prefix
52     //
53     // Returns false if no extension corresponds to the given prefix.
54     bool getExtensionInfo(uint16_t prefix, const Extension** extension) const;
55 
56     // Looks up information about an extension operand type
57     //
58     // Returns false if the extension or type is unknown.
59     bool getExtensionOperandTypeInfo(OperandType type,
60                                      const Extension::OperandTypeInformation** info) const;
61 
62     // Returns true if an operand type is a tensor type.
63     //
64     // Aborts if the type is an unknown extension type.
65     bool isTensorType(OperandType type) const;
66 
67     // Returns the amount of space needed to store a value of the dimensions and
68     // type of this operand. For a tensor with unspecified rank or at least one
69     // unspecified dimension, returns zero.
70     //
71     // Aborts if the type is an unknown extension type.
72     // Aborts if the size would overflow the return type.
getSizeOfData(const Operand & operand)73     uint32_t getSizeOfData(const Operand& operand) const {
74         return getSizeOfData(operand.type, operand.dimensions);
75     }
76 
77     // Returns the amount of space needed to store a value of the specified
78     // dimensions and type. For a tensor with unspecified rank or at least one
79     // unspecified dimension, returns zero.
80     //
81     // Aborts if the type is an unknown extension type.
82     uint32_t getSizeOfData(OperandType type, const std::vector<uint32_t>& dimensions) const;
83 
84     // Returns the ExtensionNameAndPrefix mapping from metaData.
85     std::vector<ExtensionNameAndPrefix> getExtensionNameAndPrefix(
86             const std::vector<TokenValuePair>& metaData);
87 
88     // Returns true if the amount of space needed to store a value of the specified
89     // dimensions and element size overflows the uint32_t type.
90     //
91     // See also TypeManager::sizeOfDataOverflowsUInt32().
92     bool sizeOfDataOverflowsUInt32(OperandType type, const std::vector<uint32_t>& dimensions) const;
93 
94     // Returns true if extensions usage is allowed in current process.
areExtensionsAllowed()95     bool areExtensionsAllowed() const { return mExtensionsAllowed; }
96 
97     // This method is intended for use only by internal unit tests.
98     //
99     // Registers an extension.
100     //
101     // Returns true if the registration was successful.
forTest_registerExtension(const Extension & extension)102     bool forTest_registerExtension(const Extension& extension) {
103         return registerExtension(extension, "INTERNAL TEST");
104     }
105 
106     // This method is intended for use only by internal unit tests.
107     //
108     // Resets the internal state.
109     //
110     // After calling forTest_registerExtension() any number of times, call
111     // forTest_reset() to return to the state as if forTest_registerExtension()
112     // had never been called. Note that forTest_reset() resets all internal
113     // state (including assigned prefixes) and re-discovers extensions from
114     // available devices.
forTest_reset()115     void forTest_reset() { *this = TypeManager(); }
116 
117 #ifndef NN_COMPATIBILITY_LIBRARY_BUILD
118     // Check if NNAPI Vendor extensions are usable in the process with the given app
119     // and supplemental infomation.
120     //
121     // useOnProductImageEnabled - whether apps/binaries preinstalled on /product partition
122     // can be enabled for extensions use.
123     // allowlist - list of apps/binaries which are allowed to use extensions.
124     static bool isExtensionsUseAllowed(const AppInfoFetcher::AppInfo& appPackageInfo,
125                                        bool useOnProductImageEnabled,
126                                        const std::vector<std::string>& allowlist);
127 #endif  // NN_COMPATIBILITY_LIBRARY_BUILD
128 
129    private:
130     TypeManager();
131     void findAvailableExtensions();
132     bool registerExtension(Extension extension, const std::string& deviceName);
133 
134     // Returns the numeric "prefix" value corresponding to an extension.
135     //
136     // Returns false when assigning a new prefix would overflow uint16_t.
137     bool getExtensionPrefix(const std::string& extensionName, uint16_t* prefix);
138 
139     const DeviceManager* mDeviceManager = DeviceManager::get();
140 
141     // Contains all registered extensions.
142     std::map<std::string, Extension> mExtensionNameToExtension;
143 
144     // Contains the name of the first discovered device that supports an
145     // extension. Used for error reporting.
146     std::map<std::string, std::string> mExtensionNameToFirstDevice;
147 
148     // When multiple devices report conflicting information about an extension,
149     // the extension is disabled.
150     std::set<std::string> mDisabledExtensions;
151 
152     // The fields below are used to support efficient extension name to
153     // prefix mapping. New prefixes are created by getExtensionPrefix.
154     std::map<std::string, uint16_t> mExtensionNameToPrefix;
155     // Entries of mPrefixToExtension point into mExtensionNameToExtension.
156     // prefix=0 corresponds to no extension and should never be looked up.
157     std::vector<Extension*> mPrefixToExtension = {nullptr};
158 
159     // True if Extensions can be used in current process.
160     bool mExtensionsAllowed = false;
161 };
162 
163 }  // namespace nn
164 }  // namespace android
165 
166 #endif  // ANDROID_PACKAGES_MODULES_NEURALNETWORKS_RUNTIME_TYPE_MANAGER_H
167