1 /*
2  * Copyright (C) 2019 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 package com.android.nn.benchmark.util;
18 
19 import android.util.Log;
20 
21 import android.app.Activity;
22 import android.os.Bundle;
23 import com.android.nn.benchmark.core.NNTestBase;
24 import com.android.nn.benchmark.core.TestModels.TestModelEntry;
25 import com.android.nn.benchmark.core.TestModels;
26 import java.io.IOException;
27 import java.io.File;
28 
29 
30 /** Helper activity for dumping state of interference intermediate tensors.
31  *
32  * Example usage:
33  * adb shell am start -n  com.android.nn.benchmark.app/com.android.nn.benchmark.\
34  * util.DumpIntermediateTensors --es modelName mobilenet_v1_1.0_224_quant_topk_aosp,tts_float\
35  * inputAssetIndex 0
36  *
37  * Assets will be then dumped into /data/data/com.android.nn.benchmark.app/files/intermediate
38  * To fetch:
39  * adb pull /data/data/com.android.nn.benchmark.app/files/intermediate
40  *
41  */
42 public class DumpIntermediateTensors extends Activity {
43     protected static final String TAG = "VDEBUG";
44     public static final String EXTRA_MODEL_NAME = "modelName";
45     public static final String EXTRA_INPUT_ASSET_INDEX= "inputAssetIndex";
46     public static final String EXTRA_INPUT_ASSET_SIZE= "inputAssetSize";
47     public static final String DUMP_DIR = "intermediate";
48     public static final String CPU_DIR = "cpu";
49     public static final String NNAPI_DIR = "nnapi";
50     // TODO(veralin): Update to use other models in vendor as well.
51     // Due to recent change in NNScoringTest, the model names are moved to here.
52     private static final String[] MODEL_NAMES = new String[]{
53         "tts_float",
54         "asr_float",
55         "mobilenet_v1_1.0_224_quant_topk_aosp",
56         "mobilenet_v1_1.0_224_topk_aosp",
57         "mobilenet_v1_0.75_192_quant_topk_aosp",
58         "mobilenet_v1_0.75_192_topk_aosp",
59         "mobilenet_v1_0.5_160_quant_topk_aosp",
60         "mobilenet_v1_0.5_160_topk_aosp",
61         "mobilenet_v1_0.25_128_quant_topk_aosp",
62         "mobilenet_v1_0.25_128_topk_aosp",
63         "mobilenet_v2_0.35_128_topk_aosp",
64         "mobilenet_v2_0.5_160_topk_aosp",
65         "mobilenet_v2_0.75_192_topk_aosp",
66         "mobilenet_v2_1.0_224_topk_aosp",
67         "mobilenet_v2_1.0_224_quant_topk_aosp",
68     };
69 
70     @Override
onCreate(Bundle savedInstanceState)71     protected void onCreate(Bundle savedInstanceState) {
72         super.onCreate(savedInstanceState);
73         Bundle extras = getIntent().getExtras();
74 
75         String userModelName = extras.getString(EXTRA_MODEL_NAME);
76         int inputAssetIndex = extras.getInt(EXTRA_INPUT_ASSET_INDEX, 0);
77         int inputAssetSize = extras.getInt(EXTRA_INPUT_ASSET_SIZE, 1);
78 
79         // Default to run all models in NNScoringTest
80         String[] modelNames = userModelName == null? MODEL_NAMES: userModelName.split(",");
81 
82         try {
83             File dumpDir = new File(getFilesDir(), DUMP_DIR);
84             safeMkdir(dumpDir);
85 
86             for (String modelName : modelNames) {
87                 File modelDir = new File(getFilesDir() + "/" + DUMP_DIR, modelName);
88                 safeMkdir(modelDir);
89                 // Run in CPU and NNAPI mode
90                 for (final boolean useNNAPI : new boolean[] {false, true}) {
91                     String useNNAPIDir = useNNAPI? NNAPI_DIR: CPU_DIR;
92                     Log.i(TAG, "Running " + modelName + " in " + useNNAPIDir);
93                     TestModelEntry modelEntry = TestModels.getModelByName(modelName);
94                     NNTestBase testBase = modelEntry.createNNTestBase(
95                         useNNAPI, true/*enableIntermediateTensorsDump*/);
96                     testBase.setupModel(this);
97                     File outputDir = new File(getFilesDir() + "/" + DUMP_DIR +
98                         "/" + modelName, useNNAPIDir);
99                     safeMkdir(outputDir);
100                     testBase.dumpAllLayers(outputDir, inputAssetIndex, inputAssetSize);
101                 }
102             }
103 
104         } catch (Exception e) {
105             Log.e(TAG, "Failed to dump tensors", e);
106             throw new IllegalStateException("Failed to dump tensors", e);
107         }
108         finish();
109     }
110 
deleteRecursive(File fileOrDirectory)111     private void deleteRecursive(File fileOrDirectory) {
112         if (fileOrDirectory.isDirectory()) {
113             for (File child : fileOrDirectory.listFiles()) {
114                 deleteRecursive(child);
115             }
116         }
117         fileOrDirectory.delete();
118     }
119 
safeMkdir(File fileOrDirectory)120     private void safeMkdir(File fileOrDirectory) {
121         deleteRecursive(fileOrDirectory);
122         fileOrDirectory.mkdir();
123     }
124 }
125