1 /*
2  * Copyright (C) 2019 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #define LOG_TAG "TypeManager"
18 
19 #include "TypeManager.h"
20 
21 #include "Utils.h"
22 
23 #include <android-base/file.h>
24 #include <android-base/properties.h>
25 #include <android/content/pm/IPackageManagerNative.h>
26 #include <binder/IServiceManager.h>
27 #include <procpartition/procpartition.h>
28 #include <algorithm>
29 #include <string_view>
30 
31 namespace android {
32 namespace nn {
33 
34 // Replacement function for std::string_view::starts_with()
35 // which shall be available in C++20.
36 #if __cplusplus >= 202000L
37 #error "When upgrading to C++20, remove this error and file a bug to remove this workaround."
38 #endif
StartsWith(std::string_view sv,std::string_view prefix)39 inline bool StartsWith(std::string_view sv, std::string_view prefix) {
40     return sv.substr(0u, prefix.size()) == prefix;
41 }
42 
43 namespace {
44 
45 const uint8_t kLowBitsType = static_cast<uint8_t>(Model::ExtensionTypeEncoding::LOW_BITS_TYPE);
46 const uint32_t kMaxPrefix =
47         (1 << static_cast<uint8_t>(Model::ExtensionTypeEncoding::HIGH_BITS_PREFIX)) - 1;
48 
49 // Checks if the two structures contain the same information. The order of
50 // operand types within the structures does not matter.
equal(const Extension & a,const Extension & b)51 bool equal(const Extension& a, const Extension& b) {
52     NN_RET_CHECK_EQ(a.name, b.name);
53     // Relies on the fact that TypeManager sorts operandTypes.
54     NN_RET_CHECK(a.operandTypes == b.operandTypes);
55     return true;
56 }
57 
58 // Property for disabling NNAPI vendor extensions on product image (used on GSI /product image,
59 // which can't use NNAPI vendor extensions).
60 const char kVExtProductDeny[] = "ro.nnapi.extensions.deny_on_product";
isNNAPIVendorExtensionsUseAllowedInProductImage()61 bool isNNAPIVendorExtensionsUseAllowedInProductImage() {
62     const std::string vExtProductDeny = android::base::GetProperty(kVExtProductDeny, "");
63     return vExtProductDeny.empty();
64 }
65 
66 // The file containing the list of Android apps and binaries allowed to use vendor extensions.
67 // Each line of the file contains new entry. If entry is prefixed by
68 // '/' slash, then it's a native binary path (e.g. '/data/foo'). If not, it's a name
69 // of Android app package (e.g. 'com.foo.bar').
70 const char kAppAllowlistPath[] = "/vendor/etc/nnapi_extensions_app_allowlist";
71 const char kCtsAllowlist[] = "/data/local/tmp/CTSNNAPITestCases";
getVendorExtensionAllowlistedApps()72 std::vector<std::string> getVendorExtensionAllowlistedApps() {
73     std::string data;
74     // Allowlist CTS by default.
75     std::vector<std::string> allowlist = {kCtsAllowlist};
76 
77     if (!android::base::ReadFileToString(kAppAllowlistPath, &data)) {
78         // Return default allowlist (no app can use extensions).
79         LOG(INFO) << "Failed to read " << kAppAllowlistPath
80                   << " ; No app allowlisted for vendor extensions use.";
81         return allowlist;
82     }
83 
84     std::istringstream streamData(data);
85     std::string line;
86     while (std::getline(streamData, line)) {
87         // Do some basic sanity check on entry, it's either
88         // fs path or package name.
89         if (StartsWith(line, "/") || line.find('.') != std::string::npos) {
90             allowlist.push_back(line);
91         } else {
92             LOG(ERROR) << kAppAllowlistPath << " - Invalid entry: " << line;
93         }
94     }
95     return allowlist;
96 }
97 
98 // Query PackageManagerNative service about Android app properties.
99 // On success, it will populate appPackageInfo->app* fields.
fetchAppPackageLocationInfo(uid_t uid,TypeManager::AppPackageInfo * appPackageInfo)100 bool fetchAppPackageLocationInfo(uid_t uid, TypeManager::AppPackageInfo* appPackageInfo) {
101     sp<::android::IServiceManager> sm(::android::defaultServiceManager());
102     sp<::android::IBinder> binder(sm->getService(String16("package_native")));
103     if (binder == nullptr) {
104         LOG(ERROR) << "getService package_native failed";
105         return false;
106     }
107 
108     sp<content::pm::IPackageManagerNative> packageMgr =
109             interface_cast<content::pm::IPackageManagerNative>(binder);
110     std::vector<int> uids{static_cast<int>(uid)};
111     std::vector<std::string> names;
112     binder::Status status = packageMgr->getNamesForUids(uids, &names);
113     if (!status.isOk()) {
114         LOG(ERROR) << "package_native::getNamesForUids failed: "
115                    << status.exceptionMessage().c_str();
116         return false;
117     }
118     const std::string& packageName = names[0];
119 
120     appPackageInfo->appPackageName = packageName;
121     int flags = 0;
122     status = packageMgr->getLocationFlags(packageName, &flags);
123     if (!status.isOk()) {
124         LOG(ERROR) << "package_native::getLocationFlags failed: "
125                    << status.exceptionMessage().c_str();
126         return false;
127     }
128     // isSystemApp()
129     appPackageInfo->appIsSystemApp =
130             ((flags & content::pm::IPackageManagerNative::LOCATION_SYSTEM) != 0);
131     // isVendor()
132     appPackageInfo->appIsOnVendorImage =
133             ((flags & content::pm::IPackageManagerNative::LOCATION_VENDOR) != 0);
134     // isProduct()
135     appPackageInfo->appIsOnProductImage =
136             ((flags & content::pm::IPackageManagerNative::LOCATION_PRODUCT) != 0);
137     return true;
138 }
139 
140 // Check if this process is allowed to use NNAPI Vendor extensions.
isNNAPIVendorExtensionsUseAllowed(const std::vector<std::string> & allowlist)141 bool isNNAPIVendorExtensionsUseAllowed(const std::vector<std::string>& allowlist) {
142     TypeManager::AppPackageInfo appPackageInfo = {
143             .binaryPath = ::android::procpartition::getExe(getpid()),
144             .appPackageName = "",
145             .appIsSystemApp = false,
146             .appIsOnVendorImage = false,
147             .appIsOnProductImage = false};
148 
149     if (appPackageInfo.binaryPath == "/system/bin/app_process64" ||
150         appPackageInfo.binaryPath == "/system/bin/app_process32") {
151         if (!fetchAppPackageLocationInfo(getuid(), &appPackageInfo)) {
152             LOG(ERROR) << "Failed to get app information from package_manager_native";
153             return false;
154         }
155     }
156     return TypeManager::isExtensionsUseAllowed(
157             appPackageInfo, isNNAPIVendorExtensionsUseAllowedInProductImage(), allowlist);
158 }
159 
160 }  // namespace
161 
TypeManager()162 TypeManager::TypeManager() {
163     VLOG(MANAGER) << "TypeManager::TypeManager";
164     mExtensionsAllowed = isNNAPIVendorExtensionsUseAllowed(getVendorExtensionAllowlistedApps());
165     VLOG(MANAGER) << "NNAPI Vendor extensions enabled: " << mExtensionsAllowed;
166     findAvailableExtensions();
167 }
168 
isExtensionsUseAllowed(const AppPackageInfo & appPackageInfo,bool useOnProductImageEnabled,const std::vector<std::string> & allowlist)169 bool TypeManager::isExtensionsUseAllowed(const AppPackageInfo& appPackageInfo,
170                                          bool useOnProductImageEnabled,
171                                          const std::vector<std::string>& allowlist) {
172     // Only selected partitions and user-installed apps (/data)
173     // are allowed to use extensions.
174     if (StartsWith(appPackageInfo.binaryPath, "/vendor/") ||
175         StartsWith(appPackageInfo.binaryPath, "/odm/") ||
176         StartsWith(appPackageInfo.binaryPath, "/data/") ||
177         (StartsWith(appPackageInfo.binaryPath, "/product/") && useOnProductImageEnabled)) {
178 #ifdef NN_DEBUGGABLE
179         // Only on userdebug and eng builds.
180         // When running tests with mma and adb push.
181         if (StartsWith(appPackageInfo.binaryPath, "/data/nativetest") ||
182             // When running tests with Atest.
183             StartsWith(appPackageInfo.binaryPath, "/data/local/tmp/NeuralNetworksTest_")) {
184             return true;
185         }
186 #endif  // NN_DEBUGGABLE
187 
188         return std::find(allowlist.begin(), allowlist.end(), appPackageInfo.binaryPath) !=
189                allowlist.end();
190     } else if (appPackageInfo.binaryPath == "/system/bin/app_process64" ||
191                appPackageInfo.binaryPath == "/system/bin/app_process32") {
192         // App is not system app OR vendor app OR (product app AND product enabled)
193         // AND app is on allowlist.
194         return (!appPackageInfo.appIsSystemApp || appPackageInfo.appIsOnVendorImage ||
195                 (appPackageInfo.appIsOnProductImage && useOnProductImageEnabled)) &&
196                std::find(allowlist.begin(), allowlist.end(), appPackageInfo.appPackageName) !=
197                        allowlist.end();
198     }
199     return false;
200 }
201 
findAvailableExtensions()202 void TypeManager::findAvailableExtensions() {
203     for (const std::shared_ptr<Device>& device : mDeviceManager->getDrivers()) {
204         for (const Extension extension : device->getSupportedExtensions()) {
205             registerExtension(extension, device->getName());
206         }
207     }
208 }
209 
registerExtension(Extension extension,const std::string & deviceName)210 bool TypeManager::registerExtension(Extension extension, const std::string& deviceName) {
211     if (mDisabledExtensions.find(extension.name) != mDisabledExtensions.end()) {
212         LOG(ERROR) << "Extension " << extension.name << " is disabled";
213         return false;
214     }
215 
216     std::sort(extension.operandTypes.begin(), extension.operandTypes.end(),
217               [](const Extension::OperandTypeInformation& a,
218                  const Extension::OperandTypeInformation& b) {
219                   return static_cast<uint16_t>(a.type) < static_cast<uint16_t>(b.type);
220               });
221 
222     std::map<std::string, Extension>::iterator it;
223     bool isNew;
224     std::tie(it, isNew) = mExtensionNameToExtension.emplace(extension.name, extension);
225     if (isNew) {
226         VLOG(MANAGER) << "Registered extension " << extension.name;
227         mExtensionNameToFirstDevice.emplace(extension.name, deviceName);
228     } else if (!equal(extension, it->second)) {
229         LOG(ERROR) << "Devices " << mExtensionNameToFirstDevice[extension.name] << " and "
230                    << deviceName << " provide inconsistent information for extension "
231                    << extension.name << ", which is therefore disabled";
232         mExtensionNameToExtension.erase(it);
233         mDisabledExtensions.insert(extension.name);
234         return false;
235     }
236     return true;
237 }
238 
getExtensionPrefix(const std::string & extensionName,uint16_t * prefix)239 bool TypeManager::getExtensionPrefix(const std::string& extensionName, uint16_t* prefix) {
240     auto it = mExtensionNameToPrefix.find(extensionName);
241     if (it != mExtensionNameToPrefix.end()) {
242         *prefix = it->second;
243     } else {
244         NN_RET_CHECK_LE(mPrefixToExtension.size(), kMaxPrefix) << "Too many extensions in use";
245         *prefix = mPrefixToExtension.size();
246         mExtensionNameToPrefix[extensionName] = *prefix;
247         mPrefixToExtension.push_back(&mExtensionNameToExtension[extensionName]);
248     }
249     return true;
250 }
251 
getExtensionType(const char * extensionName,uint16_t typeWithinExtension,int32_t * type)252 bool TypeManager::getExtensionType(const char* extensionName, uint16_t typeWithinExtension,
253                                    int32_t* type) {
254     uint16_t prefix;
255     NN_RET_CHECK(getExtensionPrefix(extensionName, &prefix));
256     *type = (prefix << kLowBitsType) | typeWithinExtension;
257     return true;
258 }
259 
getExtensionInfo(uint16_t prefix,const Extension ** extension) const260 bool TypeManager::getExtensionInfo(uint16_t prefix, const Extension** extension) const {
261     NN_RET_CHECK_NE(prefix, 0u) << "prefix=0 does not correspond to an extension";
262     NN_RET_CHECK_LT(prefix, mPrefixToExtension.size()) << "Unknown extension prefix";
263     *extension = mPrefixToExtension[prefix];
264     return true;
265 }
266 
getExtensionOperandTypeInfo(OperandType type,const Extension::OperandTypeInformation ** info) const267 bool TypeManager::getExtensionOperandTypeInfo(
268         OperandType type, const Extension::OperandTypeInformation** info) const {
269     uint32_t operandType = static_cast<uint32_t>(type);
270     uint16_t prefix = operandType >> kLowBitsType;
271     uint16_t typeWithinExtension = operandType & ((1 << kLowBitsType) - 1);
272     const Extension* extension;
273     NN_RET_CHECK(getExtensionInfo(prefix, &extension))
274             << "Cannot find extension corresponding to prefix " << prefix;
275     auto it = std::lower_bound(
276             extension->operandTypes.begin(), extension->operandTypes.end(), typeWithinExtension,
277             [](const Extension::OperandTypeInformation& info, uint32_t typeSought) {
278                 return static_cast<uint16_t>(info.type) < typeSought;
279             });
280     NN_RET_CHECK(it != extension->operandTypes.end() &&
281                  static_cast<uint16_t>(it->type) == typeWithinExtension)
282             << "Cannot find operand type " << typeWithinExtension << " in extension "
283             << extension->name;
284     *info = &*it;
285     return true;
286 }
287 
isTensorType(OperandType type) const288 bool TypeManager::isTensorType(OperandType type) const {
289     if (!isExtensionOperandType(type)) {
290         return !nonExtensionOperandTypeIsScalar(static_cast<int>(type));
291     }
292     const Extension::OperandTypeInformation* info;
293     CHECK(getExtensionOperandTypeInfo(type, &info));
294     return info->isTensor;
295 }
296 
getSizeOfData(OperandType type,const std::vector<uint32_t> & dimensions) const297 uint32_t TypeManager::getSizeOfData(OperandType type,
298                                     const std::vector<uint32_t>& dimensions) const {
299     if (!isExtensionOperandType(type)) {
300         return nonExtensionOperandSizeOfData(type, dimensions);
301     }
302 
303     const Extension::OperandTypeInformation* info;
304     CHECK(getExtensionOperandTypeInfo(type, &info));
305 
306     if (!info->isTensor) {
307         return info->byteSize;
308     }
309 
310     if (dimensions.empty()) {
311         return 0;
312     }
313 
314     uint32_t size = info->byteSize;
315     for (auto dimension : dimensions) {
316         size *= dimension;
317     }
318     return size;
319 }
320 
321 }  // namespace nn
322 }  // namespace android
323