1 /*
2  * Copyright (C) 2019 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #ifndef ANDROID_ML_NN_RUNTIME_EXTENSION_MANAGER_H
18 #define ANDROID_ML_NN_RUNTIME_EXTENSION_MANAGER_H
19 
20 #include "HalInterfaces.h"
21 #include "Manager.h"
22 
23 #include <map>
24 #include <set>
25 #include <string>
26 
27 namespace android {
28 namespace nn {
29 
30 // Manages runtime operand and operation type information.
31 //
32 // This class gathers information about extension types from all devices
33 // and provides a unified way to access information about any known type.
34 class TypeManager {
35    public:
get()36     static TypeManager* get() {
37         static TypeManager manager;
38         return &manager;
39     }
40 
41     // Creates an operand/operation type corresponding to a given extension
42     // name and type within extension.
43     //
44     // Returns false if the extension is unknown.
45     bool getExtensionType(const char* extensionName, uint16_t typeWithinExtension, int32_t* type);
46 
47     // Looks up information about the extension corresponding to the given prefix
48     //
49     // Returns false if no extension corresponds to the given prefix.
50     bool getExtensionInfo(uint16_t prefix, const Extension** extension) const;
51 
52     // Looks up information about an extension operand type
53     //
54     // Returns false if the extension or type is unknown.
55     bool getExtensionOperandTypeInfo(OperandType type,
56                                      const Extension::OperandTypeInformation** info) const;
57 
58     // Returns true if an operand type is a tensor type.
59     //
60     // Aborts if the type is an unknown extension type.
61     bool isTensorType(OperandType type) const;
62 
63     // Returns the amount of space needed to store a value of the dimensions and
64     // type of this operand. For a tensor with unspecified rank or at least one
65     // unspecified dimension, returns zero.
66     //
67     // Aborts if the type is an unknown extension type.
getSizeOfData(const Operand & operand)68     uint32_t getSizeOfData(const Operand& operand) const {
69         return getSizeOfData(operand.type, operand.dimensions);
70     }
71 
72     // Returns the amount of space needed to store a value of the specified
73     // dimensions and type. For a tensor with unspecified rank or at least one
74     // unspecified dimension, returns zero.
75     //
76     // Aborts if the type is an unknown extension type.
77     uint32_t getSizeOfData(OperandType type, const std::vector<uint32_t>& dimensions) const;
78 
79     // Returns true if extensions usage is allowed in current process.
areExtensionsAllowed()80     bool areExtensionsAllowed() const { return mExtensionsAllowed; }
81 
82     // This method is intended for use only by internal unit tests.
83     //
84     // Registers an extension.
85     //
86     // Returns true if the registration was successful.
forTest_registerExtension(const Extension & extension)87     bool forTest_registerExtension(const Extension& extension) {
88         return registerExtension(extension, "INTERNAL TEST");
89     }
90 
91     // This method is intended for use only by internal unit tests.
92     //
93     // Resets the internal state.
94     //
95     // After calling forTest_registerExtension() any number of times, call
96     // forTest_reset() to return to the state as if forTest_registerExtension()
97     // had never been called. Note that forTest_reset() resets all internal
98     // state (including assigned prefixes) and re-discovers extensions from
99     // available devices.
forTest_reset()100     void forTest_reset() { *this = TypeManager(); }
101 
102     // Collection of app-related arguments for the isExtensionsUseAllowed method.
103     struct AppPackageInfo {
104         // Path of the binary (/proc/$PID/exe)
105         std::string binaryPath;
106         // Package name of the Android app (empty string if not Android app).
107         std::string appPackageName;
108         // Is the app a system app? (false if not an Android app)
109         bool appIsSystemApp;
110         // Is the app preinstalled on vendor image? (false if not an Android app)
111         bool appIsOnVendorImage;
112         // Is the app preinstalled on product image? (false if not an Android app)
113         bool appIsOnProductImage;
114     };
115 
116     // Check if NNAPI Vendor extensions are usable in the process with the given app
117     // and supplemental infomation.
118     //
119     // useOnProductImageEnabled - whether apps/binaries preinstalled on /product partition
120     // can be enabled for extensions use.
121     // allowlist - list of apps/binaries which are allowed to use extensions.
122     static bool isExtensionsUseAllowed(const AppPackageInfo& appPackageInfo,
123                                        bool useOnProductImageEnabled,
124                                        const std::vector<std::string>& allowlist);
125 
126    private:
127     TypeManager();
128     void findAvailableExtensions();
129     bool registerExtension(Extension extension, const std::string& deviceName);
130 
131     // Returns the numeric "prefix" value corresponding to an extension.
132     //
133     // Returns false when assigning a new prefix would overflow uint16_t.
134     bool getExtensionPrefix(const std::string& extensionName, uint16_t* prefix);
135 
136     const DeviceManager* mDeviceManager = DeviceManager::get();
137 
138     // Contains all registered extensions.
139     std::map<std::string, Extension> mExtensionNameToExtension;
140 
141     // Contains the name of the first discovered device that supports an
142     // extension. Used for error reporting.
143     std::map<std::string, std::string> mExtensionNameToFirstDevice;
144 
145     // When multiple devices report conflicting information about an extension,
146     // the extension is disabled.
147     std::set<std::string> mDisabledExtensions;
148 
149     // The fields below are used to support efficient extension name to
150     // prefix mapping. New prefixes are created by getExtensionPrefix.
151     std::map<std::string, uint16_t> mExtensionNameToPrefix;
152     // Entries of mPrefixToExtension point into mExtensionNameToExtension.
153     // prefix=0 corresponds to no extension and should never be looked up.
154     std::vector<Extension*> mPrefixToExtension = {nullptr};
155 
156     // True if Extensions can be used in current process.
157     bool mExtensionsAllowed = false;
158 };
159 
160 }  // namespace nn
161 }  // namespace android
162 
163 #endif  // ANDROID_ML_NN_RUNTIME_EXTENSION_MANAGER_H
164