1 /* 2 * Copyright (C) 2019 The Android Open Source Project 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 package com.android.nn.benchmark.util; 18 19 import android.util.Log; 20 21 import android.app.Activity; 22 import android.os.Bundle; 23 import com.android.nn.benchmark.core.NNTestBase; 24 import com.android.nn.benchmark.core.TestModels.TestModelEntry; 25 import com.android.nn.benchmark.core.TestModels; 26 import java.io.IOException; 27 import java.io.File; 28 29 30 /** Helper activity for dumping state of interference intermediate tensors. 31 * 32 * Example usage: 33 * adb shell am start -n com.android.nn.benchmark.app/com.android.nn.benchmark.\ 34 * util.DumpIntermediateTensors --es modelName mobilenet_v1_1.0_224_quant_topk_aosp,tts_float\ 35 * inputAssetIndex 0 36 * 37 * Assets will be then dumped into /data/data/com.android.nn.benchmark.app/files/intermediate 38 * To fetch: 39 * adb pull /data/data/com.android.nn.benchmark.app/files/intermediate 40 * 41 */ 42 public class DumpIntermediateTensors extends Activity { 43 protected static final String TAG = "VDEBUG"; 44 public static final String EXTRA_MODEL_NAME = "modelName"; 45 public static final String EXTRA_INPUT_ASSET_INDEX= "inputAssetIndex"; 46 public static final String EXTRA_INPUT_ASSET_SIZE= "inputAssetSize"; 47 public static final String DUMP_DIR = "intermediate"; 48 public static final String CPU_DIR = "cpu"; 49 public static final String NNAPI_DIR = "nnapi"; 50 // TODO(veralin): Update to use other models in vendor as well. 51 // Due to recent change in NNScoringTest, the model names are moved to here. 52 private static final String[] MODEL_NAMES = new String[]{ 53 "tts_float", 54 "asr_float", 55 "mobilenet_v1_1.0_224_quant_topk_aosp", 56 "mobilenet_v1_1.0_224_topk_aosp", 57 "mobilenet_v1_0.75_192_quant_topk_aosp", 58 "mobilenet_v1_0.75_192_topk_aosp", 59 "mobilenet_v1_0.5_160_quant_topk_aosp", 60 "mobilenet_v1_0.5_160_topk_aosp", 61 "mobilenet_v1_0.25_128_quant_topk_aosp", 62 "mobilenet_v1_0.25_128_topk_aosp", 63 "mobilenet_v2_0.35_128_topk_aosp", 64 "mobilenet_v2_0.5_160_topk_aosp", 65 "mobilenet_v2_0.75_192_topk_aosp", 66 "mobilenet_v2_1.0_224_topk_aosp", 67 "mobilenet_v2_1.0_224_quant_topk_aosp", 68 }; 69 70 @Override onCreate(Bundle savedInstanceState)71 protected void onCreate(Bundle savedInstanceState) { 72 super.onCreate(savedInstanceState); 73 Bundle extras = getIntent().getExtras(); 74 75 String userModelName = extras.getString(EXTRA_MODEL_NAME); 76 int inputAssetIndex = extras.getInt(EXTRA_INPUT_ASSET_INDEX, 0); 77 int inputAssetSize = extras.getInt(EXTRA_INPUT_ASSET_SIZE, 1); 78 79 // Default to run all models in NNScoringTest 80 String[] modelNames = userModelName == null? MODEL_NAMES: userModelName.split(","); 81 82 try { 83 File dumpDir = new File(getFilesDir(), DUMP_DIR); 84 safeMkdir(dumpDir); 85 86 for (String modelName : modelNames) { 87 File modelDir = new File(getFilesDir() + "/" + DUMP_DIR, modelName); 88 safeMkdir(modelDir); 89 // Run in CPU and NNAPI mode 90 for (final boolean useNNAPI : new boolean[] {false, true}) { 91 String useNNAPIDir = useNNAPI? NNAPI_DIR: CPU_DIR; 92 Log.i(TAG, "Running " + modelName + " in " + useNNAPIDir); 93 TestModelEntry modelEntry = TestModels.getModelByName(modelName); 94 NNTestBase testBase = modelEntry.createNNTestBase( 95 useNNAPI, true/*enableIntermediateTensorsDump*/); 96 testBase.setupModel(this); 97 File outputDir = new File(getFilesDir() + "/" + DUMP_DIR + 98 "/" + modelName, useNNAPIDir); 99 safeMkdir(outputDir); 100 testBase.dumpAllLayers(outputDir, inputAssetIndex, inputAssetSize); 101 } 102 } 103 104 } catch (Exception e) { 105 Log.e(TAG, "Failed to dump tensors", e); 106 throw new IllegalStateException("Failed to dump tensors", e); 107 } 108 finish(); 109 } 110 deleteRecursive(File fileOrDirectory)111 private void deleteRecursive(File fileOrDirectory) { 112 if (fileOrDirectory.isDirectory()) { 113 for (File child : fileOrDirectory.listFiles()) { 114 deleteRecursive(child); 115 } 116 } 117 fileOrDirectory.delete(); 118 } 119 safeMkdir(File fileOrDirectory)120 private void safeMkdir(File fileOrDirectory) { 121 deleteRecursive(fileOrDirectory); 122 fileOrDirectory.mkdir(); 123 } 124 } 125