1 /* 2 * Copyright (C) 2018 The Android Open Source Project 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 package com.android.nn.benchmark.core; 18 19 import java.io.BufferedReader; 20 import java.io.File; 21 import java.io.IOException; 22 23 import android.content.res.AssetManager; 24 import android.util.Log; 25 26 import com.android.nn.benchmark.util.IOUtils; 27 28 import java.io.InputStream; 29 import java.io.InputStreamReader; 30 import java.nio.ByteBuffer; 31 import java.util.ArrayList; 32 import java.util.Arrays; 33 import java.util.Collections; 34 import java.util.Comparator; 35 import java.util.HashMap; 36 import java.util.List; 37 38 /** 39 * Input and expected output sequence pair for inference benchmark. 40 * 41 * Note that it's quite likely this class will need extension with new datasets, 42 * it now supports imagenet-style files and labels only. 43 */ 44 public class InferenceInOutSequence { 45 /** Sequence of input/output pairs */ 46 private List<InferenceInOut> mInputOutputs; 47 private boolean mHasGoldenOutput; 48 final public int mDatasize; 49 InferenceInOutSequence(int sequenceLength, boolean hasGoldenOutput, int datasize)50 public InferenceInOutSequence(int sequenceLength, boolean hasGoldenOutput, int datasize) { 51 mInputOutputs = new ArrayList<>(sequenceLength); 52 mHasGoldenOutput = hasGoldenOutput; 53 mDatasize = datasize; 54 } 55 size()56 public int size() { 57 return mInputOutputs.size(); 58 } 59 get(int i)60 public InferenceInOut get(int i) { 61 return mInputOutputs.get(i); 62 } 63 hasGoldenOutput()64 public boolean hasGoldenOutput() { 65 return mHasGoldenOutput; 66 } 67 68 /** Helper class, generates {@link InferenceInOut} from a pair of android asset files */ 69 public static class FromAssets { 70 private String mInputAssetName; 71 private String[] mOutputAssetsNames; 72 private int mDataBytesSize; 73 private int mInputSizeBytes; 74 FromAssets(String inputAssetName, String[] outputAssetsNames, int dataBytesSize, int inputSizeBytes)75 public FromAssets(String inputAssetName, String[] outputAssetsNames, int dataBytesSize, 76 int inputSizeBytes) { 77 this.mInputAssetName = inputAssetName; 78 this.mOutputAssetsNames = outputAssetsNames; 79 this.mDataBytesSize = dataBytesSize; 80 this.mInputSizeBytes = inputSizeBytes; 81 } 82 readAssets(AssetManager assetManager)83 public InferenceInOutSequence readAssets(AssetManager assetManager) throws IOException { 84 byte[] inputs = IOUtils.readAsset(assetManager, mInputAssetName, mDataBytesSize); 85 byte[][] outputs = new byte[mOutputAssetsNames.length][]; 86 int sequenceLength = inputs.length / mInputSizeBytes; 87 88 for (int i = 0; i < mOutputAssetsNames.length; ++i) { 89 outputs[i] = IOUtils.readAsset(assetManager, mOutputAssetsNames[i], mDataBytesSize); 90 if (outputs[i].length % sequenceLength != 0) { 91 throw new IllegalArgumentException( 92 "Output data " + mOutputAssetsNames[i] + " size (in bytes): " + 93 outputs[i].length + " is not a multiple of sequence length: " + 94 sequenceLength); 95 } 96 } 97 if (inputs.length % mInputSizeBytes != 0) { 98 throw new IllegalArgumentException("Input data size (in bytes): " + inputs.length + 99 " is not a multiple of input size (in bytes): " + mInputSizeBytes); 100 } 101 InferenceInOutSequence sequence = new InferenceInOutSequence( 102 sequenceLength, true, mDataBytesSize); 103 104 for (int i = 0; i < sequenceLength; ++i) { 105 byte[][] outz = new byte[mOutputAssetsNames.length][]; 106 for (int j = 0; j < mOutputAssetsNames.length; ++j) { 107 int outputSizeBytes = outputs[j].length / sequenceLength; 108 outz[j] = Arrays.copyOfRange(outputs[j], outputSizeBytes * i, 109 outputSizeBytes * (i + 1)); 110 } 111 112 sequence.mInputOutputs.add(new InferenceInOut( 113 Arrays.copyOfRange(inputs, mInputSizeBytes * i, mInputSizeBytes * (i + 1)), 114 outz, 115 -1)); 116 } 117 return sequence; 118 } 119 } 120 121 /** 122 * Helper class, generates {@link InferenceInOut}[] from a directory with image files, 123 * (optional) set of labels and an image preprocessor. 124 * 125 * The images and ground truth should look like imagenet: the images in the directory 126 * must be name <prefix>-<number>.<extension>, where the number is used to find the 127 * corresponding line in the ground truth labels. 128 */ 129 public static class FromDataset { 130 private String mInputPath; 131 private String mLabelAssetName; 132 private String mGroundTruthAssetName; 133 private String mPreprocessorName; 134 private int mDatasize; 135 private float mQuantScale; 136 private float mQuantZeroPoint; 137 private int mImageDimension; 138 FromDataset(String inputPath, String labelAssetName, String groundTruthAssetName, String preprocessorName, int datasize, float quantScale, float quantZeroPoint, int imageDimension)139 public FromDataset(String inputPath, String labelAssetName, String groundTruthAssetName, 140 String preprocessorName, int datasize, 141 float quantScale, float quantZeroPoint, 142 int imageDimension) { 143 mInputPath = inputPath; 144 if (mInputPath.endsWith("/")) { 145 mInputPath = mInputPath.substring(0, mInputPath.length() - 1); 146 } 147 mLabelAssetName = labelAssetName; 148 mGroundTruthAssetName = groundTruthAssetName; 149 mPreprocessorName = preprocessorName; 150 mDatasize = datasize; 151 mQuantScale = quantScale; 152 mQuantZeroPoint = quantZeroPoint; 153 mImageDimension = imageDimension; 154 } 155 isImageFile(String fileName)156 private boolean isImageFile(String fileName) { 157 String lower = fileName.toLowerCase(); 158 return (lower.endsWith(".jpeg") || lower.endsWith(".jpg")); 159 } 160 createImageProcessor()161 private ImageProcessorInterface createImageProcessor() { 162 try { 163 Class<?> clazz = Class.forName( 164 "com.android.nn.benchmark.imageprocessors." + mPreprocessorName); 165 return (ImageProcessorInterface) clazz.getConstructor().newInstance(); 166 } catch (Exception e) { 167 throw new IllegalArgumentException( 168 "Can not create image processors named '" + mPreprocessorName + "'", 169 e); 170 } 171 } 172 getIndexFromFilename(String filename)173 private static Integer getIndexFromFilename(String filename) { 174 String index = filename.split("-")[1].split("\\.")[0]; 175 return Integer.valueOf(index, 10); 176 } 177 readDataset( final AssetManager assetManager, final File cacheDir)178 public ArrayList<InferenceInOutSequence> readDataset( 179 final AssetManager assetManager, final File cacheDir) throws IOException { 180 String[] allFileNames = assetManager.list(mInputPath); 181 ArrayList<String> imageFileNames = new ArrayList<String>(); 182 for (String fileName : allFileNames) { 183 if (isImageFile(fileName)) { 184 imageFileNames.add(fileName); 185 } 186 } 187 Collections.sort(imageFileNames, new Comparator<String>() { 188 @Override 189 public int compare(String o1, String o2) { 190 Integer index1 = getIndexFromFilename(o1); 191 Integer index2 = getIndexFromFilename(o2); 192 return index1.compareTo(index2); 193 } 194 }); 195 196 Integer[] expectedClasses = null; 197 HashMap<String, Integer> labelMap = null; 198 if (mLabelAssetName != null) { 199 labelMap = new HashMap<String, Integer>(); 200 InputStream labelStream = assetManager.open(mLabelAssetName); 201 BufferedReader labelReader = new BufferedReader( 202 new InputStreamReader(labelStream, "UTF-8")); 203 String line; 204 int index = 0; 205 while ((line = labelReader.readLine()) != null) { 206 labelMap.put(line, new Integer(index)); 207 index++; 208 } 209 } 210 if (mGroundTruthAssetName != null) { 211 expectedClasses = new Integer[imageFileNames.size()]; 212 InputStream truthStream = assetManager.open(mGroundTruthAssetName); 213 BufferedReader truthReader = new BufferedReader( 214 new InputStreamReader(truthStream, "UTF-8")); 215 String line; 216 int index = 0; 217 while ((line = truthReader.readLine()) != null) { 218 if (labelMap != null) { 219 expectedClasses[index] = labelMap.get(line); 220 } else { 221 expectedClasses[index] = Integer.parseInt(line, 10); 222 } 223 index++; 224 } 225 } 226 227 ArrayList<InferenceInOutSequence> ret = new ArrayList<InferenceInOutSequence>(); 228 final ImageProcessorInterface imageProcessor = createImageProcessor(); 229 230 for (int i = 0; i < imageFileNames.size(); i++) { 231 final String fileName = mInputPath + '/' + imageFileNames.get(i); 232 int expectedClass = -1; 233 if (expectedClasses != null) { 234 expectedClass = expectedClasses[i]; 235 } 236 InferenceInOut.InputCreatorInterface creator = 237 new InferenceInOut.InputCreatorInterface() { 238 @Override 239 public void createInput(ByteBuffer buffer) { 240 try { 241 imageProcessor.preprocess(mDatasize, 242 mQuantScale, mQuantZeroPoint, mImageDimension, 243 assetManager, fileName, cacheDir, buffer); 244 } catch (Throwable t) { 245 throw new Error("Failed to create image input", t); 246 } 247 } 248 }; 249 InferenceInOutSequence sequence = new InferenceInOutSequence( 250 1, false, mDatasize); 251 sequence.mInputOutputs.add(new InferenceInOut(creator, null, 252 expectedClass)); 253 ret.add(sequence); 254 } 255 return ret; 256 } 257 } 258 } 259