1 /*
2  * Copyright (C) 2018 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 package com.android.nn.benchmark.core;
18 
19 import android.content.res.AssetManager;
20 import android.util.Log;
21 
22 import org.json.JSONArray;
23 import org.json.JSONException;
24 import org.json.JSONObject;
25 
26 import java.io.IOException;
27 import java.io.InputStream;
28 import java.io.InputStreamReader;
29 import java.io.Reader;
30 
31 /** Helper class to register test model definitions from assets data */
32 public class TestModelsListLoader {
33     private static final String TAG = "NN_BENCHMARK";
34 
35     /**
36      * Parse list of models in form of json data.
37      *
38      * Example input:
39      * { "models" : [
40      * {"name" : "modelName",
41      * "testName" : "testName",
42      * "baselineSec" : 0.03,
43      * "evaluator": "TopK",
44      * "inputSize" : [1,2,3,4],
45      * "dataSize" : 4,
46      * "inputOutputs" : [ {"input": "input1", "output": "output2"} ]
47      * }
48      * ]}
49      */
parseJSONModelsList(String jsonStringInput)50     static public void parseJSONModelsList(String jsonStringInput) throws JSONException {
51         JSONObject jsonRootObject = new JSONObject(jsonStringInput);
52         JSONArray jsonModelsArray = jsonRootObject.getJSONArray("models");
53 
54         for (int i = 0; i < jsonModelsArray.length(); i++) {
55             JSONObject jsonTestModelEntry = jsonModelsArray.getJSONObject(i);
56 
57             String name = jsonTestModelEntry.getString("name");
58             String testName = name;
59             if (jsonTestModelEntry.has("testName")) {
60                 testName = jsonTestModelEntry.getString("testName");
61             }
62             String modelFile = name;
63             if (jsonTestModelEntry.has("modelFile")) {
64                 modelFile = jsonTestModelEntry.getString("modelFile");
65             }
66             double baseline = jsonTestModelEntry.getDouble("baselineSec");
67             int minSdkVersion = 0;
68             if (jsonTestModelEntry.has("minSdkVersion")) {
69                 minSdkVersion = jsonTestModelEntry.getInt("minSdkVersion");
70             }
71             EvaluatorConfig evaluator = null;
72             if (jsonTestModelEntry.has("evaluator")) {
73                 JSONObject evaluatorJson = jsonTestModelEntry.getJSONObject("evaluator");
74                 evaluator = new EvaluatorConfig(evaluatorJson.getString("className"),
75                         evaluatorJson.has("outputMeanStdDev")
76                                 ? evaluatorJson.getString("outputMeanStdDev")
77                                 : null,
78                         evaluatorJson.has("expectedTop1")
79                                 ? evaluatorJson.getDouble("expectedTop1")
80                                 : null);
81             }
82 
83             int dataSize = jsonTestModelEntry.getInt("dataSize");
84             JSONArray jsonInputSize = jsonTestModelEntry.getJSONArray("inputSize");
85             int[] inputSize = new int[jsonInputSize.length()];
86             int inputSizeBytes = dataSize;
87             for (int k = 0; k < jsonInputSize.length(); ++k) {
88                 inputSize[k] = jsonInputSize.getInt(k);
89                 inputSizeBytes *= inputSize[k];
90             }
91 
92             InferenceInOutSequence.FromAssets[] inputOutputs = null;
93             if (jsonTestModelEntry.has("inputOutputs")) {
94                 JSONArray jsonInputOutputs = jsonTestModelEntry.getJSONArray("inputOutputs");
95                 inputOutputs =
96                         new InferenceInOutSequence.FromAssets[jsonInputOutputs.length()];
97 
98                 for (int j = 0; j < jsonInputOutputs.length(); j++) {
99                     JSONObject jsonInputOutput = jsonInputOutputs.getJSONObject(j);
100                     String input = jsonInputOutput.getString("input");
101                     String[] outputs = null;
102                     String output = jsonInputOutput.optString("output", null);
103                     if (output != null) {
104                         outputs = new String[]{output};
105                     } else {
106                         JSONArray outputArray = jsonInputOutput.getJSONArray("outputs");
107                         if (outputArray != null) {
108                             outputs = new String[outputArray.length()];
109                             for (int k = 0; k < outputArray.length(); ++k) {
110                                 outputs[k] = outputArray.getString(k);
111                             }
112                         }
113                     }
114 
115                     inputOutputs[j] = new InferenceInOutSequence.FromAssets(input, outputs,
116                             dataSize,
117                             inputSizeBytes);
118                 }
119             }
120             InferenceInOutSequence.FromDataset[] datasets = null;
121             if (jsonTestModelEntry.has("dataset")) {
122                 JSONObject jsonDataset = jsonTestModelEntry.getJSONObject("dataset");
123                 String inputPath = jsonDataset.getString("inputPath");
124                 String groundTruth = jsonDataset.getString("groundTruth");
125                 String labels = jsonDataset.getString("labels");
126                 String preprocessor = jsonDataset.getString("preprocessor");
127                 if (inputSize.length != 4 || inputSize[0] != 1 || inputSize[1] != inputSize[2] ||
128                         inputSize[3] != 3) {
129                     throw new IllegalArgumentException("Datasets only support square images," +
130                             "input size [1, D, D, 3], given " + inputSize[0] +
131                             ", " + inputSize[1] + ", " + inputSize[2] + ", " + inputSize[3]);
132                 }
133                 float quantScale = 0.f;
134                 float quantZeroPoint = 0.f;
135                 if (dataSize == 1) {
136                     if (!jsonTestModelEntry.has("inputScale") ||
137                             !jsonTestModelEntry.has("inputZeroPoint")) {
138                         throw new IllegalArgumentException("Quantized test model must include " +
139                                 "inputScale and inputZeroPoint for reading a dataset");
140                     }
141                     quantScale = (float) jsonTestModelEntry.getDouble("inputScale");
142                     quantZeroPoint = (float) jsonTestModelEntry.getDouble("inputZeroPoint");
143                 }
144                 datasets = new InferenceInOutSequence.FromDataset[]{
145                         new InferenceInOutSequence.FromDataset(inputPath, labels, groundTruth,
146                                 preprocessor, dataSize, quantScale, quantZeroPoint, inputSize[1])
147                 };
148             }
149 
150             TestModels.registerModel(
151                 new TestModels.TestModelEntry(name, (float) baseline, inputSize, inputOutputs,
152                     datasets, testName, modelFile, evaluator, minSdkVersion, dataSize));
153         }
154     }
155 
readAssetsFileAsString(InputStream inputStream)156     static String readAssetsFileAsString(InputStream inputStream) throws IOException {
157         Reader reader = new InputStreamReader(inputStream);
158         StringBuilder sb = new StringBuilder();
159         char buffer[] = new char[16384];
160         int len;
161         while ((len = reader.read(buffer)) > 0) {
162             sb.append(buffer, 0, len);
163         }
164         reader.close();
165         return sb.toString();
166     }
167 
168     /** Parse all ".json" files in root assets directory */
169     private static final String MODELS_LIST_ROOT = "models_list";
170 
parseFromAssets(AssetManager assetManager)171     static public void parseFromAssets(AssetManager assetManager) throws IOException {
172         for (String file : assetManager.list(MODELS_LIST_ROOT)) {
173             if (!file.endsWith(".json")) {
174                 continue;
175             }
176             try {
177                 parseJSONModelsList(readAssetsFileAsString(
178                         assetManager.open(MODELS_LIST_ROOT + "/" + file)));
179             } catch (JSONException e) {
180                 Log.e(TAG, "error reading json model list", e);
181                 throw new IOException("JSON error in " + file, e);
182             } catch (Exception e) {
183                 Log.e(TAG, "error parsing json model list", e);
184                 // Wrap exception to add a filename to it
185                 throw new IOException("Error while parsing " + file, e);
186             }
187 
188         }
189     }
190 }
191