1 /*
2  * Copyright (C) 2018 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 package com.android.nn.benchmark.core;
18 
19 import java.io.BufferedReader;
20 import java.io.File;
21 import java.io.IOException;
22 
23 import android.content.res.AssetManager;
24 import android.util.Log;
25 
26 import com.android.nn.benchmark.util.IOUtils;
27 
28 import java.io.InputStream;
29 import java.io.InputStreamReader;
30 import java.nio.ByteBuffer;
31 import java.util.ArrayList;
32 import java.util.Arrays;
33 import java.util.Collections;
34 import java.util.Comparator;
35 import java.util.HashMap;
36 import java.util.List;
37 
38 /**
39  * Input and expected output sequence pair for inference benchmark.
40  *
41  * Note that it's quite likely this class will need extension with new datasets,
42  * it now supports imagenet-style files and labels only.
43  */
44 public class InferenceInOutSequence {
45     /** Sequence of input/output pairs */
46     private List<InferenceInOut> mInputOutputs;
47     private boolean mHasGoldenOutput;
48     final public int mDatasize;
49 
InferenceInOutSequence(int sequenceLength, boolean hasGoldenOutput, int datasize)50     public InferenceInOutSequence(int sequenceLength, boolean hasGoldenOutput, int datasize) {
51         mInputOutputs = new ArrayList<>(sequenceLength);
52         mHasGoldenOutput = hasGoldenOutput;
53         mDatasize = datasize;
54     }
55 
size()56     public int size() {
57         return mInputOutputs.size();
58     }
59 
get(int i)60     public InferenceInOut get(int i) {
61         return mInputOutputs.get(i);
62     }
63 
hasGoldenOutput()64     public boolean hasGoldenOutput() {
65         return mHasGoldenOutput;
66     }
67 
68     /** Helper class, generates {@link InferenceInOut} from a pair of android asset files */
69     public static class FromAssets {
70         private String mInputAssetName;
71         private String[] mOutputAssetsNames;
72         private int mDataBytesSize;
73         private int mInputSizeBytes;
74 
FromAssets(String inputAssetName, String[] outputAssetsNames, int dataBytesSize, int inputSizeBytes)75         public FromAssets(String inputAssetName, String[] outputAssetsNames, int dataBytesSize,
76                 int inputSizeBytes) {
77             this.mInputAssetName = inputAssetName;
78             this.mOutputAssetsNames = outputAssetsNames;
79             this.mDataBytesSize = dataBytesSize;
80             this.mInputSizeBytes = inputSizeBytes;
81         }
82 
readAssets(AssetManager assetManager)83         public InferenceInOutSequence readAssets(AssetManager assetManager) throws IOException {
84             byte[] inputs = IOUtils.readAsset(assetManager, mInputAssetName, mDataBytesSize);
85             byte[][] outputs = new byte[mOutputAssetsNames.length][];
86             int sequenceLength = inputs.length / mInputSizeBytes;
87 
88             for (int i = 0; i < mOutputAssetsNames.length; ++i) {
89                 outputs[i] = IOUtils.readAsset(assetManager, mOutputAssetsNames[i], mDataBytesSize);
90                 if (outputs[i].length % sequenceLength != 0) {
91                     throw new IllegalArgumentException(
92                             "Output data " + mOutputAssetsNames[i] + " size (in bytes): " +
93                                     outputs[i].length + " is not a multiple of sequence length: " +
94                                     sequenceLength);
95                 }
96             }
97             if (inputs.length % mInputSizeBytes != 0) {
98                 throw new IllegalArgumentException("Input data size (in bytes): " + inputs.length +
99                         " is not a multiple of input size (in bytes): " + mInputSizeBytes);
100             }
101             InferenceInOutSequence sequence = new InferenceInOutSequence(
102                     sequenceLength, true, mDataBytesSize);
103 
104             for (int i = 0; i < sequenceLength; ++i) {
105                 byte[][] outz = new byte[mOutputAssetsNames.length][];
106                 for (int j = 0; j < mOutputAssetsNames.length; ++j) {
107                     int outputSizeBytes = outputs[j].length / sequenceLength;
108                     outz[j] = Arrays.copyOfRange(outputs[j], outputSizeBytes * i,
109                             outputSizeBytes * (i + 1));
110                 }
111 
112                 sequence.mInputOutputs.add(new InferenceInOut(
113                         Arrays.copyOfRange(inputs, mInputSizeBytes * i, mInputSizeBytes * (i + 1)),
114                         outz,
115                         -1));
116             }
117             return sequence;
118         }
119     }
120 
121     /**
122      * Helper class, generates {@link InferenceInOut}[] from a directory with image files,
123      * (optional) set of labels and an image preprocessor.
124      *
125      * The images and ground truth should look like imagenet: the images in the directory
126      * must be name <prefix>-<number>.<extension>, where the number is used to find the
127      * corresponding line in the ground truth labels.
128      */
129     public static class FromDataset {
130         private String mInputPath;
131         private String mLabelAssetName;
132         private String mGroundTruthAssetName;
133         private String mPreprocessorName;
134         private int mDatasize;
135         private float mQuantScale;
136         private float mQuantZeroPoint;
137         private int mImageDimension;
138 
FromDataset(String inputPath, String labelAssetName, String groundTruthAssetName, String preprocessorName, int datasize, float quantScale, float quantZeroPoint, int imageDimension)139         public FromDataset(String inputPath, String labelAssetName, String groundTruthAssetName,
140                 String preprocessorName, int datasize,
141                 float quantScale, float quantZeroPoint,
142                 int imageDimension) {
143             mInputPath = inputPath;
144             if (mInputPath.endsWith("/")) {
145                 mInputPath = mInputPath.substring(0, mInputPath.length() - 1);
146             }
147             mLabelAssetName = labelAssetName;
148             mGroundTruthAssetName = groundTruthAssetName;
149             mPreprocessorName = preprocessorName;
150             mDatasize = datasize;
151             mQuantScale = quantScale;
152             mQuantZeroPoint = quantZeroPoint;
153             mImageDimension = imageDimension;
154         }
155 
isImageFile(String fileName)156         private boolean isImageFile(String fileName) {
157             String lower = fileName.toLowerCase();
158             return (lower.endsWith(".jpeg") || lower.endsWith(".jpg"));
159         }
160 
createImageProcessor()161         private ImageProcessorInterface createImageProcessor() {
162             try {
163                 Class<?> clazz = Class.forName(
164                         "com.android.nn.benchmark.imageprocessors." + mPreprocessorName);
165                 return (ImageProcessorInterface) clazz.getConstructor().newInstance();
166             } catch (Exception e) {
167                 throw new IllegalArgumentException(
168                         "Can not create image processors named '" + mPreprocessorName + "'",
169                         e);
170             }
171         }
172 
getIndexFromFilename(String filename)173         private static Integer getIndexFromFilename(String filename) {
174             String index = filename.split("-")[1].split("\\.")[0];
175             return Integer.valueOf(index, 10);
176         }
177 
readDataset( final AssetManager assetManager, final File cacheDir)178         public ArrayList<InferenceInOutSequence> readDataset(
179                 final AssetManager assetManager, final File cacheDir) throws IOException {
180             String[] allFileNames = assetManager.list(mInputPath);
181             ArrayList<String> imageFileNames = new ArrayList<String>();
182             for (String fileName : allFileNames) {
183                 if (isImageFile(fileName)) {
184                     imageFileNames.add(fileName);
185                 }
186             }
187             Collections.sort(imageFileNames, new Comparator<String>() {
188                 @Override
189                 public int compare(String o1, String o2) {
190                     Integer index1 = getIndexFromFilename(o1);
191                     Integer index2 = getIndexFromFilename(o2);
192                     return index1.compareTo(index2);
193                 }
194             });
195 
196             Integer[] expectedClasses = null;
197             HashMap<String, Integer> labelMap = null;
198             if (mLabelAssetName != null) {
199                 labelMap = new HashMap<String, Integer>();
200                 InputStream labelStream = assetManager.open(mLabelAssetName);
201                 BufferedReader labelReader = new BufferedReader(
202                         new InputStreamReader(labelStream, "UTF-8"));
203                 String line;
204                 int index = 0;
205                 while ((line = labelReader.readLine()) != null) {
206                     labelMap.put(line, new Integer(index));
207                     index++;
208                 }
209             }
210             if (mGroundTruthAssetName != null) {
211                 expectedClasses = new Integer[imageFileNames.size()];
212                 InputStream truthStream = assetManager.open(mGroundTruthAssetName);
213                 BufferedReader truthReader = new BufferedReader(
214                         new InputStreamReader(truthStream, "UTF-8"));
215                 String line;
216                 int index = 0;
217                 while ((line = truthReader.readLine()) != null) {
218                     if (labelMap != null) {
219                         expectedClasses[index] = labelMap.get(line);
220                     } else {
221                         expectedClasses[index] = Integer.parseInt(line, 10);
222                     }
223                     index++;
224                 }
225             }
226 
227             ArrayList<InferenceInOutSequence> ret = new ArrayList<InferenceInOutSequence>();
228             final ImageProcessorInterface imageProcessor = createImageProcessor();
229 
230             for (int i = 0; i < imageFileNames.size(); i++) {
231                 final String fileName = mInputPath + '/' + imageFileNames.get(i);
232                 int expectedClass = -1;
233                 if (expectedClasses != null) {
234                     expectedClass = expectedClasses[i];
235                 }
236                 InferenceInOut.InputCreatorInterface creator =
237                         new InferenceInOut.InputCreatorInterface() {
238                             @Override
239                             public void createInput(ByteBuffer buffer) {
240                                 try {
241                                     imageProcessor.preprocess(mDatasize,
242                                             mQuantScale, mQuantZeroPoint, mImageDimension,
243                                             assetManager, fileName, cacheDir, buffer);
244                                 } catch (Throwable t) {
245                                     throw new Error("Failed to create image input", t);
246                                 }
247                             }
248                         };
249                 InferenceInOutSequence sequence = new InferenceInOutSequence(
250                         1, false, mDatasize);
251                 sequence.mInputOutputs.add(new InferenceInOut(creator, null,
252                         expectedClass));
253                 ret.add(sequence);
254             }
255             return ret;
256         }
257     }
258 }
259